]> git.madduck.net Git - etc/taskwarrior.git/blobdiff - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Task: Stop before marking as done with older TW versions
[etc/taskwarrior.git] / tasklib / task.py
index e0e10409451cfeb5d2c38e72453cf980b4b18265..1121988249f3745264f3c588a69c9cee6517fd03 100644 (file)
@@ -22,6 +22,7 @@ VERSION_2_3_0 = six.u('2.3.0')
 VERSION_2_4_0 = six.u('2.4.0')
 VERSION_2_4_1 = six.u('2.4.1')
 VERSION_2_4_2 = six.u('2.4.2')
+VERSION_2_4_3 = six.u('2.4.3')
 
 logger = logging.getLogger(__name__)
 local_zone = tzlocal.get_localzone()
@@ -262,8 +263,10 @@ class SerializingObject(object):
                 # If the value is already localized, there is no need to change
                 # time zone at this point. Also None is a valid value too.
                 localized = value
-        elif isinstance(value, six.string_types):
+        elif (isinstance(value, six.string_types)
+                and self.warrior.version >= VERSION_2_4_0):
             # For strings, use 'task calc' to evaluate the string to datetime
+            # available since TW 2.4.0
             args = value.split()
             result = self.warrior.execute_command(['calc'] + args)
             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
@@ -294,7 +297,7 @@ class TaskResource(SerializingObject):
         # are not propagated.
         self._original_data = copy.deepcopy(self._data)
 
-    def _update_data(self, data, update_original=False):
+    def _update_data(self, data, update_original=False, remove_missing=False):
         """
         Low level update of the internal _data dict. Data which are coming as
         updates should already be serialized. If update_original is True, the
@@ -303,6 +306,11 @@ class TaskResource(SerializingObject):
         self._data.update(dict((key, self._deserialize(key, value))
                                for key, value in data.items()))
 
+        # In certain situations, we want to treat missing keys as removals
+        if remove_missing:
+            for key in set(self._data.keys()) - set(data.keys()):
+                self._data[key] = None
+
         if update_original:
             self._original_data = copy.deepcopy(self._data)
 
@@ -413,6 +421,12 @@ class Task(TaskResource):
         """
         pass
 
+    class InactiveTask(Exception):
+        """
+        Raised when the operation cannot be performed on an inactive task.
+        """
+        pass
+
     class NotSaved(Exception):
         """
         Raised when the operation cannot be performed on the task, because
@@ -455,7 +469,8 @@ class Task(TaskResource):
         # If this is a on-modify event, we are provided with additional
         # line of input, which provides updated data
         if modify:
-            task._update_data(json.loads(input_file.readline().strip()))
+            task._update_data(json.loads(input_file.readline().strip()),
+                              remove_missing=True)
 
         return task
 
@@ -515,6 +530,10 @@ class Task(TaskResource):
     def pending(self):
         return self['status'] == six.text_type('pending')
 
+    @property
+    def active(self):
+        return self['start'] is not None
+
     @property
     def saved(self):
         return self['uuid'] is not None or self['id'] is not None
@@ -553,7 +572,7 @@ class Task(TaskResource):
         if self.warrior.version < VERSION_2_4_0:
             return self._data['description']
         else:
-            return "description:'{0}'".format(self._data['description'] or '')
+            return six.u("description:'{0}'").format(self._data['description'] or '')
 
     def delete(self):
         if not self.saved:
@@ -587,6 +606,21 @@ class Task(TaskResource):
         # Refresh the status again, so that we have updated info stored
         self.refresh(only_fields=['status', 'start'])
 
+    def stop(self):
+        if not self.saved:
+            raise Task.NotSaved("Task needs to be saved before it can be stopped")
+
+        # Refresh, and raise exception if task is already completed/deleted
+        self.refresh(only_fields=['status'])
+
+        if not self.active:
+            raise Task.InactiveTask("Cannot stop an inactive task")
+
+        self.warrior.execute_command([self['uuid'], 'stop'])
+
+        # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start'])
+
     def done(self):
         if not self.saved:
             raise Task.NotSaved("Task needs to be saved before it can be completed")
@@ -599,6 +633,10 @@ class Task(TaskResource):
         elif self.deleted:
             raise Task.DeletedTask("Deleted task cannot be completed")
 
+        # Older versions of TW do not stop active task at completion
+        if self.warrior.version < VERSION_2_4_0 and self.active:
+            self.stop()
+
         self.warrior.execute_command([self['uuid'], 'done'])
 
         # Refresh the status again, so that we have updated info stored
@@ -661,9 +699,9 @@ class Task(TaskResource):
             if serialized_value is '':
                 escaped_serialized_value = ''
             else:
-                escaped_serialized_value = "'{0}'".format(serialized_value)
+                escaped_serialized_value = six.u("'{0}'").format(serialized_value)
 
-            format_default = lambda: "{0}:{1}".format(field,
+            format_default = lambda: six.u("{0}:{1}").format(field,
                                                       escaped_serialized_value)
 
             format_func = getattr(self, 'format_{0}'.format(field),
@@ -737,7 +775,7 @@ class TaskFilter(SerializingObject):
             modifier = '.is' if value else '.none'
             key = key + modifier if '.' not in key else key
 
-            self.filter_params.append("{0}:{1}".format(key, value))
+            self.filter_params.append(six.u("{0}:{1}").format(key, value))
 
     def get_filter_params(self):
         return [f for f in self.filter_params if f]
@@ -859,8 +897,7 @@ class TaskQuerySet(object):
 
 
 class TaskWarrior(object):
-    def __init__(self, data_location='~/.task', create=True, taskrc_location='~/.taskrc'):
-        data_location = os.path.expanduser(data_location)
+    def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
         self.taskrc_location = os.path.expanduser(taskrc_location)
 
         # If taskrc does not exist, pass / to use defaults and avoid creating
@@ -868,17 +905,24 @@ class TaskWarrior(object):
         if not os.path.exists(self.taskrc_location):
             self.taskrc_location = '/'
 
-        if create and not os.path.exists(data_location):
-            os.makedirs(data_location)
-
+        self.version = self._get_version()
         self.config = {
-            'data.location': data_location,
             'confirmation': 'no',
             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
+            # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
+            # arbitrary big number which is likely to be large enough
+            'bulk': 0 if self.version >= VERSION_2_4_3 else 100000,
         }
+
+        # Set data.location override if passed via kwarg
+        if data_location is not None:
+            data_location = os.path.expanduser(data_location)
+            if create and not os.path.exists(data_location):
+                os.makedirs(data_location)
+            self.config['data.location'] = data_location
+
         self.tasks = TaskQuerySet(self)
-        self.version = self._get_version()
 
     def _get_command_args(self, args, config_override={}):
         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
@@ -886,7 +930,7 @@ class TaskWarrior(object):
         config.update(config_override)
         for item in config.items():
             command_args.append('rc.{0}={1}'.format(*item))
-        command_args.extend(map(str, args))
+        command_args.extend(map(six.text_type, args))
         return command_args
 
     def _get_version(self):
@@ -897,7 +941,8 @@ class TaskWarrior(object):
         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
         return stdout.strip('\n')
 
-    def execute_command(self, args, config_override={}, allow_failure=True):
+    def execute_command(self, args, config_override={}, allow_failure=True,
+                        return_all=False):
         command_args = self._get_command_args(
             args, config_override=config_override)
         logger.debug(' '.join(command_args))
@@ -910,7 +955,14 @@ class TaskWarrior(object):
             else:
                 error_msg = stdout.strip()
             raise TaskWarriorException(error_msg)
-        return stdout.strip().split('\n')
+
+        # Return all whole triplet only if explicitly asked for
+        if not return_all:
+            return stdout.rstrip().split('\n')
+        else:
+            return (stdout.rstrip().split('\n'),
+                    stderr.rstrip().split('\n'),
+                    p.returncode)
 
     def enforce_recurrence(self):
         # Run arbitrary report command which will trigger generation