X-Git-Url: https://git.madduck.net/etc/taskwarrior.git/blobdiff_plain/ae7397e43578a8adcff05714ea6aa5960af13145..5c5b35a097608b63c3d54acf2f07d5cb00e85796:/tasklib/task.py?ds=inline

diff --git a/tasklib/task.py b/tasklib/task.py
index ca21b2b..b7bba4b 100644
--- a/tasklib/task.py
+++ b/tasklib/task.py
@@ -7,25 +7,49 @@ import os
 import pytz
 import six
 import sys
-import subprocess
 import tzlocal
 
+from backends import TaskWarrior, TaskWarriorException
+
 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
+DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
 REPR_OUTPUT_SIZE = 10
 PENDING = 'pending'
 COMPLETED = 'completed'
 
-VERSION_2_1_0 = six.u('2.1.0')
-VERSION_2_2_0 = six.u('2.2.0')
-VERSION_2_3_0 = six.u('2.3.0')
-VERSION_2_4_0 = six.u('2.4.0')
-
 logger = logging.getLogger(__name__)
 local_zone = tzlocal.get_localzone()
 
 
-class TaskWarriorException(Exception):
-    pass
+class ReadOnlyDictView(object):
+    """
+    Provides simplified read-only view upon dict object.
+    """
+
+    def __init__(self, viewed_dict):
+        self.viewed_dict = viewed_dict
+
+    def __getitem__(self, key):
+        return copy.deepcopy(self.viewed_dict.__getitem__(key))
+
+    def __contains__(self, k):
+        return self.viewed_dict.__contains__(k)
+
+    def __iter__(self):
+        for value in self.viewed_dict:
+            yield copy.deepcopy(value)
+
+    def __len__(self):
+        return len(self.viewed_dict)
+
+    def get(self, key, default=None):
+        return copy.deepcopy(self.viewed_dict.get(key, default))
+
+    def items(self):
+        return [copy.deepcopy(v) for v in self.viewed_dict.items()]
+
+    def values(self):
+        return [copy.deepcopy(v) for v in self.viewed_dict.values()]
 
 
 class SerializingObject(object):
@@ -44,8 +68,20 @@ class SerializingObject(object):
         not export empty-valued attributes) if the attribute
         is not iterable (e.g. list or set), in which case
         a empty iterable should be used.
+
+    Normalizing methods should hold the following contract:
+      - They are used to validate and normalize the user input.
+        Any attribute value that comes from the user (during Task
+        initialization, assignign values to Task attributes, or
+        filtering by user-provided values of attributes) is first
+        validated and normalized using the normalize_{key} method.
+      - If validation or normalization fails, normalizer is expected
+        to raise ValueError.
     """
 
+    def __init__(self, warrior):
+        self.warrior = warrior
+
     def _deserialize(self, key, value):
         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
                                lambda x: x if x != '' else None)
@@ -63,6 +99,10 @@ class SerializingObject(object):
         or entered as a value of Task attribute.
         """
 
+        # None value should not be converted by normalizer
+        if value is None:
+            return None
+
         normalize_func = getattr(self, 'normalize_{0}'.format(key),
                                  lambda x: x)
 
@@ -93,36 +133,72 @@ class SerializingObject(object):
     def deserialize_entry(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_entry(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_modified(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_modified(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_modified(self, value):
+        return self.datetime_normalizer(value)
+
+    def serialize_start(self, value):
+        return self.timestamp_serializer(value)
+
+    def deserialize_start(self, value):
+        return self.timestamp_deserializer(value)
+
+    def normalize_start(self, value):
+        return self.datetime_normalizer(value)
+
+    def serialize_end(self, value):
+        return self.timestamp_serializer(value)
+
+    def deserialize_end(self, value):
+        return self.timestamp_deserializer(value)
+
+    def normalize_end(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_due(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_due(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_due(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_scheduled(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_scheduled(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_scheduled(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_until(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_until(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_until(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_wait(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_wait(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_wait(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_annotations(self, value):
         value = value if value is not None else []
 
@@ -149,11 +225,18 @@ class SerializingObject(object):
         return ','.join(task['uuid'] for task in value)
 
     def deserialize_depends(self, raw_uuids):
-        raw_uuids = raw_uuids or ''  # Convert None to empty string
-        uuids = raw_uuids.split(',')
+        raw_uuids = raw_uuids or []  # Convert None to empty list
+
+        # TW 2.4.4 encodes list of dependencies as a single string
+        if type(raw_uuids) is not list:
+            uuids = raw_uuids.split(',')
+        # TW 2.4.5 and later exports them as a list, no conversion needed
+        else:
+            uuids = raw_uuids
+
         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
 
-    def normalize_datetime(self, value):
+    def datetime_normalizer(self, value):
         """
         Normalizes date/datetime value (considered to come from user input)
         to localized datetime value. Following conversions happen:
@@ -168,16 +251,36 @@ class SerializingObject(object):
             # Convert to local midnight
             value_full = datetime.datetime.combine(value, datetime.time.min)
             localized = local_zone.localize(value_full)
-        elif isinstance(value, datetime.datetime) and value.tzinfo is None:
-            # Convert to localized datetime object
-            localized = local_zone.localize(value)
+        elif isinstance(value, datetime.datetime):
+            if value.tzinfo is None:
+                # Convert to localized datetime object
+                localized = local_zone.localize(value)
+            else:
+                # If the value is already localized, there is no need to change
+                # time zone at this point. Also None is a valid value too.
+                localized = value
+        elif (isinstance(value, six.string_types)
+                and self.warrior.version >= VERSION_2_4_0):
+            # For strings, use 'task calc' to evaluate the string to datetime
+            # available since TW 2.4.0
+            args = value.split()
+            result = self.warrior.execute_command(['calc'] + args)
+            naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
+            localized = local_zone.localize(naive)
         else:
-            # If the value is already localized, there is no need to change
-            # time zone at this point. Also None is a valid value too.
-            localized = value
-        
+            raise ValueError("Provided value could not be converted to "
+                             "datetime, its type is not supported: {}"
+                             .format(type(value)))
+
         return localized
-            
+
+    def normalize_uuid(self, value):
+        # Enforce sane UUID
+        if not isinstance(value, six.string_types) or value == '':
+            raise ValueError("UUID must be a valid non-empty string, "
+                             "not: {}".format(value))
+
+        return value
 
 
 class TaskResource(SerializingObject):
@@ -190,7 +293,7 @@ class TaskResource(SerializingObject):
         # are not propagated.
         self._original_data = copy.deepcopy(self._data)
 
-    def _update_data(self, data, update_original=False):
+    def _update_data(self, data, update_original=False, remove_missing=False):
         """
         Low level update of the internal _data dict. Data which are coming as
         updates should already be serialized. If update_original is True, the
@@ -199,6 +302,11 @@ class TaskResource(SerializingObject):
         self._data.update(dict((key, self._deserialize(key, value))
                                for key, value in data.items()))
 
+        # In certain situations, we want to treat missing keys as removals
+        if remove_missing:
+            for key in set(self._data.keys()) - set(data.keys()):
+                self._data[key] = None
+
         if update_original:
             self._original_data = copy.deepcopy(self._data)
 
@@ -221,11 +329,8 @@ class TaskResource(SerializingObject):
         if key in self.read_only_fields:
             raise RuntimeError('Field \'%s\' is read-only' % key)
 
-        # Localize any naive date/datetime to the detected timezone
-        if (isinstance(value, datetime.datetime) or
-            isinstance(value, datetime.date)):
-            value = self.normalize_datetime(value)
-
+        # Normalize the user input before saving it
+        value = self._normalize(key, value)
         self._data[key] = value
 
     def __str__(self):
@@ -275,9 +380,10 @@ class TaskResource(SerializingObject):
 class TaskAnnotation(TaskResource):
     read_only_fields = ['entry', 'description']
 
-    def __init__(self, task, data={}):
+    def __init__(self, task, data=None):
         self.task = task
-        self._load_data(data)
+        self._load_data(data or dict())
+        super(TaskAnnotation, self).__init__(task.warrior)
 
     def remove(self):
         self.task.remove_annotation(self)
@@ -311,6 +417,18 @@ class Task(TaskResource):
         """
         pass
 
+    class ActiveTask(Exception):
+        """
+        Raised when the operation cannot be performed on the active task.
+        """
+        pass
+
+    class InactiveTask(Exception):
+        """
+        Raised when the operation cannot be performed on an inactive task.
+        """
+        pass
+
     class NotSaved(Exception):
         """
         Raised when the operation cannot be performed on the task, because
@@ -319,7 +437,7 @@ class Task(TaskResource):
         pass
 
     @classmethod
-    def from_input(cls, input_file=sys.stdin, modify=None):
+    def from_input(cls, input_file=sys.stdin, modify=None, warrior=None):
         """
         Creates a Task object, directly from the stdin, by reading one line.
         If modify=True, two lines are used, first line interpreted as the
@@ -335,25 +453,31 @@ class Task(TaskResource):
         but defaults to sys.stdin.
         """
 
-        # TaskWarrior instance is set to None
-        task = cls(None)
-
         # Detect the hook type if not given directly
         name = os.path.basename(sys.argv[0])
         modify = name.startswith('on-modify') if modify is None else modify
 
+        # Create the TaskWarrior instance if none passed
+        if warrior is None:
+            hook_parent_dir = os.path.dirname(os.path.dirname(sys.argv[0]))
+            warrior = TaskWarrior(data_location=hook_parent_dir)
+
+        # TaskWarrior instance is set to None
+        task = cls(warrior)
+
         # Load the data from the input
         task._load_data(json.loads(input_file.readline().strip()))
 
         # If this is a on-modify event, we are provided with additional
         # line of input, which provides updated data
         if modify:
-            task._update_data(json.loads(input_file.readline().strip()))
+            task._update_data(json.loads(input_file.readline().strip()),
+                              remove_missing=True)
 
         return task
 
     def __init__(self, warrior, **kwargs):
-        self.warrior = warrior
+        super(Task, self).__init__(warrior)
 
         # Check that user is not able to set read-only value in __init__
         for key in kwargs.keys():
@@ -369,6 +493,9 @@ class Task(TaskResource):
                           for (key, value) in six.iteritems(kwargs))
         self._original_data = copy.deepcopy(self._data)
 
+        # Provide read only access to the original data
+        self.original = ReadOnlyDictView(self._original_data)
+
     def __unicode__(self):
         return self['description']
 
@@ -405,6 +532,10 @@ class Task(TaskResource):
     def pending(self):
         return self['status'] == six.text_type('pending')
 
+    @property
+    def active(self):
+        return self['start'] is not None
+
     @property
     def saved(self):
         return self['uuid'] is not None or self['id'] is not None
@@ -443,7 +574,7 @@ class Task(TaskResource):
         if self.warrior.version < VERSION_2_4_0:
             return self._data['description']
         else:
-            return "description:'{0}'".format(self._data['description'] or '')
+            return six.u("description:'{0}'").format(self._data['description'] or '')
 
     def delete(self):
         if not self.saved:
@@ -455,11 +586,44 @@ class Task(TaskResource):
         if self.deleted:
             raise Task.DeletedTask("Task was already deleted")
 
-        self.warrior.execute_command([self['uuid'], 'delete'])
+        self.backend.delete_task(self)
+
+        # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start', 'end'])
+
+    def start(self):
+        if not self.saved:
+            raise Task.NotSaved("Task needs to be saved before it can be started")
+
+        # Refresh, and raise exception if task is already completed/deleted
+        self.refresh(only_fields=['status'])
+
+        if self.completed:
+            raise Task.CompletedTask("Cannot start a completed task")
+        elif self.deleted:
+            raise Task.DeletedTask("Deleted task cannot be started")
+        elif self.active:
+            raise Task.ActiveTask("Task is already active")
+
+        self.backend.start_task(self)
 
         # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start'])
+
+    def stop(self):
+        if not self.saved:
+            raise Task.NotSaved("Task needs to be saved before it can be stopped")
+
+        # Refresh, and raise exception if task is already completed/deleted
         self.refresh(only_fields=['status'])
 
+        if not self.active:
+            raise Task.InactiveTask("Cannot stop an inactive task")
+
+        self.backend.stop_task(self)
+
+        # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start'])
 
     def done(self):
         if not self.saved:
@@ -473,36 +637,21 @@ class Task(TaskResource):
         elif self.deleted:
             raise Task.DeletedTask("Deleted task cannot be completed")
 
+        # Older versions of TW do not stop active task at completion
+        if self.warrior.version < VERSION_2_4_0 and self.active:
+            self.stop()
+
         self.warrior.execute_command([self['uuid'], 'done'])
 
         # Refresh the status again, so that we have updated info stored
-        self.refresh(only_fields=['status'])
+        self.refresh(only_fields=['status', 'start', 'end'])
 
     def save(self):
         if self.saved and not self.modified:
             return
 
-        args = [self['uuid'], 'modify'] if self.saved else ['add']
-        args.extend(self._get_modified_fields_as_args())
-        output = self.warrior.execute_command(args)
-
-        # Parse out the new ID, if the task is being added for the first time
-        if not self.saved:
-            id_lines = [l for l in output if l.startswith('Created task ')]
-
-            # Complain loudly if it seems that more tasks were created
-            # Should not happen
-            if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
-                raise TaskWarriorException("Unexpected output when creating "
-                                           "task: %s" % '\n'.join(id_lines))
-
-            # Circumvent the ID storage, since ID is considered read-only
-            self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
-
-        # Refreshing is very important here, as not only modification time
-        # is updated, but arbitrary attribute may have changed due hooks
-        # altering the data before saving
-        self.refresh()
+        # All the actual work is done by the backend
+        self.backend.save_task(self)
 
     def add_annotation(self, annotation):
         if not self.saved:
@@ -535,9 +684,9 @@ class Task(TaskResource):
             if serialized_value is '':
                 escaped_serialized_value = ''
             else:
-                escaped_serialized_value = "'{0}'".format(serialized_value)
+                escaped_serialized_value = six.u("'{0}'").format(serialized_value)
 
-            format_default = lambda: "{0}:{1}".format(field,
+            format_default = lambda: six.u("{0}:{1}").format(field,
                                                       escaped_serialized_value)
 
             format_func = getattr(self, 'format_{0}'.format(field),
@@ -558,7 +707,7 @@ class Task(TaskResource):
 
         return args
 
-    def refresh(self, only_fields=[]):
+    def refresh(self, only_fields=None, after_save=False):
         # Raise error when trying to refresh a task that has not been saved
         if not self.saved:
             raise Task.NotSaved("Task needs to be saved to be refreshed")
@@ -567,7 +716,39 @@ class Task(TaskResource):
         # of newly saved tasks. Any other place in the code is fine
         # with using UUID only.
         args = [self['uuid'] or self['id'], 'export']
-        new_data = json.loads(self.warrior.execute_command(args)[0])
+        output = self.warrior.execute_command(args)
+
+        def valid(output):
+            return len(output) == 1 and output[0].startswith('{')
+
+        # For older TW versions attempt to uniquely locate the task
+        # using the data we have if it has been just saved.
+        # This can happen when adding a completed task on older TW versions.
+        if (not valid(output) and self.warrior.version < VERSION_2_4_5
+                and after_save):
+
+            # Make a copy, removing ID and UUID. It's most likely invalid
+            # (ID 0) if it failed to match a unique task.
+            data = copy.deepcopy(self._data)
+            data.pop('id', None)
+            data.pop('uuid', None)
+
+            taskfilter = TaskFilter(self.warrior)
+            for key, value in data.items():
+                taskfilter.add_filter_param(key, value)
+
+            output = self.warrior.execute_command(['export', '--'] +
+                taskfilter.get_filter_params())
+
+        # If more than 1 task has been matched still, raise an exception
+        if not valid(output):
+            raise TaskWarriorException(
+                "Unique identifiers {0} with description: {1} matches "
+                "multiple tasks: {2}".format(
+                self['uuid'] or self['id'], self['description'], output)
+            )
+
+        new_data = json.loads(output[0])
         if only_fields:
             to_update = dict(
                 [(k, new_data.get(k)) for k in only_fields])
@@ -580,8 +761,9 @@ class TaskFilter(SerializingObject):
     A set of parameters to filter the task list with.
     """
 
-    def __init__(self, filter_params=[]):
-        self.filter_params = filter_params
+    def __init__(self, warrior, filter_params=None):
+        self.filter_params = filter_params or []
+        super(TaskFilter, self).__init__(warrior)
 
     def add_filter(self, filter_str):
         self.filter_params.append(filter_str)
@@ -593,12 +775,8 @@ class TaskFilter(SerializingObject):
         # convention in TW for empty values
         attribute_key = key.split('.')[0]
 
-        # Since this is user input, we need to normalize datetime
-        # objects
-        if (isinstance(value, datetime.datetime) or
-            isinstance(value, datetime.date)):
-            value = self.normalize_datetime(value)
-
+        # Since this is user input, we need to normalize before we serialize
+        value = self._normalize(attribute_key, value)
         value = self._serialize(attribute_key, value)
 
         # If we are filtering by uuid:, do not use uuid keyword
@@ -611,16 +789,18 @@ class TaskFilter(SerializingObject):
 
             # We enforce equality match by using 'is' (or 'none') modifier
             # Without using this syntax, filter fails due to TW-1479
-            modifier = '.is' if value else '.none'
-            key = key + modifier if '.' not in key else key
+            # which is, however, fixed in 2.4.5
+            if self.warrior.version < VERSION_2_4_5:
+                modifier = '.is' if value else '.none'
+                key = key + modifier if '.' not in key else key
 
-            self.filter_params.append("{0}:{1}".format(key, value))
+            self.filter_params.append(six.u("{0}:{1}").format(key, value))
 
     def get_filter_params(self):
         return [f for f in self.filter_params if f]
 
     def clone(self):
-        c = self.__class__()
+        c = self.__class__(self.warrior)
         c.filter_params = list(self.filter_params)
         return c
 
@@ -633,7 +813,7 @@ class TaskQuerySet(object):
     def __init__(self, warrior=None, filter_obj=None):
         self.warrior = warrior
         self._result_cache = None
-        self.filter_obj = filter_obj or TaskFilter()
+        self.filter_obj = filter_obj or TaskFilter(warrior)
 
     def __deepcopy__(self, memo):
         """
@@ -733,73 +913,3 @@ class TaskQuerySet(object):
         raise ValueError(
             'get() returned more than one Task -- it returned {0}! '
             'Lookup parameters were {1}'.format(num, kwargs))
-
-
-class TaskWarrior(object):
-    def __init__(self, data_location='~/.task', create=True):
-        data_location = os.path.expanduser(data_location)
-        if create and not os.path.exists(data_location):
-            os.makedirs(data_location)
-        self.config = {
-            'data.location': os.path.expanduser(data_location),
-            'confirmation': 'no',
-            'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
-            'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
-        }
-        self.tasks = TaskQuerySet(self)
-        self.version = self._get_version()
-
-    def _get_command_args(self, args, config_override={}):
-        command_args = ['task', 'rc:/']
-        config = self.config.copy()
-        config.update(config_override)
-        for item in config.items():
-            command_args.append('rc.{0}={1}'.format(*item))
-        command_args.extend(map(str, args))
-        return command_args
-
-    def _get_version(self):
-        p = subprocess.Popen(
-                ['task', '--version'],
-                stdout=subprocess.PIPE,
-                stderr=subprocess.PIPE)
-        stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
-        return stdout.strip('\n')
-
-    def execute_command(self, args, config_override={}):
-        command_args = self._get_command_args(
-            args, config_override=config_override)
-        logger.debug(' '.join(command_args))
-        p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
-                             stderr=subprocess.PIPE)
-        stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
-        if p.returncode:
-            if stderr.strip():
-                error_msg = stderr.strip().splitlines()[-1]
-            else:
-                error_msg = stdout.strip()
-            raise TaskWarriorException(error_msg)
-        return stdout.strip().split('\n')
-
-    def filter_tasks(self, filter_obj):
-        args = ['export', '--'] + filter_obj.get_filter_params()
-        tasks = []
-        for line in self.execute_command(args):
-            if line:
-                data = line.strip(',')
-                try:
-                    filtered_task = Task(self)
-                    filtered_task._load_data(json.loads(data))
-                    tasks.append(filtered_task)
-                except ValueError:
-                    raise TaskWarriorException('Invalid JSON: %s' % data)
-        return tasks
-
-    def merge_with(self, path, push=False):
-        path = path.rstrip('/') + '/'
-        self.execute_command(['merge', path], config_override={
-            'merge.autopush': 'yes' if push else 'no',
-        })
-
-    def undo(self):
-        self.execute_command(['undo'])