]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

TaskFilter: Enforce exact match when building filter
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, basestring):
97             return tags.split(',') if tags else []
98         return tags
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = data
115         # We need to use a copy for original data, so that changes
116         # are not propagated. Shallow copy is alright, since data dict uses only
117         # primitive data types
118         self._original_data = data.copy()
119
120     def _update_data(self, data, update_original=False):
121         """
122         Low level update of the internal _data dict. Data which are coming as
123         updates should already be serialized. If update_original is True, the
124         original_data dict is updated as well.
125         """
126
127         self._data.update(data)
128
129         if update_original:
130             self._original_data.update(data)
131
132     def __getitem__(self, key):
133         # This is a workaround to make TaskResource non-iterable
134         # over simple index-based iteration
135         try:
136             int(key)
137             raise StopIteration
138         except ValueError:
139             pass
140
141         return self._deserialize(key, self._data.get(key))
142
143     def __setitem__(self, key, value):
144         if key in self.read_only_fields:
145             raise RuntimeError('Field \'%s\' is read-only' % key)
146         self._data[key] = self._serialize(key, value)
147
148     def __str__(self):
149         s = six.text_type(self.__unicode__())
150         if not six.PY3:
151             s = s.encode('utf-8')
152         return s
153
154     def __repr__(self):
155         return str(self)
156
157
158 class TaskAnnotation(TaskResource):
159     read_only_fields = ['entry', 'description']
160
161     def __init__(self, task, data={}):
162         self.task = task
163         self._load_data(data)
164
165     def remove(self):
166         self.task.remove_annotation(self)
167
168     def __unicode__(self):
169         return self['description']
170
171     __repr__ = __unicode__
172
173
174 class Task(TaskResource):
175     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
176
177     class DoesNotExist(Exception):
178         pass
179
180     class CompletedTask(Exception):
181         """
182         Raised when the operation cannot be performed on the completed task.
183         """
184         pass
185
186     class DeletedTask(Exception):
187         """
188         Raised when the operation cannot be performed on the deleted task.
189         """
190         pass
191
192     class NotSaved(Exception):
193         """
194         Raised when the operation cannot be performed on the task, because
195         it has not been saved to TaskWarrior yet.
196         """
197         pass
198
199     def __init__(self, warrior, **kwargs):
200         self.warrior = warrior
201
202         # Check that user is not able to set read-only value in __init__
203         for key in kwargs.keys():
204             if key in self.read_only_fields:
205                 raise RuntimeError('Field \'%s\' is read-only' % key)
206
207         # We serialize the data in kwargs so that users of the library
208         # do not have to pass different data formats via __setitem__ and
209         # __init__ methods, that would be confusing
210
211         # Rather unfortunate syntax due to python2.6 comaptiblity
212         self._load_data(dict((key, self._serialize(key, value))
213                         for (key, value) in six.iteritems(kwargs)))
214
215     def __unicode__(self):
216         return self['description']
217
218     def __eq__(self, other):
219         if self['uuid'] and other['uuid']:
220             # For saved Tasks, just define equality by equality of uuids
221             return self['uuid'] == other['uuid']
222         else:
223             # If the tasks are not saved, compare the actual instances
224             return id(self) == id(other)
225
226
227     def __hash__(self):
228         if self['uuid']:
229             # For saved Tasks, just define equality by equality of uuids
230             return self['uuid'].__hash__()
231         else:
232             # If the tasks are not saved, return hash of instance id
233             return id(self).__hash__()
234
235     @property
236     def _modified_fields(self):
237         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
238         for key in writable_fields:
239             if self._data.get(key) != self._original_data.get(key):
240                 yield key
241
242     @property
243     def completed(self):
244         return self['status'] == six.text_type('completed')
245
246     @property
247     def deleted(self):
248         return self['status'] == six.text_type('deleted')
249
250     @property
251     def waiting(self):
252         return self['status'] == six.text_type('waiting')
253
254     @property
255     def pending(self):
256         return self['status'] == six.text_type('pending')
257
258     @property
259     def saved(self):
260         return self['uuid'] is not None or self['id'] is not None
261
262     def serialize_depends(self, cur_dependencies):
263         # Check that all the tasks are saved
264         for task in cur_dependencies:
265             if not task.saved:
266                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
267                                     'it can be set as dependency.' % task)
268
269         return super(Task, self).serialize_depends(cur_dependencies)
270
271     def format_depends(self):
272         # We need to generate added and removed dependencies list,
273         # since Taskwarrior does not accept redefining dependencies.
274
275         # This cannot be part of serialize_depends, since we need
276         # to keep a list of all depedencies in the _data dictionary,
277         # not just currently added/removed ones
278
279         old_dependencies_raw = self._original_data.get('depends','')
280         old_dependencies = self.deserialize_depends(old_dependencies_raw)
281
282         added = self['depends'] - old_dependencies
283         removed = old_dependencies - self['depends']
284
285         # Removed dependencies need to be prefixed with '-'
286         return 'depends:' + ','.join(
287                 [t['uuid'] for t in added] +
288                 ['-' + t['uuid'] for t in removed]
289             )
290
291     def format_description(self):
292         # Task version older than 2.4.0 ignores first word of the
293         # task description if description: prefix is used
294         if self.warrior.version < VERSION_2_4_0:
295             return self._data['description']
296         else:
297             return "description:'{0}'".format(self._data['description'] or '')
298
299     def delete(self):
300         if not self.saved:
301             raise Task.NotSaved("Task needs to be saved before it can be deleted")
302
303         # Refresh the status, and raise exception if the task is deleted
304         self.refresh(only_fields=['status'])
305
306         if self.deleted:
307             raise Task.DeletedTask("Task was already deleted")
308
309         self.warrior.execute_command([self['uuid'], 'delete'])
310
311         # Refresh the status again, so that we have updated info stored
312         self.refresh(only_fields=['status'])
313
314
315     def done(self):
316         if not self.saved:
317             raise Task.NotSaved("Task needs to be saved before it can be completed")
318
319         # Refresh, and raise exception if task is already completed/deleted
320         self.refresh(only_fields=['status'])
321
322         if self.completed:
323             raise Task.CompletedTask("Cannot complete a completed task")
324         elif self.deleted:
325             raise Task.DeletedTask("Deleted task cannot be completed")
326
327         self.warrior.execute_command([self['uuid'], 'done'])
328
329         # Refresh the status again, so that we have updated info stored
330         self.refresh(only_fields=['status'])
331
332     def save(self):
333         args = [self['uuid'], 'modify'] if self.saved else ['add']
334         args.extend(self._get_modified_fields_as_args())
335         output = self.warrior.execute_command(args)
336
337         # Parse out the new ID, if the task is being added for the first time
338         if not self.saved:
339             id_lines = [l for l in output if l.startswith('Created task ')]
340
341             # Complain loudly if it seems that more tasks were created
342             # Should not happen
343             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
344                 raise TaskWarriorException("Unexpected output when creating "
345                                            "task: %s" % '\n'.join(id_lines))
346
347             # Circumvent the ID storage, since ID is considered read-only
348             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
349
350         self.refresh()
351
352     def add_annotation(self, annotation):
353         if not self.saved:
354             raise Task.NotSaved("Task needs to be saved to add annotation")
355
356         args = [self['uuid'], 'annotate', annotation]
357         self.warrior.execute_command(args)
358         self.refresh(only_fields=['annotations'])
359
360     def remove_annotation(self, annotation):
361         if not self.saved:
362             raise Task.NotSaved("Task needs to be saved to remove annotation")
363
364         if isinstance(annotation, TaskAnnotation):
365             annotation = annotation['description']
366         args = [self['uuid'], 'denotate', annotation]
367         self.warrior.execute_command(args)
368         self.refresh(only_fields=['annotations'])
369
370     def _get_modified_fields_as_args(self):
371         args = []
372
373         def add_field(field):
374             # Add the output of format_field method to args list (defaults to
375             # field:value)
376             format_default = lambda k: "{0}:'{1}'".format(k, self._data[k] or '')
377             format_func = getattr(self, 'format_{0}'.format(field),
378                                   lambda: format_default(field))
379             args.append(format_func())
380
381         # If we're modifying saved task, simply pass on all modified fields
382         if self.saved:
383             for field in self._modified_fields:
384                 add_field(field)
385         # For new tasks, pass all fields that make sense
386         else:
387             for field in self._data.keys():
388                 if field in self.read_only_fields:
389                     continue
390                 add_field(field)
391
392         return args
393
394     def refresh(self, only_fields=[]):
395         # Raise error when trying to refresh a task that has not been saved
396         if not self.saved:
397             raise Task.NotSaved("Task needs to be saved to be refreshed")
398
399         # We need to use ID as backup for uuid here for the refreshes
400         # of newly saved tasks. Any other place in the code is fine
401         # with using UUID only.
402         args = [self['uuid'] or self['id'], 'export']
403         new_data = json.loads(self.warrior.execute_command(args)[0])
404         if only_fields:
405             to_update = dict(
406                 [(k, new_data.get(k)) for k in only_fields])
407             self._update_data(to_update, update_original=True)
408         else:
409             self._load_data(new_data)
410
411
412 class TaskFilter(SerializingObject):
413     """
414     A set of parameters to filter the task list with.
415     """
416
417     def __init__(self, filter_params=[]):
418         self.filter_params = filter_params
419
420     def add_filter(self, filter_str):
421         self.filter_params.append(filter_str)
422
423     def add_filter_param(self, key, value):
424         key = key.replace('__', '.')
425
426         # Replace the value with empty string, since that is the
427         # convention in TW for empty values
428         attribute_key = key.split('.')[0]
429         value = self._serialize(attribute_key, value)
430
431         # If we are filtering by uuid:, do not use uuid keyword
432         # due to TW-1452 bug
433         if key == 'uuid':
434             self.filter_params.insert(0, value)
435         else:
436             # We enforce equality match by using is keyword
437             # Also, without using this syntax, filter fails due to TW-1479
438             self.filter_params.append("{0}.is:'{1}'".format(key, value))
439
440     def get_filter_params(self):
441         return [f for f in self.filter_params if f]
442
443     def clone(self):
444         c = self.__class__()
445         c.filter_params = list(self.filter_params)
446         return c
447
448
449 class TaskQuerySet(object):
450     """
451     Represents a lazy lookup for a task objects.
452     """
453
454     def __init__(self, warrior=None, filter_obj=None):
455         self.warrior = warrior
456         self._result_cache = None
457         self.filter_obj = filter_obj or TaskFilter()
458
459     def __deepcopy__(self, memo):
460         """
461         Deep copy of a QuerySet doesn't populate the cache
462         """
463         obj = self.__class__()
464         for k, v in self.__dict__.items():
465             if k in ('_iter', '_result_cache'):
466                 obj.__dict__[k] = None
467             else:
468                 obj.__dict__[k] = copy.deepcopy(v, memo)
469         return obj
470
471     def __repr__(self):
472         data = list(self[:REPR_OUTPUT_SIZE + 1])
473         if len(data) > REPR_OUTPUT_SIZE:
474             data[-1] = "...(remaining elements truncated)..."
475         return repr(data)
476
477     def __len__(self):
478         if self._result_cache is None:
479             self._result_cache = list(self)
480         return len(self._result_cache)
481
482     def __iter__(self):
483         if self._result_cache is None:
484             self._result_cache = self._execute()
485         return iter(self._result_cache)
486
487     def __getitem__(self, k):
488         if self._result_cache is None:
489             self._result_cache = list(self)
490         return self._result_cache.__getitem__(k)
491
492     def __bool__(self):
493         if self._result_cache is not None:
494             return bool(self._result_cache)
495         try:
496             next(iter(self))
497         except StopIteration:
498             return False
499         return True
500
501     def __nonzero__(self):
502         return type(self).__bool__(self)
503
504     def _clone(self, klass=None, **kwargs):
505         if klass is None:
506             klass = self.__class__
507         filter_obj = self.filter_obj.clone()
508         c = klass(warrior=self.warrior, filter_obj=filter_obj)
509         c.__dict__.update(kwargs)
510         return c
511
512     def _execute(self):
513         """
514         Fetch the tasks which match the current filters.
515         """
516         return self.warrior.filter_tasks(self.filter_obj)
517
518     def all(self):
519         """
520         Returns a new TaskQuerySet that is a copy of the current one.
521         """
522         return self._clone()
523
524     def pending(self):
525         return self.filter(status=PENDING)
526
527     def completed(self):
528         return self.filter(status=COMPLETED)
529
530     def filter(self, *args, **kwargs):
531         """
532         Returns a new TaskQuerySet with the given filters added.
533         """
534         clone = self._clone()
535         for f in args:
536             clone.filter_obj.add_filter(f)
537         for key, value in kwargs.items():
538             clone.filter_obj.add_filter_param(key, value)
539         return clone
540
541     def get(self, **kwargs):
542         """
543         Performs the query and returns a single object matching the given
544         keyword arguments.
545         """
546         clone = self.filter(**kwargs)
547         num = len(clone)
548         if num == 1:
549             return clone._result_cache[0]
550         if not num:
551             raise Task.DoesNotExist(
552                 'Task matching query does not exist. '
553                 'Lookup parameters were {0}'.format(kwargs))
554         raise ValueError(
555             'get() returned more than one Task -- it returned {0}! '
556             'Lookup parameters were {1}'.format(num, kwargs))
557
558
559 class TaskWarrior(object):
560     def __init__(self, data_location='~/.task', create=True):
561         data_location = os.path.expanduser(data_location)
562         if create and not os.path.exists(data_location):
563             os.makedirs(data_location)
564         self.config = {
565             'data.location': os.path.expanduser(data_location),
566             'confirmation': 'no',
567         }
568         self.tasks = TaskQuerySet(self)
569         self.version = self._get_version()
570
571     def _get_command_args(self, args, config_override={}):
572         command_args = ['task', 'rc:/']
573         config = self.config.copy()
574         config.update(config_override)
575         for item in config.items():
576             command_args.append('rc.{0}={1}'.format(*item))
577         command_args.extend(map(str, args))
578         return command_args
579
580     def _get_version(self):
581         p = subprocess.Popen(
582                 ['task', '--version'],
583                 stdout=subprocess.PIPE,
584                 stderr=subprocess.PIPE)
585         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
586         return stdout.strip('\n')
587
588     def execute_command(self, args, config_override={}):
589         command_args = self._get_command_args(
590             args, config_override=config_override)
591         logger.debug(' '.join(command_args))
592         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
593                              stderr=subprocess.PIPE)
594         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
595         if p.returncode:
596             if stderr.strip():
597                 error_msg = stderr.strip().splitlines()[-1]
598             else:
599                 error_msg = stdout.strip()
600             raise TaskWarriorException(error_msg)
601         return stdout.strip().split('\n')
602
603     def filter_tasks(self, filter_obj):
604         args = ['export', '--'] + filter_obj.get_filter_params()
605         tasks = []
606         for line in self.execute_command(args):
607             if line:
608                 data = line.strip(',')
609                 try:
610                     filtered_task = Task(self)
611                     filtered_task._load_data(json.loads(data))
612                     tasks.append(filtered_task)
613                 except ValueError:
614                     raise TaskWarriorException('Invalid JSON: %s' % data)
615         return tasks
616
617     def merge_with(self, path, push=False):
618         path = path.rstrip('/') + '/'
619         self.execute_command(['merge', path], config_override={
620             'merge.autopush': 'yes' if push else 'no',
621         })
622
623     def undo(self):
624         self.execute_command(['undo'])