X-Git-Url: https://git.madduck.net/etc/taskwarrior.git/blobdiff_plain/5421d65128402657771440f1f667631d56d44fe8..80230e4e58118eaec09b0a83459b8543fbf78260:/tasklib/task.py?ds=sidebyside diff --git a/tasklib/task.py b/tasklib/task.py index 3097f4c..537c792 100644 --- a/tasklib/task.py +++ b/tasklib/task.py @@ -4,9 +4,11 @@ import datetime import json import logging import os +import pytz import six import sys import subprocess +import tzlocal DATE_FORMAT = '%Y%m%dT%H%M%SZ' REPR_OUTPUT_SIZE = 10 @@ -19,12 +21,47 @@ VERSION_2_3_0 = six.u('2.3.0') VERSION_2_4_0 = six.u('2.4.0') logger = logging.getLogger(__name__) +local_zone = tzlocal.get_localzone() class TaskWarriorException(Exception): pass +class ReadOnlyDictView(object): + """ + Provides simplified read-only view upon dict object. + """ + + def __init__(self, viewed_dict): + self.viewed_dict = viewed_dict + + def __getitem__(self, key): + return copy.deepcopy(self.viewed_dict.__getitem__(key)) + + def __contains__(self, k): + return self.viewed_dict.__contains__(k) + + def __iter__(self): + for value in self.viewed_dict: + yield copy.deepcopy(value) + + def __len__(self): + return len(self.viewed_dict) + + def get(self, key, default=None): + return copy.deepcopy(self.viewed_dict.get(key, default)) + + def has_key(self, key): + return self.viewed_dict.has_key(key) + + def items(self): + return [copy.deepcopy(v) for v in self.viewed_dict.items()] + + def values(self): + return [copy.deepcopy(v) for v in self.viewed_dict.values()] + + class SerializingObject(object): """ Common ancestor for TaskResource & TaskFilter, since they both @@ -41,6 +78,15 @@ class SerializingObject(object): not export empty-valued attributes) if the attribute is not iterable (e.g. list or set), in which case a empty iterable should be used. + + Normalizing methods should hold the following contract: + - They are used to validate and normalize the user input. + Any attribute value that comes from the user (during Task + initialization, assignign values to Task attributes, or + filtering by user-provided values of attributes) is first + validated and normalized using the normalize_{key} method. + - If validation or normalization fails, normalizer is expected + to raise ValueError. """ def _deserialize(self, key, value): @@ -68,12 +114,21 @@ class SerializingObject(object): def timestamp_serializer(self, date): if not date: return '' + + # Any serialized timestamp should be localized, we need to + # convert to UTC before converting to string (DATE_FORMAT uses UTC) + date = date.astimezone(pytz.utc) + return date.strftime(DATE_FORMAT) def timestamp_deserializer(self, date_str): if not date_str: return None - return datetime.datetime.strptime(date_str, DATE_FORMAT) + + # Return timestamp localized in the local zone + naive_timestamp = datetime.datetime.strptime(date_str, DATE_FORMAT) + localized_timestamp = pytz.utc.localize(naive_timestamp) + return localized_timestamp.astimezone(local_zone) def serialize_entry(self, value): return self.timestamp_serializer(value) @@ -81,36 +136,54 @@ class SerializingObject(object): def deserialize_entry(self, value): return self.timestamp_deserializer(value) + def normalize_entry(self, value): + return self.datetime_normalizer(value) + def serialize_modified(self, value): return self.timestamp_serializer(value) def deserialize_modified(self, value): return self.timestamp_deserializer(value) + def normalize_modified(self, value): + return self.datetime_normalizer(value) + def serialize_due(self, value): return self.timestamp_serializer(value) def deserialize_due(self, value): return self.timestamp_deserializer(value) + def normalize_due(self, value): + return self.datetime_normalizer(value) + def serialize_scheduled(self, value): return self.timestamp_serializer(value) def deserialize_scheduled(self, value): return self.timestamp_deserializer(value) + def normalize_scheduled(self, value): + return self.datetime_normalizer(value) + def serialize_until(self, value): return self.timestamp_serializer(value) def deserialize_until(self, value): return self.timestamp_deserializer(value) + def normalize_until(self, value): + return self.datetime_normalizer(value) + def serialize_wait(self, value): return self.timestamp_serializer(value) def deserialize_wait(self, value): return self.timestamp_deserializer(value) + def normalize_wait(self, value): + return self.datetime_normalizer(value) + def serialize_annotations(self, value): value = value if value is not None else [] @@ -141,6 +214,38 @@ class SerializingObject(object): uuids = raw_uuids.split(',') return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid) + def datetime_normalizer(self, value): + """ + Normalizes date/datetime value (considered to come from user input) + to localized datetime value. Following conversions happen: + + naive date -> localized datetime with the same date, and time=midnight + naive datetime -> localized datetime with the same value + localized datetime -> localized datetime (no conversion) + """ + + if (isinstance(value, datetime.date) + and not isinstance(value, datetime.datetime)): + # Convert to local midnight + value_full = datetime.datetime.combine(value, datetime.time.min) + localized = local_zone.localize(value_full) + elif isinstance(value, datetime.datetime) and value.tzinfo is None: + # Convert to localized datetime object + localized = local_zone.localize(value) + else: + # If the value is already localized, there is no need to change + # time zone at this point. Also None is a valid value too. + localized = value + + return localized + + def normalize_uuid(self, value): + # Enforce sane UUID + if not isinstance(value, six.text_type) or value == '': + raise ValueError("UUID must be a valid non-empty string.") + + return value + class TaskResource(SerializingObject): read_only_fields = [] @@ -182,6 +287,9 @@ class TaskResource(SerializingObject): def __setitem__(self, key, value): if key in self.read_only_fields: raise RuntimeError('Field \'%s\' is read-only' % key) + + # Normalize the user input before saving it + value = self._normalize(key, value) self._data[key] = value def __str__(self): @@ -325,6 +433,9 @@ class Task(TaskResource): for (key, value) in six.iteritems(kwargs)) self._original_data = copy.deepcopy(self._data) + # Provide read only access to the original data + self.original = ReadOnlyDictView(self._original_data) + def __unicode__(self): return self['description'] @@ -548,6 +659,9 @@ class TaskFilter(SerializingObject): # Replace the value with empty string, since that is the # convention in TW for empty values attribute_key = key.split('.')[0] + + # Since this is user input, we need to normalize before we serialize + value = self._normalize(key, value) value = self._serialize(attribute_key, value) # If we are filtering by uuid:, do not use uuid keyword @@ -715,14 +829,14 @@ class TaskWarrior(object): stdout, stderr = [x.decode('utf-8') for x in p.communicate()] return stdout.strip('\n') - def execute_command(self, args, config_override={}): + def execute_command(self, args, config_override={}, allow_failure=True): command_args = self._get_command_args( args, config_override=config_override) logger.debug(' '.join(command_args)) p = subprocess.Popen(command_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = [x.decode('utf-8') for x in p.communicate()] - if p.returncode: + if p.returncode and allow_failure: if stderr.strip(): error_msg = stderr.strip().splitlines()[-1] else: @@ -730,7 +844,15 @@ class TaskWarrior(object): raise TaskWarriorException(error_msg) return stdout.strip().split('\n') + def enforce_recurrence(self): + # Run arbitrary report command which will trigger generation + # of recurrent tasks. + # TODO: Make a version dependant enforcement once + # TW-1531 is handled + self.execute_command(['next'], allow_failure=False) + def filter_tasks(self, filter_obj): + self.enforce_recurrence() args = ['export', '--'] + filter_obj.get_filter_params() tasks = [] for line in self.execute_command(args):