X-Git-Url: https://git.madduck.net/etc/taskwarrior.git/blobdiff_plain/a40915e62131c8828305ec59ed630c63e6abecd7..90d018e0f0c05ecbaed4699bc743b846abeafced:/tasklib/task.py?ds=inline

diff --git a/tasklib/task.py b/tasklib/task.py
index f482f53..72e16af 100644
--- a/tasklib/task.py
+++ b/tasklib/task.py
@@ -4,11 +4,14 @@ import datetime
 import json
 import logging
 import os
+import pytz
 import six
 import sys
 import subprocess
+import tzlocal
 
 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
+DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
 REPR_OUTPUT_SIZE = 10
 PENDING = 'pending'
 COMPLETED = 'completed'
@@ -17,14 +20,49 @@ VERSION_2_1_0 = six.u('2.1.0')
 VERSION_2_2_0 = six.u('2.2.0')
 VERSION_2_3_0 = six.u('2.3.0')
 VERSION_2_4_0 = six.u('2.4.0')
+VERSION_2_4_1 = six.u('2.4.1')
+VERSION_2_4_2 = six.u('2.4.2')
+VERSION_2_4_3 = six.u('2.4.3')
 
 logger = logging.getLogger(__name__)
+local_zone = tzlocal.get_localzone()
 
 
 class TaskWarriorException(Exception):
     pass
 
 
+class ReadOnlyDictView(object):
+    """
+    Provides simplified read-only view upon dict object.
+    """
+
+    def __init__(self, viewed_dict):
+        self.viewed_dict = viewed_dict
+
+    def __getitem__(self, key):
+        return copy.deepcopy(self.viewed_dict.__getitem__(key))
+
+    def __contains__(self, k):
+        return self.viewed_dict.__contains__(k)
+
+    def __iter__(self):
+        for value in self.viewed_dict:
+            yield copy.deepcopy(value)
+
+    def __len__(self):
+        return len(self.viewed_dict)
+
+    def get(self, key, default=None):
+        return copy.deepcopy(self.viewed_dict.get(key, default))
+
+    def items(self):
+        return [copy.deepcopy(v) for v in self.viewed_dict.items()]
+
+    def values(self):
+        return [copy.deepcopy(v) for v in self.viewed_dict.values()]
+
+
 class SerializingObject(object):
     """
     Common ancestor for TaskResource & TaskFilter, since they both
@@ -41,8 +79,20 @@ class SerializingObject(object):
         not export empty-valued attributes) if the attribute
         is not iterable (e.g. list or set), in which case
         a empty iterable should be used.
+
+    Normalizing methods should hold the following contract:
+      - They are used to validate and normalize the user input.
+        Any attribute value that comes from the user (during Task
+        initialization, assignign values to Task attributes, or
+        filtering by user-provided values of attributes) is first
+        validated and normalized using the normalize_{key} method.
+      - If validation or normalization fails, normalizer is expected
+        to raise ValueError.
     """
 
+    def __init__(self, warrior):
+        self.warrior = warrior
+
     def _deserialize(self, key, value):
         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
                                lambda x: x if x != '' else None)
@@ -53,15 +103,40 @@ class SerializingObject(object):
                                  lambda x: x if x is not None else '')
         return dehydrate_func(value)
 
+    def _normalize(self, key, value):
+        """
+        Use normalize_<key> methods to normalize user input. Any user
+        input will be normalized at the moment it is used as filter,
+        or entered as a value of Task attribute.
+        """
+
+        # None value should not be converted by normalizer
+        if value is None:
+            return None
+
+        normalize_func = getattr(self, 'normalize_{0}'.format(key),
+                                 lambda x: x)
+
+        return normalize_func(value)
+
     def timestamp_serializer(self, date):
         if not date:
             return ''
+
+        # Any serialized timestamp should be localized, we need to
+        # convert to UTC before converting to string (DATE_FORMAT uses UTC)
+        date = date.astimezone(pytz.utc)
+
         return date.strftime(DATE_FORMAT)
 
     def timestamp_deserializer(self, date_str):
         if not date_str:
             return None
-        return datetime.datetime.strptime(date_str, DATE_FORMAT)
+
+        # Return timestamp localized in the local zone
+        naive_timestamp = datetime.datetime.strptime(date_str, DATE_FORMAT)
+        localized_timestamp = pytz.utc.localize(naive_timestamp)
+        return localized_timestamp.astimezone(local_zone)
 
     def serialize_entry(self, value):
         return self.timestamp_serializer(value)
@@ -69,36 +144,81 @@ class SerializingObject(object):
     def deserialize_entry(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_entry(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_modified(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_modified(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_modified(self, value):
+        return self.datetime_normalizer(value)
+
+    def serialize_start(self, value):
+        return self.timestamp_serializer(value)
+
+    def deserialize_start(self, value):
+        return self.timestamp_deserializer(value)
+
+    def normalize_start(self, value):
+        return self.datetime_normalizer(value)
+
+    def serialize_end(self, value):
+        return self.timestamp_serializer(value)
+
+    def deserialize_end(self, value):
+        return self.timestamp_deserializer(value)
+
+    def normalize_end(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_due(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_due(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_due(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_scheduled(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_scheduled(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_scheduled(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_until(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_until(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_until(self, value):
+        return self.datetime_normalizer(value)
+
     def serialize_wait(self, value):
         return self.timestamp_serializer(value)
 
     def deserialize_wait(self, value):
         return self.timestamp_deserializer(value)
 
+    def normalize_wait(self, value):
+        return self.datetime_normalizer(value)
+
+    def serialize_annotations(self, value):
+        value = value if value is not None else []
+
+        # This may seem weird, but it's correct, we want to export
+        # a list of dicts as serialized value
+        serialized_annotations = [json.loads(annotation.export_data())
+                                  for annotation in value]
+        return serialized_annotations if serialized_annotations else ''
+
     def deserialize_annotations(self, data):
         return [TaskAnnotation(self, d) for d in data] if data else []
 
@@ -120,6 +240,52 @@ class SerializingObject(object):
         uuids = raw_uuids.split(',')
         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
 
+    def datetime_normalizer(self, value):
+        """
+        Normalizes date/datetime value (considered to come from user input)
+        to localized datetime value. Following conversions happen:
+
+        naive date -> localized datetime with the same date, and time=midnight
+        naive datetime -> localized datetime with the same value
+        localized datetime -> localized datetime (no conversion)
+        """
+
+        if (isinstance(value, datetime.date)
+            and not isinstance(value, datetime.datetime)):
+            # Convert to local midnight
+            value_full = datetime.datetime.combine(value, datetime.time.min)
+            localized = local_zone.localize(value_full)
+        elif isinstance(value, datetime.datetime):
+            if value.tzinfo is None:
+                # Convert to localized datetime object
+                localized = local_zone.localize(value)
+            else:
+                # If the value is already localized, there is no need to change
+                # time zone at this point. Also None is a valid value too.
+                localized = value
+        elif (isinstance(value, six.string_types)
+                and self.warrior.version > VERSION_2_4_0):
+            # For strings, use 'task calc' to evaluate the string to datetime
+            # available since TW 2.4.0
+            args = value.split()
+            result = self.warrior.execute_command(['calc'] + args)
+            naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
+            localized = local_zone.localize(naive)
+        else:
+            raise ValueError("Provided value could not be converted to "
+                             "datetime, its type is not supported: {}"
+                             .format(type(value)))
+
+        return localized
+
+    def normalize_uuid(self, value):
+        # Enforce sane UUID
+        if not isinstance(value, six.string_types) or value == '':
+            raise ValueError("UUID must be a valid non-empty string, "
+                             "not: {}".format(value))
+
+        return value
+
 
 class TaskResource(SerializingObject):
     read_only_fields = []
@@ -161,6 +327,9 @@ class TaskResource(SerializingObject):
     def __setitem__(self, key, value):
         if key in self.read_only_fields:
             raise RuntimeError('Field \'%s\' is read-only' % key)
+
+        # Normalize the user input before saving it
+        value = self._normalize(key, value)
         self._data[key] = value
 
     def __str__(self):
@@ -213,6 +382,7 @@ class TaskAnnotation(TaskResource):
     def __init__(self, task, data={}):
         self.task = task
         self._load_data(data)
+        super(TaskAnnotation, self).__init__(task.warrior)
 
     def remove(self):
         self.task.remove_annotation(self)
@@ -254,7 +424,7 @@ class Task(TaskResource):
         pass
 
     @classmethod
-    def from_input(cls, input_file=sys.stdin, modify=None):
+    def from_input(cls, input_file=sys.stdin, modify=None, warrior=None):
         """
         Creates a Task object, directly from the stdin, by reading one line.
         If modify=True, two lines are used, first line interpreted as the
@@ -270,13 +440,18 @@ class Task(TaskResource):
         but defaults to sys.stdin.
         """
 
-        # TaskWarrior instance is set to None
-        task = cls(None)
-
         # Detect the hook type if not given directly
         name = os.path.basename(sys.argv[0])
         modify = name.startswith('on-modify') if modify is None else modify
 
+        # Create the TaskWarrior instance if none passed
+        if warrior is None:
+            hook_parent_dir = os.path.dirname(os.path.dirname(sys.argv[0]))
+            warrior = TaskWarrior(data_location=hook_parent_dir)
+
+        # TaskWarrior instance is set to None
+        task = cls(warrior)
+
         # Load the data from the input
         task._load_data(json.loads(input_file.readline().strip()))
 
@@ -288,7 +463,7 @@ class Task(TaskResource):
         return task
 
     def __init__(self, warrior, **kwargs):
-        self.warrior = warrior
+        super(Task, self).__init__(warrior)
 
         # Check that user is not able to set read-only value in __init__
         for key in kwargs.keys():
@@ -300,8 +475,12 @@ class Task(TaskResource):
         # __init__ methods, that would be confusing
 
         # Rather unfortunate syntax due to python2.6 comaptiblity
-        self._load_data(dict((key, self._serialize(key, value))
-                        for (key, value) in six.iteritems(kwargs)))
+        self._data = dict((key, self._normalize(key, value))
+                          for (key, value) in six.iteritems(kwargs))
+        self._original_data = copy.deepcopy(self._data)
+
+        # Provide read only access to the original data
+        self.original = ReadOnlyDictView(self._original_data)
 
     def __unicode__(self):
         return self['description']
@@ -392,8 +571,24 @@ class Task(TaskResource):
         self.warrior.execute_command([self['uuid'], 'delete'])
 
         # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start', 'end'])
+
+    def start(self):
+        if not self.saved:
+            raise Task.NotSaved("Task needs to be saved before it can be started")
+
+        # Refresh, and raise exception if task is already completed/deleted
         self.refresh(only_fields=['status'])
 
+        if self.completed:
+            raise Task.CompletedTask("Cannot start a completed task")
+        elif self.deleted:
+            raise Task.DeletedTask("Deleted task cannot be started")
+
+        self.warrior.execute_command([self['uuid'], 'start'])
+
+        # Refresh the status again, so that we have updated info stored
+        self.refresh(only_fields=['status', 'start'])
 
     def done(self):
         if not self.saved:
@@ -410,7 +605,7 @@ class Task(TaskResource):
         self.warrior.execute_command([self['uuid'], 'done'])
 
         # Refresh the status again, so that we have updated info stored
-        self.refresh(only_fields=['status'])
+        self.refresh(only_fields=['status', 'start', 'end'])
 
     def save(self):
         if self.saved and not self.modified:
@@ -514,8 +709,9 @@ class TaskFilter(SerializingObject):
     A set of parameters to filter the task list with.
     """
 
-    def __init__(self, filter_params=[]):
+    def __init__(self, warrior, filter_params=[]):
         self.filter_params = filter_params
+        super(TaskFilter, self).__init__(warrior)
 
     def add_filter(self, filter_str):
         self.filter_params.append(filter_str)
@@ -526,6 +722,9 @@ class TaskFilter(SerializingObject):
         # Replace the value with empty string, since that is the
         # convention in TW for empty values
         attribute_key = key.split('.')[0]
+
+        # Since this is user input, we need to normalize before we serialize
+        value = self._normalize(attribute_key, value)
         value = self._serialize(attribute_key, value)
 
         # If we are filtering by uuid:, do not use uuid keyword
@@ -547,7 +746,7 @@ class TaskFilter(SerializingObject):
         return [f for f in self.filter_params if f]
 
     def clone(self):
-        c = self.__class__()
+        c = self.__class__(self.warrior)
         c.filter_params = list(self.filter_params)
         return c
 
@@ -560,7 +759,7 @@ class TaskQuerySet(object):
     def __init__(self, warrior=None, filter_obj=None):
         self.warrior = warrior
         self._result_cache = None
-        self.filter_obj = filter_obj or TaskFilter()
+        self.filter_obj = filter_obj or TaskFilter(warrior)
 
     def __deepcopy__(self, memo):
         """
@@ -663,21 +862,32 @@ class TaskQuerySet(object):
 
 
 class TaskWarrior(object):
-    def __init__(self, data_location='~/.task', create=True):
+    def __init__(self, data_location='~/.task', create=True, taskrc_location='~/.taskrc'):
         data_location = os.path.expanduser(data_location)
+        self.taskrc_location = os.path.expanduser(taskrc_location)
+
+        # If taskrc does not exist, pass / to use defaults and avoid creating
+        # dummy .taskrc file by TaskWarrior
+        if not os.path.exists(self.taskrc_location):
+            self.taskrc_location = '/'
+
         if create and not os.path.exists(data_location):
             os.makedirs(data_location)
+
+        self.version = self._get_version()
         self.config = {
-            'data.location': os.path.expanduser(data_location),
+            'data.location': data_location,
             'confirmation': 'no',
             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
+            # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
+            # arbitrary big number which is likely to be large enough
+            'bulk': 0 if self.version > VERSION_2_4_3 else 100000,
         }
         self.tasks = TaskQuerySet(self)
-        self.version = self._get_version()
 
     def _get_command_args(self, args, config_override={}):
-        command_args = ['task', 'rc:/']
+        command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
         config = self.config.copy()
         config.update(config_override)
         for item in config.items():
@@ -693,22 +903,31 @@ class TaskWarrior(object):
         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
         return stdout.strip('\n')
 
-    def execute_command(self, args, config_override={}):
+    def execute_command(self, args, config_override={}, allow_failure=True):
         command_args = self._get_command_args(
             args, config_override=config_override)
         logger.debug(' '.join(command_args))
         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE)
         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
-        if p.returncode:
+        if p.returncode and allow_failure:
             if stderr.strip():
-                error_msg = stderr.strip().splitlines()[-1]
+                error_msg = stderr.strip()
             else:
                 error_msg = stdout.strip()
             raise TaskWarriorException(error_msg)
-        return stdout.strip().split('\n')
+        return stdout.rstrip().split('\n')
+
+    def enforce_recurrence(self):
+        # Run arbitrary report command which will trigger generation
+        # of recurrent tasks.
+
+        # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
+        if self.version < VERSION_2_4_2:
+            self.execute_command(['next'], allow_failure=False)
 
     def filter_tasks(self, filter_obj):
+        self.enforce_recurrence()
         args = ['export', '--'] + filter_obj.get_filter_params()
         tasks = []
         for line in self.execute_command(args):