]> git.madduck.net Git - etc/taskwarrior.git/blobdiff - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Task: Stop before marking as done with older TW versions
[etc/taskwarrior.git] / tasklib / task.py
index f070e6b3b8bd103ecad6d00a04123b84d524402c..1121988249f3745264f3c588a69c9cee6517fd03 100644 (file)
@@ -297,7 +297,7 @@ class TaskResource(SerializingObject):
         # are not propagated.
         self._original_data = copy.deepcopy(self._data)
 
-    def _update_data(self, data, update_original=False):
+    def _update_data(self, data, update_original=False, remove_missing=False):
         """
         Low level update of the internal _data dict. Data which are coming as
         updates should already be serialized. If update_original is True, the
@@ -306,6 +306,11 @@ class TaskResource(SerializingObject):
         self._data.update(dict((key, self._deserialize(key, value))
                                for key, value in data.items()))
 
+        # In certain situations, we want to treat missing keys as removals
+        if remove_missing:
+            for key in set(self._data.keys()) - set(data.keys()):
+                self._data[key] = None
+
         if update_original:
             self._original_data = copy.deepcopy(self._data)
 
@@ -416,6 +421,12 @@ class Task(TaskResource):
         """
         pass
 
+    class InactiveTask(Exception):
+        """
+        Raised when the operation cannot be performed on an inactive task.
+        """
+        pass
+
     class NotSaved(Exception):
         """
         Raised when the operation cannot be performed on the task, because
@@ -458,7 +469,8 @@ class Task(TaskResource):
         # If this is a on-modify event, we are provided with additional
         # line of input, which provides updated data
         if modify:
-            task._update_data(json.loads(input_file.readline().strip()))
+            task._update_data(json.loads(input_file.readline().strip()),
+                              remove_missing=True)
 
         return task
 
@@ -560,7 +572,7 @@ class Task(TaskResource):
         if self.warrior.version < VERSION_2_4_0:
             return self._data['description']
         else:
-            return "description:'{0}'".format(self._data['description'] or '')
+            return six.u("description:'{0}'").format(self._data['description'] or '')
 
     def delete(self):
         if not self.saved:
@@ -594,6 +606,21 @@ class Task(TaskResource):
         # Refresh the status again, so that we have updated info stored
         self.refresh(only_fields=['status', 'start'])
 
+    def stop(self):
+        if not self.saved:
+            raise Task.NotSaved("Task needs to be saved before it can be stopped")
+
+        # Refresh, and raise exception if task is already completed/deleted
+        self.refresh(only_fields=['status'])
+
+        if not self.active:
+            raise Task.InactiveTask("Cannot stop an inactive task")
+
+        self.warrior.execute_command([self['uuid'], 'stop'])
+
+        # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start'])
+
     def done(self):
         if not self.saved:
             raise Task.NotSaved("Task needs to be saved before it can be completed")
@@ -606,6 +633,10 @@ class Task(TaskResource):
         elif self.deleted:
             raise Task.DeletedTask("Deleted task cannot be completed")
 
+        # Older versions of TW do not stop active task at completion
+        if self.warrior.version < VERSION_2_4_0 and self.active:
+            self.stop()
+
         self.warrior.execute_command([self['uuid'], 'done'])
 
         # Refresh the status again, so that we have updated info stored
@@ -668,9 +699,9 @@ class Task(TaskResource):
             if serialized_value is '':
                 escaped_serialized_value = ''
             else:
-                escaped_serialized_value = "'{0}'".format(serialized_value)
+                escaped_serialized_value = six.u("'{0}'").format(serialized_value)
 
-            format_default = lambda: "{0}:{1}".format(field,
+            format_default = lambda: six.u("{0}:{1}").format(field,
                                                       escaped_serialized_value)
 
             format_func = getattr(self, 'format_{0}'.format(field),
@@ -744,7 +775,7 @@ class TaskFilter(SerializingObject):
             modifier = '.is' if value else '.none'
             key = key + modifier if '.' not in key else key
 
-            self.filter_params.append("{0}:{1}".format(key, value))
+            self.filter_params.append(six.u("{0}:{1}").format(key, value))
 
     def get_filter_params(self):
         return [f for f in self.filter_params if f]
@@ -881,7 +912,7 @@ class TaskWarrior(object):
             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
             # arbitrary big number which is likely to be large enough
-            'bulk': 0 if self.version > VERSION_2_4_3 else 100000,
+            'bulk': 0 if self.version >= VERSION_2_4_3 else 100000,
         }
 
         # Set data.location override if passed via kwarg
@@ -899,7 +930,7 @@ class TaskWarrior(object):
         config.update(config_override)
         for item in config.items():
             command_args.append('rc.{0}={1}'.format(*item))
-        command_args.extend(map(str, args))
+        command_args.extend(map(six.text_type, args))
         return command_args
 
     def _get_version(self):
@@ -910,7 +941,8 @@ class TaskWarrior(object):
         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
         return stdout.strip('\n')
 
-    def execute_command(self, args, config_override={}, allow_failure=True):
+    def execute_command(self, args, config_override={}, allow_failure=True,
+                        return_all=False):
         command_args = self._get_command_args(
             args, config_override=config_override)
         logger.debug(' '.join(command_args))
@@ -923,7 +955,14 @@ class TaskWarrior(object):
             else:
                 error_msg = stdout.strip()
             raise TaskWarriorException(error_msg)
-        return stdout.rstrip().split('\n')
+
+        # Return all whole triplet only if explicitly asked for
+        if not return_all:
+            return stdout.rstrip().split('\n')
+        else:
+            return (stdout.rstrip().split('\n'),
+                    stderr.rstrip().split('\n'),
+                    p.returncode)
 
     def enforce_recurrence(self):
         # Run arbitrary report command which will trigger generation