]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Restore _update_data and only_fields, and add a test for race conditions
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, six.string_types):
97             return tags.split(',') if tags else []
98         return tags or []
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = dict((key, self._deserialize(key, value))
115                           for key, value in data.items())
116         # We need to use a copy for original data, so that changes
117         # are not propagated.
118         self._original_data = copy.deepcopy(self._data)
119
120     def _update_data(self, data, update_original=False):
121         """
122         Low level update of the internal _data dict. Data which are coming as
123         updates should already be serialized. If update_original is True, the
124         original_data dict is updated as well.
125         """
126         self._data.update(dict((key, self._deserialize(key, value))
127                                for key, value in data.items()))
128
129         if update_original:
130             self._original_data = copy.deepcopy(self._data)
131
132
133     def __getitem__(self, key):
134         # This is a workaround to make TaskResource non-iterable
135         # over simple index-based iteration
136         try:
137             int(key)
138             raise StopIteration
139         except ValueError:
140             pass
141
142         if key not in self._data:
143             self._data[key] = self._deserialize(key, None)
144
145         return self._data.get(key)
146
147     def __setitem__(self, key, value):
148         if key in self.read_only_fields:
149             raise RuntimeError('Field \'%s\' is read-only' % key)
150         self._data[key] = value
151
152     def __str__(self):
153         s = six.text_type(self.__unicode__())
154         if not six.PY3:
155             s = s.encode('utf-8')
156         return s
157
158     def __repr__(self):
159         return str(self)
160
161
162 class TaskAnnotation(TaskResource):
163     read_only_fields = ['entry', 'description']
164
165     def __init__(self, task, data={}):
166         self.task = task
167         self._load_data(data)
168
169     def remove(self):
170         self.task.remove_annotation(self)
171
172     def __unicode__(self):
173         return self['description']
174
175     def __eq__(self, other):
176         # consider 2 annotations equal if they belong to the same task, and
177         # their data dics are the same
178         return self.task == other.task and self._data == other._data
179
180     __repr__ = __unicode__
181
182
183 class Task(TaskResource):
184     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
185
186     class DoesNotExist(Exception):
187         pass
188
189     class CompletedTask(Exception):
190         """
191         Raised when the operation cannot be performed on the completed task.
192         """
193         pass
194
195     class DeletedTask(Exception):
196         """
197         Raised when the operation cannot be performed on the deleted task.
198         """
199         pass
200
201     class NotSaved(Exception):
202         """
203         Raised when the operation cannot be performed on the task, because
204         it has not been saved to TaskWarrior yet.
205         """
206         pass
207
208     def __init__(self, warrior, **kwargs):
209         self.warrior = warrior
210
211         # Check that user is not able to set read-only value in __init__
212         for key in kwargs.keys():
213             if key in self.read_only_fields:
214                 raise RuntimeError('Field \'%s\' is read-only' % key)
215
216         # We serialize the data in kwargs so that users of the library
217         # do not have to pass different data formats via __setitem__ and
218         # __init__ methods, that would be confusing
219
220         # Rather unfortunate syntax due to python2.6 comaptiblity
221         self._load_data(dict((key, self._serialize(key, value))
222                         for (key, value) in six.iteritems(kwargs)))
223
224     def __unicode__(self):
225         return self['description']
226
227     def __eq__(self, other):
228         if self['uuid'] and other['uuid']:
229             # For saved Tasks, just define equality by equality of uuids
230             return self['uuid'] == other['uuid']
231         else:
232             # If the tasks are not saved, compare the actual instances
233             return id(self) == id(other)
234
235
236     def __hash__(self):
237         if self['uuid']:
238             # For saved Tasks, just define equality by equality of uuids
239             return self['uuid'].__hash__()
240         else:
241             # If the tasks are not saved, return hash of instance id
242             return id(self).__hash__()
243
244     @property
245     def _modified_fields(self):
246         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
247         for key in writable_fields:
248             if self._data.get(key) != self._original_data.get(key):
249                 yield key
250
251     @property
252     def _is_modified(self):
253         return bool(list(self._modified_fields))
254
255     @property
256     def completed(self):
257         return self['status'] == six.text_type('completed')
258
259     @property
260     def deleted(self):
261         return self['status'] == six.text_type('deleted')
262
263     @property
264     def waiting(self):
265         return self['status'] == six.text_type('waiting')
266
267     @property
268     def pending(self):
269         return self['status'] == six.text_type('pending')
270
271     @property
272     def saved(self):
273         return self['uuid'] is not None or self['id'] is not None
274
275     def serialize_depends(self, cur_dependencies):
276         # Check that all the tasks are saved
277         for task in cur_dependencies:
278             if not task.saved:
279                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
280                                     'it can be set as dependency.' % task)
281
282         return super(Task, self).serialize_depends(cur_dependencies)
283
284     def format_depends(self):
285         # We need to generate added and removed dependencies list,
286         # since Taskwarrior does not accept redefining dependencies.
287
288         # This cannot be part of serialize_depends, since we need
289         # to keep a list of all depedencies in the _data dictionary,
290         # not just currently added/removed ones
291
292         old_dependencies = self._original_data.get('depends', set())
293
294         added = self['depends'] - old_dependencies
295         removed = old_dependencies - self['depends']
296
297         # Removed dependencies need to be prefixed with '-'
298         return 'depends:' + ','.join(
299                 [t['uuid'] for t in added] +
300                 ['-' + t['uuid'] for t in removed]
301             )
302
303     def format_description(self):
304         # Task version older than 2.4.0 ignores first word of the
305         # task description if description: prefix is used
306         if self.warrior.version < VERSION_2_4_0:
307             return self._data['description']
308         else:
309             return "description:'{0}'".format(self._data['description'] or '')
310
311     def delete(self):
312         if not self.saved:
313             raise Task.NotSaved("Task needs to be saved before it can be deleted")
314
315         # Refresh the status, and raise exception if the task is deleted
316         self.refresh(only_fields=['status'])
317
318         if self.deleted:
319             raise Task.DeletedTask("Task was already deleted")
320
321         self.warrior.execute_command([self['uuid'], 'delete'])
322
323         # Refresh the status again, so that we have updated info stored
324         self.refresh(only_fields=['status'])
325
326
327     def done(self):
328         if not self.saved:
329             raise Task.NotSaved("Task needs to be saved before it can be completed")
330
331         # Refresh, and raise exception if task is already completed/deleted
332         self.refresh(only_fields=['status'])
333
334         if self.completed:
335             raise Task.CompletedTask("Cannot complete a completed task")
336         elif self.deleted:
337             raise Task.DeletedTask("Deleted task cannot be completed")
338
339         self.warrior.execute_command([self['uuid'], 'done'])
340
341         # Refresh the status again, so that we have updated info stored
342         self.refresh(only_fields=['status'])
343
344     def save(self):
345         if self.saved and not self._is_modified:
346             return
347
348         args = [self['uuid'], 'modify'] if self.saved else ['add']
349         args.extend(self._get_modified_fields_as_args())
350         output = self.warrior.execute_command(args)
351
352         # Parse out the new ID, if the task is being added for the first time
353         if not self.saved:
354             id_lines = [l for l in output if l.startswith('Created task ')]
355
356             # Complain loudly if it seems that more tasks were created
357             # Should not happen
358             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
359                 raise TaskWarriorException("Unexpected output when creating "
360                                            "task: %s" % '\n'.join(id_lines))
361
362             # Circumvent the ID storage, since ID is considered read-only
363             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
364
365         self.refresh()
366
367     def add_annotation(self, annotation):
368         if not self.saved:
369             raise Task.NotSaved("Task needs to be saved to add annotation")
370
371         args = [self['uuid'], 'annotate', annotation]
372         self.warrior.execute_command(args)
373         self.refresh(only_fields=['annotations'])
374
375     def remove_annotation(self, annotation):
376         if not self.saved:
377             raise Task.NotSaved("Task needs to be saved to remove annotation")
378
379         if isinstance(annotation, TaskAnnotation):
380             annotation = annotation['description']
381         args = [self['uuid'], 'denotate', annotation]
382         self.warrior.execute_command(args)
383         self.refresh(only_fields=['annotations'])
384
385     def _get_modified_fields_as_args(self):
386         args = []
387
388         def add_field(field):
389             # Add the output of format_field method to args list (defaults to
390             # field:value)
391             serialized_value = self._serialize(field, self._data[field]) or ''
392             format_default = lambda: "{0}:{1}".format(
393                 field,
394                 "'{0}'".format(serialized_value) if serialized_value else ''
395             )
396             format_func = getattr(self, 'format_{0}'.format(field),
397                                   format_default)
398             args.append(format_func())
399
400         # If we're modifying saved task, simply pass on all modified fields
401         if self.saved:
402             for field in self._modified_fields:
403                 add_field(field)
404         # For new tasks, pass all fields that make sense
405         else:
406             for field in self._data.keys():
407                 if field in self.read_only_fields:
408                     continue
409                 add_field(field)
410
411         return args
412
413     def refresh(self, only_fields=[]):
414         # Raise error when trying to refresh a task that has not been saved
415         if not self.saved:
416             raise Task.NotSaved("Task needs to be saved to be refreshed")
417
418         # We need to use ID as backup for uuid here for the refreshes
419         # of newly saved tasks. Any other place in the code is fine
420         # with using UUID only.
421         args = [self['uuid'] or self['id'], 'export']
422         new_data = json.loads(self.warrior.execute_command(args)[0])
423         if only_fields:
424             to_update = dict(
425                 [(k, new_data.get(k)) for k in only_fields])
426             self._update_data(to_update, update_original=True)
427         else:
428             self._load_data(new_data)
429
430
431 class TaskFilter(SerializingObject):
432     """
433     A set of parameters to filter the task list with.
434     """
435
436     def __init__(self, filter_params=[]):
437         self.filter_params = filter_params
438
439     def add_filter(self, filter_str):
440         self.filter_params.append(filter_str)
441
442     def add_filter_param(self, key, value):
443         key = key.replace('__', '.')
444
445         # Replace the value with empty string, since that is the
446         # convention in TW for empty values
447         attribute_key = key.split('.')[0]
448         value = self._serialize(attribute_key, value)
449
450         # If we are filtering by uuid:, do not use uuid keyword
451         # due to TW-1452 bug
452         if key == 'uuid':
453             self.filter_params.insert(0, value)
454         else:
455             # Surround value with aphostrophes unless it's a empty string
456             value = "'%s'" % value if value else ''
457
458             # We enforce equality match by using 'is' (or 'none') modifier
459             # Without using this syntax, filter fails due to TW-1479
460             modifier = '.is' if value else '.none'
461             key = key + modifier if '.' not in key else key
462
463             self.filter_params.append("{0}:{1}".format(key, value))
464
465     def get_filter_params(self):
466         return [f for f in self.filter_params if f]
467
468     def clone(self):
469         c = self.__class__()
470         c.filter_params = list(self.filter_params)
471         return c
472
473
474 class TaskQuerySet(object):
475     """
476     Represents a lazy lookup for a task objects.
477     """
478
479     def __init__(self, warrior=None, filter_obj=None):
480         self.warrior = warrior
481         self._result_cache = None
482         self.filter_obj = filter_obj or TaskFilter()
483
484     def __deepcopy__(self, memo):
485         """
486         Deep copy of a QuerySet doesn't populate the cache
487         """
488         obj = self.__class__()
489         for k, v in self.__dict__.items():
490             if k in ('_iter', '_result_cache'):
491                 obj.__dict__[k] = None
492             else:
493                 obj.__dict__[k] = copy.deepcopy(v, memo)
494         return obj
495
496     def __repr__(self):
497         data = list(self[:REPR_OUTPUT_SIZE + 1])
498         if len(data) > REPR_OUTPUT_SIZE:
499             data[-1] = "...(remaining elements truncated)..."
500         return repr(data)
501
502     def __len__(self):
503         if self._result_cache is None:
504             self._result_cache = list(self)
505         return len(self._result_cache)
506
507     def __iter__(self):
508         if self._result_cache is None:
509             self._result_cache = self._execute()
510         return iter(self._result_cache)
511
512     def __getitem__(self, k):
513         if self._result_cache is None:
514             self._result_cache = list(self)
515         return self._result_cache.__getitem__(k)
516
517     def __bool__(self):
518         if self._result_cache is not None:
519             return bool(self._result_cache)
520         try:
521             next(iter(self))
522         except StopIteration:
523             return False
524         return True
525
526     def __nonzero__(self):
527         return type(self).__bool__(self)
528
529     def _clone(self, klass=None, **kwargs):
530         if klass is None:
531             klass = self.__class__
532         filter_obj = self.filter_obj.clone()
533         c = klass(warrior=self.warrior, filter_obj=filter_obj)
534         c.__dict__.update(kwargs)
535         return c
536
537     def _execute(self):
538         """
539         Fetch the tasks which match the current filters.
540         """
541         return self.warrior.filter_tasks(self.filter_obj)
542
543     def all(self):
544         """
545         Returns a new TaskQuerySet that is a copy of the current one.
546         """
547         return self._clone()
548
549     def pending(self):
550         return self.filter(status=PENDING)
551
552     def completed(self):
553         return self.filter(status=COMPLETED)
554
555     def filter(self, *args, **kwargs):
556         """
557         Returns a new TaskQuerySet with the given filters added.
558         """
559         clone = self._clone()
560         for f in args:
561             clone.filter_obj.add_filter(f)
562         for key, value in kwargs.items():
563             clone.filter_obj.add_filter_param(key, value)
564         return clone
565
566     def get(self, **kwargs):
567         """
568         Performs the query and returns a single object matching the given
569         keyword arguments.
570         """
571         clone = self.filter(**kwargs)
572         num = len(clone)
573         if num == 1:
574             return clone._result_cache[0]
575         if not num:
576             raise Task.DoesNotExist(
577                 'Task matching query does not exist. '
578                 'Lookup parameters were {0}'.format(kwargs))
579         raise ValueError(
580             'get() returned more than one Task -- it returned {0}! '
581             'Lookup parameters were {1}'.format(num, kwargs))
582
583
584 class TaskWarrior(object):
585     def __init__(self, data_location='~/.task', create=True):
586         data_location = os.path.expanduser(data_location)
587         if create and not os.path.exists(data_location):
588             os.makedirs(data_location)
589         self.config = {
590             'data.location': os.path.expanduser(data_location),
591             'confirmation': 'no',
592             'dependency.confirmation': 'no', # See TW-1483 or taskrc man page
593         }
594         self.tasks = TaskQuerySet(self)
595         self.version = self._get_version()
596
597     def _get_command_args(self, args, config_override={}):
598         command_args = ['task', 'rc:/']
599         config = self.config.copy()
600         config.update(config_override)
601         for item in config.items():
602             command_args.append('rc.{0}={1}'.format(*item))
603         command_args.extend(map(str, args))
604         return command_args
605
606     def _get_version(self):
607         p = subprocess.Popen(
608                 ['task', '--version'],
609                 stdout=subprocess.PIPE,
610                 stderr=subprocess.PIPE)
611         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
612         return stdout.strip('\n')
613
614     def execute_command(self, args, config_override={}):
615         command_args = self._get_command_args(
616             args, config_override=config_override)
617         logger.debug(' '.join(command_args))
618         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
619                              stderr=subprocess.PIPE)
620         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
621         if p.returncode:
622             if stderr.strip():
623                 error_msg = stderr.strip().splitlines()[-1]
624             else:
625                 error_msg = stdout.strip()
626             raise TaskWarriorException(error_msg)
627         return stdout.strip().split('\n')
628
629     def filter_tasks(self, filter_obj):
630         args = ['export', '--'] + filter_obj.get_filter_params()
631         tasks = []
632         for line in self.execute_command(args):
633             if line:
634                 data = line.strip(',')
635                 try:
636                     filtered_task = Task(self)
637                     filtered_task._load_data(json.loads(data))
638                     tasks.append(filtered_task)
639                 except ValueError:
640                     raise TaskWarriorException('Invalid JSON: %s' % data)
641         return tasks
642
643     def merge_with(self, path, push=False):
644         path = path.rstrip('/') + '/'
645         self.execute_command(['merge', path], config_override={
646             'merge.autopush': 'yes' if push else 'no',
647         })
648
649     def undo(self):
650         self.execute_command(['undo'])