]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Task: Write serializers and deserializers for all timestamp attributes
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class TaskResource(object):
28     read_only_fields = []
29
30     def _load_data(self, data):
31         self._data = data
32         # We need to use a copy for original data, so that changes
33         # are not propagated. Shallow copy is alright, since data dict uses only
34         # primitive data types
35         self._original_data = data.copy()
36
37     def _update_data(self, data, update_original=False):
38         """
39         Low level update of the internal _data dict. Data which are coming as
40         updates should already be serialized. If update_original is True, the
41         original_data dict is updated as well.
42         """
43
44         self._data.update(data)
45
46         if update_original:
47             self._original_data.update(data)
48
49     def __getitem__(self, key):
50         # This is a workaround to make TaskResource non-iterable
51         # over simple index-based iteration
52         try:
53             int(key)
54             raise StopIteration
55         except ValueError:
56             pass
57
58         return self._deserialize(key, self._data.get(key))
59
60     def __setitem__(self, key, value):
61         if key in self.read_only_fields:
62             raise RuntimeError('Field \'%s\' is read-only' % key)
63         self._data[key] = self._serialize(key, value)
64
65     def _deserialize(self, key, value):
66         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
67                                lambda x: x)
68         return hydrate_func(value)
69
70     def _serialize(self, key, value):
71         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
72                                  lambda x: x)
73         return dehydrate_func(value)
74
75     def __str__(self):
76         s = six.text_type(self.__unicode__())
77         if not six.PY3:
78             s = s.encode('utf-8')
79         return s
80
81     def __repr__(self):
82         return str(self)
83
84     def timestamp_serializer(self, date):
85         if not date:
86             return None
87         return date.strftime(DATE_FORMAT)
88
89     def timestamp_deserializer(self, date_str):
90         if not date_str:
91             return None
92         return datetime.datetime.strptime(date_str, DATE_FORMAT)
93
94     def serialize_entry(self, value):
95         return self.timestamp_serializer(value)
96
97     def deserialize_entry(self, value):
98         return self.timestamp_deserializer(value)
99
100
101 class TaskAnnotation(TaskResource):
102     read_only_fields = ['entry', 'description']
103
104     def __init__(self, task, data={}):
105         self.task = task
106         self._load_data(data)
107
108     def remove(self):
109         self.task.remove_annotation(self)
110
111     def __unicode__(self):
112         return self['description']
113
114     __repr__ = __unicode__
115
116
117 class Task(TaskResource):
118     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
119
120     class DoesNotExist(Exception):
121         pass
122
123     class CompletedTask(Exception):
124         """
125         Raised when the operation cannot be performed on the completed task.
126         """
127         pass
128
129     class DeletedTask(Exception):
130         """
131         Raised when the operation cannot be performed on the deleted task.
132         """
133         pass
134
135     class NotSaved(Exception):
136         """
137         Raised when the operation cannot be performed on the task, because
138         it has not been saved to TaskWarrior yet.
139         """
140         pass
141
142     def __init__(self, warrior, **kwargs):
143         self.warrior = warrior
144
145         # Check that user is not able to set read-only value in __init__
146         for key in kwargs.keys():
147             if key in self.read_only_fields:
148                 raise RuntimeError('Field \'%s\' is read-only' % key)
149
150         # We serialize the data in kwargs so that users of the library
151         # do not have to pass different data formats via __setitem__ and
152         # __init__ methods, that would be confusing
153
154         # Rather unfortunate syntax due to python2.6 comaptiblity
155         self._load_data(dict((key, self._serialize(key, value))
156                         for (key, value) in six.iteritems(kwargs)))
157
158     def __unicode__(self):
159         return self['description']
160
161     def __eq__(self, other):
162         if self['uuid'] and other['uuid']:
163             # For saved Tasks, just define equality by equality of uuids
164             return self['uuid'] == other['uuid']
165         else:
166             # If the tasks are not saved, compare the actual instances
167             return id(self) == id(other)
168
169
170     def __hash__(self):
171         if self['uuid']:
172             # For saved Tasks, just define equality by equality of uuids
173             return self['uuid'].__hash__()
174         else:
175             # If the tasks are not saved, return hash of instance id
176             return id(self).__hash__()
177
178     @property
179     def _modified_fields(self):
180         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
181         for key in writable_fields:
182             if self._data.get(key) != self._original_data.get(key):
183                 yield key
184
185     @property
186     def completed(self):
187         return self['status'] == six.text_type('completed')
188
189     @property
190     def deleted(self):
191         return self['status'] == six.text_type('deleted')
192
193     @property
194     def waiting(self):
195         return self['status'] == six.text_type('waiting')
196
197     @property
198     def pending(self):
199         return self['status'] == six.text_type('pending')
200
201     @property
202     def saved(self):
203         return self['uuid'] is not None or self['id'] is not None
204
205     def serialize_due(self, value):
206         return self.timestamp_serializer(value)
207
208     def deserialize_due(self, value):
209         return self.timestamp_deserializer(value)
210
211     def serialize_scheduled(self, value):
212         return self.timestamp_serializer(value)
213
214     def deserialize_scheduled(self, value):
215         return self.timestamp_deserializer(value)
216
217     def serialize_until(self, value):
218         return self.timestamp_serializer(value)
219
220     def deserialize_until(self, value):
221         return self.timestamp_deserializer(value)
222
223     def serialize_wait(self, value):
224         return self.timestamp_serializer(value)
225
226     def deserialize_wait(self, value):
227         return self.timestamp_deserializer(value)
228
229     def serialize_depends(self, cur_dependencies):
230         # Check that all the tasks are saved
231         for task in cur_dependencies:
232             if not task.saved:
233                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
234                                     'it can be set as dependency.' % task)
235
236         # Return the list of uuids
237         return ','.join(task['uuid'] for task in cur_dependencies)
238
239     def deserialize_depends(self, raw_uuids):
240         raw_uuids = raw_uuids or ''  # Convert None to empty string
241         uuids = raw_uuids.split(',')
242         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
243
244     def format_depends(self):
245         # We need to generate added and removed dependencies list,
246         # since Taskwarrior does not accept redefining dependencies.
247
248         # This cannot be part of serialize_depends, since we need
249         # to keep a list of all depedencies in the _data dictionary,
250         # not just currently added/removed ones
251
252         old_dependencies_raw = self._original_data.get('depends','')
253         old_dependencies = self.deserialize_depends(old_dependencies_raw)
254
255         added = self['depends'] - old_dependencies
256         removed = old_dependencies - self['depends']
257
258         # Removed dependencies need to be prefixed with '-'
259         return 'depends:' + ','.join(
260                 [t['uuid'] for t in added] +
261                 ['-' + t['uuid'] for t in removed]
262             )
263
264     def deserialize_annotations(self, data):
265         return [TaskAnnotation(self, d) for d in data] if data else []
266
267     def deserialize_tags(self, tags):
268         if isinstance(tags, basestring):
269             return tags.split(',') if tags else []
270         return tags
271
272     def serialize_tags(self, tags):
273         return ','.join(tags) if tags else ''
274
275     def format_description(self):
276         # Task version older than 2.4.0 ignores first word of the
277         # task description if description: prefix is used
278         if self.warrior.version < VERSION_2_4_0:
279             return self._data['description']
280         else:
281             return "description:'{0}'".format(self._data['description'] or '')
282
283     def delete(self):
284         if not self.saved:
285             raise Task.NotSaved("Task needs to be saved before it can be deleted")
286
287         # Refresh the status, and raise exception if the task is deleted
288         self.refresh(only_fields=['status'])
289
290         if self.deleted:
291             raise Task.DeletedTask("Task was already deleted")
292
293         self.warrior.execute_command([self['uuid'], 'delete'])
294
295         # Refresh the status again, so that we have updated info stored
296         self.refresh(only_fields=['status'])
297
298
299     def done(self):
300         if not self.saved:
301             raise Task.NotSaved("Task needs to be saved before it can be completed")
302
303         # Refresh, and raise exception if task is already completed/deleted
304         self.refresh(only_fields=['status'])
305
306         if self.completed:
307             raise Task.CompletedTask("Cannot complete a completed task")
308         elif self.deleted:
309             raise Task.DeletedTask("Deleted task cannot be completed")
310
311         self.warrior.execute_command([self['uuid'], 'done'])
312
313         # Refresh the status again, so that we have updated info stored
314         self.refresh(only_fields=['status'])
315
316     def save(self):
317         args = [self['uuid'], 'modify'] if self.saved else ['add']
318         args.extend(self._get_modified_fields_as_args())
319         output = self.warrior.execute_command(args)
320
321         # Parse out the new ID, if the task is being added for the first time
322         if not self.saved:
323             id_lines = [l for l in output if l.startswith('Created task ')]
324
325             # Complain loudly if it seems that more tasks were created
326             # Should not happen
327             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
328                 raise TaskWarriorException("Unexpected output when creating "
329                                            "task: %s" % '\n'.join(id_lines))
330
331             # Circumvent the ID storage, since ID is considered read-only
332             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
333
334         self.refresh()
335
336     def add_annotation(self, annotation):
337         if not self.saved:
338             raise Task.NotSaved("Task needs to be saved to add annotation")
339
340         args = [self['uuid'], 'annotate', annotation]
341         self.warrior.execute_command(args)
342         self.refresh(only_fields=['annotations'])
343
344     def remove_annotation(self, annotation):
345         if not self.saved:
346             raise Task.NotSaved("Task needs to be saved to add annotation")
347
348         if isinstance(annotation, TaskAnnotation):
349             annotation = annotation['description']
350         args = [self['uuid'], 'denotate', annotation]
351         self.warrior.execute_command(args)
352         self.refresh(only_fields=['annotations'])
353
354     def _get_modified_fields_as_args(self):
355         args = []
356
357         def add_field(field):
358             # Add the output of format_field method to args list (defaults to
359             # field:value)
360             format_default = lambda k: "{0}:'{1}'".format(k, self._data[k] or '')
361             format_func = getattr(self, 'format_{0}'.format(field),
362                                   lambda: format_default(field))
363             args.append(format_func())
364
365         # If we're modifying saved task, simply pass on all modified fields
366         if self.saved:
367             for field in self._modified_fields:
368                 add_field(field)
369         # For new tasks, pass all fields that make sense
370         else:
371             for field in self._data.keys():
372                 if field in self.read_only_fields:
373                     continue
374                 add_field(field)
375
376         return args
377
378     def refresh(self, only_fields=[]):
379         # Raise error when trying to refresh a task that has not been saved
380         if not self.saved:
381             raise Task.NotSaved("Task needs to be saved to be refreshed")
382
383         # We need to use ID as backup for uuid here for the refreshes
384         # of newly saved tasks. Any other place in the code is fine
385         # with using UUID only.
386         args = [self['uuid'] or self['id'], 'export']
387         new_data = json.loads(self.warrior.execute_command(args)[0])
388         if only_fields:
389             to_update = dict(
390                 [(k, new_data.get(k)) for k in only_fields])
391             self._update_data(to_update, update_original=True)
392         else:
393             self._load_data(new_data)
394
395
396 class TaskFilter(object):
397     """
398     A set of parameters to filter the task list with.
399     """
400
401     def __init__(self, filter_params=[]):
402         self.filter_params = filter_params
403
404     def add_filter(self, filter_str):
405         self.filter_params.append(filter_str)
406
407     def add_filter_param(self, key, value):
408         key = key.replace('__', '.')
409
410         # Replace the value with empty string, since that is the
411         # convention in TW for empty values
412         value = value if value is not None else ''
413
414         # If we are filtering by uuid:, do not use uuid keyword
415         # due to TW-1452 bug
416         if key == 'uuid':
417             self.filter_params.insert(0, value)
418         else:
419             self.filter_params.append('{0}:{1}'.format(key, value))
420
421     def get_filter_params(self):
422         return [f for f in self.filter_params if f]
423
424     def clone(self):
425         c = self.__class__()
426         c.filter_params = list(self.filter_params)
427         return c
428
429
430 class TaskQuerySet(object):
431     """
432     Represents a lazy lookup for a task objects.
433     """
434
435     def __init__(self, warrior=None, filter_obj=None):
436         self.warrior = warrior
437         self._result_cache = None
438         self.filter_obj = filter_obj or TaskFilter()
439
440     def __deepcopy__(self, memo):
441         """
442         Deep copy of a QuerySet doesn't populate the cache
443         """
444         obj = self.__class__()
445         for k, v in self.__dict__.items():
446             if k in ('_iter', '_result_cache'):
447                 obj.__dict__[k] = None
448             else:
449                 obj.__dict__[k] = copy.deepcopy(v, memo)
450         return obj
451
452     def __repr__(self):
453         data = list(self[:REPR_OUTPUT_SIZE + 1])
454         if len(data) > REPR_OUTPUT_SIZE:
455             data[-1] = "...(remaining elements truncated)..."
456         return repr(data)
457
458     def __len__(self):
459         if self._result_cache is None:
460             self._result_cache = list(self)
461         return len(self._result_cache)
462
463     def __iter__(self):
464         if self._result_cache is None:
465             self._result_cache = self._execute()
466         return iter(self._result_cache)
467
468     def __getitem__(self, k):
469         if self._result_cache is None:
470             self._result_cache = list(self)
471         return self._result_cache.__getitem__(k)
472
473     def __bool__(self):
474         if self._result_cache is not None:
475             return bool(self._result_cache)
476         try:
477             next(iter(self))
478         except StopIteration:
479             return False
480         return True
481
482     def __nonzero__(self):
483         return type(self).__bool__(self)
484
485     def _clone(self, klass=None, **kwargs):
486         if klass is None:
487             klass = self.__class__
488         filter_obj = self.filter_obj.clone()
489         c = klass(warrior=self.warrior, filter_obj=filter_obj)
490         c.__dict__.update(kwargs)
491         return c
492
493     def _execute(self):
494         """
495         Fetch the tasks which match the current filters.
496         """
497         return self.warrior.filter_tasks(self.filter_obj)
498
499     def all(self):
500         """
501         Returns a new TaskQuerySet that is a copy of the current one.
502         """
503         return self._clone()
504
505     def pending(self):
506         return self.filter(status=PENDING)
507
508     def completed(self):
509         return self.filter(status=COMPLETED)
510
511     def filter(self, *args, **kwargs):
512         """
513         Returns a new TaskQuerySet with the given filters added.
514         """
515         clone = self._clone()
516         for f in args:
517             clone.filter_obj.add_filter(f)
518         for key, value in kwargs.items():
519             clone.filter_obj.add_filter_param(key, value)
520         return clone
521
522     def get(self, **kwargs):
523         """
524         Performs the query and returns a single object matching the given
525         keyword arguments.
526         """
527         clone = self.filter(**kwargs)
528         num = len(clone)
529         if num == 1:
530             return clone._result_cache[0]
531         if not num:
532             raise Task.DoesNotExist(
533                 'Task matching query does not exist. '
534                 'Lookup parameters were {0}'.format(kwargs))
535         raise ValueError(
536             'get() returned more than one Task -- it returned {0}! '
537             'Lookup parameters were {1}'.format(num, kwargs))
538
539
540 class TaskWarrior(object):
541     def __init__(self, data_location='~/.task', create=True):
542         data_location = os.path.expanduser(data_location)
543         if create and not os.path.exists(data_location):
544             os.makedirs(data_location)
545         self.config = {
546             'data.location': os.path.expanduser(data_location),
547             'confirmation': 'no',
548         }
549         self.tasks = TaskQuerySet(self)
550         self.version = self._get_version()
551
552     def _get_command_args(self, args, config_override={}):
553         command_args = ['task', 'rc:/']
554         config = self.config.copy()
555         config.update(config_override)
556         for item in config.items():
557             command_args.append('rc.{0}={1}'.format(*item))
558         command_args.extend(map(str, args))
559         return command_args
560
561     def _get_version(self):
562         p = subprocess.Popen(
563                 ['task', '--version'],
564                 stdout=subprocess.PIPE,
565                 stderr=subprocess.PIPE)
566         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
567         return stdout.strip('\n')
568
569     def execute_command(self, args, config_override={}):
570         command_args = self._get_command_args(
571             args, config_override=config_override)
572         logger.debug(' '.join(command_args))
573         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
574                              stderr=subprocess.PIPE)
575         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
576         if p.returncode:
577             if stderr.strip():
578                 error_msg = stderr.strip().splitlines()[-1]
579             else:
580                 error_msg = stdout.strip()
581             raise TaskWarriorException(error_msg)
582         return stdout.strip().split('\n')
583
584     def filter_tasks(self, filter_obj):
585         args = ['export', '--'] + filter_obj.get_filter_params()
586         tasks = []
587         for line in self.execute_command(args):
588             if line:
589                 data = line.strip(',')
590                 try:
591                     filtered_task = Task(self)
592                     filtered_task._load_data(json.loads(data))
593                     tasks.append(filtered_task)
594                 except ValueError:
595                     raise TaskWarriorException('Invalid JSON: %s' % data)
596         return tasks
597
598     def merge_with(self, path, push=False):
599         path = path.rstrip('/') + '/'
600         self.execute_command(['merge', path], config_override={
601             'merge.autopush': 'yes' if push else 'no',
602         })
603
604     def undo(self):
605         self.execute_command(['undo'])