]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Task: Do not yield corner case of fake removals as modified field
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, six.string_types):
97             return tags.split(',') if tags else []
98         return tags or []
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = dict((key, self._deserialize(key, value))
115                           for key, value in data.items())
116         # We need to use a copy for original data, so that changes
117         # are not propagated.
118         self._original_data = copy.deepcopy(self._data)
119
120     def _update_data(self, data, update_original=False):
121         """
122         Low level update of the internal _data dict. Data which are coming as
123         updates should already be serialized. If update_original is True, the
124         original_data dict is updated as well.
125         """
126         self._data.update(dict((key, self._deserialize(key, value))
127                                for key, value in data.items()))
128
129         if update_original:
130             self._original_data = copy.deepcopy(self._data)
131
132
133     def __getitem__(self, key):
134         # This is a workaround to make TaskResource non-iterable
135         # over simple index-based iteration
136         try:
137             int(key)
138             raise StopIteration
139         except ValueError:
140             pass
141
142         if key not in self._data:
143             self._data[key] = self._deserialize(key, None)
144
145         return self._data.get(key)
146
147     def __setitem__(self, key, value):
148         if key in self.read_only_fields:
149             raise RuntimeError('Field \'%s\' is read-only' % key)
150         self._data[key] = value
151
152     def __str__(self):
153         s = six.text_type(self.__unicode__())
154         if not six.PY3:
155             s = s.encode('utf-8')
156         return s
157
158     def __repr__(self):
159         return str(self)
160
161
162 class TaskAnnotation(TaskResource):
163     read_only_fields = ['entry', 'description']
164
165     def __init__(self, task, data={}):
166         self.task = task
167         self._load_data(data)
168
169     def remove(self):
170         self.task.remove_annotation(self)
171
172     def __unicode__(self):
173         return self['description']
174
175     def __eq__(self, other):
176         # consider 2 annotations equal if they belong to the same task, and
177         # their data dics are the same
178         return self.task == other.task and self._data == other._data
179
180     __repr__ = __unicode__
181
182
183 class Task(TaskResource):
184     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
185
186     class DoesNotExist(Exception):
187         pass
188
189     class CompletedTask(Exception):
190         """
191         Raised when the operation cannot be performed on the completed task.
192         """
193         pass
194
195     class DeletedTask(Exception):
196         """
197         Raised when the operation cannot be performed on the deleted task.
198         """
199         pass
200
201     class NotSaved(Exception):
202         """
203         Raised when the operation cannot be performed on the task, because
204         it has not been saved to TaskWarrior yet.
205         """
206         pass
207
208     def __init__(self, warrior, **kwargs):
209         self.warrior = warrior
210
211         # Check that user is not able to set read-only value in __init__
212         for key in kwargs.keys():
213             if key in self.read_only_fields:
214                 raise RuntimeError('Field \'%s\' is read-only' % key)
215
216         # We serialize the data in kwargs so that users of the library
217         # do not have to pass different data formats via __setitem__ and
218         # __init__ methods, that would be confusing
219
220         # Rather unfortunate syntax due to python2.6 comaptiblity
221         self._load_data(dict((key, self._serialize(key, value))
222                         for (key, value) in six.iteritems(kwargs)))
223
224     def __unicode__(self):
225         return self['description']
226
227     def __eq__(self, other):
228         if self['uuid'] and other['uuid']:
229             # For saved Tasks, just define equality by equality of uuids
230             return self['uuid'] == other['uuid']
231         else:
232             # If the tasks are not saved, compare the actual instances
233             return id(self) == id(other)
234
235
236     def __hash__(self):
237         if self['uuid']:
238             # For saved Tasks, just define equality by equality of uuids
239             return self['uuid'].__hash__()
240         else:
241             # If the tasks are not saved, return hash of instance id
242             return id(self).__hash__()
243
244     @property
245     def _modified_fields(self):
246         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
247         for key in writable_fields:
248             new_value = self._data.get(key)
249             old_value = self._original_data.get(key)
250
251             # Make sure not to mark data removal as modified field if the
252             # field originally had some empty value
253             if key in self._data and not new_value and not old_value:
254                 continue
255
256             if new_value != old_value:
257                 yield key
258
259     @property
260     def _is_modified(self):
261         return bool(list(self._modified_fields))
262
263     @property
264     def completed(self):
265         return self['status'] == six.text_type('completed')
266
267     @property
268     def deleted(self):
269         return self['status'] == six.text_type('deleted')
270
271     @property
272     def waiting(self):
273         return self['status'] == six.text_type('waiting')
274
275     @property
276     def pending(self):
277         return self['status'] == six.text_type('pending')
278
279     @property
280     def saved(self):
281         return self['uuid'] is not None or self['id'] is not None
282
283     def serialize_depends(self, cur_dependencies):
284         # Check that all the tasks are saved
285         for task in cur_dependencies:
286             if not task.saved:
287                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
288                                     'it can be set as dependency.' % task)
289
290         return super(Task, self).serialize_depends(cur_dependencies)
291
292     def format_depends(self):
293         # We need to generate added and removed dependencies list,
294         # since Taskwarrior does not accept redefining dependencies.
295
296         # This cannot be part of serialize_depends, since we need
297         # to keep a list of all depedencies in the _data dictionary,
298         # not just currently added/removed ones
299
300         old_dependencies = self._original_data.get('depends', set())
301
302         added = self['depends'] - old_dependencies
303         removed = old_dependencies - self['depends']
304
305         # Removed dependencies need to be prefixed with '-'
306         return 'depends:' + ','.join(
307                 [t['uuid'] for t in added] +
308                 ['-' + t['uuid'] for t in removed]
309             )
310
311     def format_description(self):
312         # Task version older than 2.4.0 ignores first word of the
313         # task description if description: prefix is used
314         if self.warrior.version < VERSION_2_4_0:
315             return self._data['description']
316         else:
317             return "description:'{0}'".format(self._data['description'] or '')
318
319     def delete(self):
320         if not self.saved:
321             raise Task.NotSaved("Task needs to be saved before it can be deleted")
322
323         # Refresh the status, and raise exception if the task is deleted
324         self.refresh(only_fields=['status'])
325
326         if self.deleted:
327             raise Task.DeletedTask("Task was already deleted")
328
329         self.warrior.execute_command([self['uuid'], 'delete'])
330
331         # Refresh the status again, so that we have updated info stored
332         self.refresh(only_fields=['status'])
333
334
335     def done(self):
336         if not self.saved:
337             raise Task.NotSaved("Task needs to be saved before it can be completed")
338
339         # Refresh, and raise exception if task is already completed/deleted
340         self.refresh(only_fields=['status'])
341
342         if self.completed:
343             raise Task.CompletedTask("Cannot complete a completed task")
344         elif self.deleted:
345             raise Task.DeletedTask("Deleted task cannot be completed")
346
347         self.warrior.execute_command([self['uuid'], 'done'])
348
349         # Refresh the status again, so that we have updated info stored
350         self.refresh(only_fields=['status'])
351
352     def save(self):
353         if self.saved and not self._is_modified:
354             return
355
356         args = [self['uuid'], 'modify'] if self.saved else ['add']
357         args.extend(self._get_modified_fields_as_args())
358         output = self.warrior.execute_command(args)
359
360         # Parse out the new ID, if the task is being added for the first time
361         if not self.saved:
362             id_lines = [l for l in output if l.startswith('Created task ')]
363
364             # Complain loudly if it seems that more tasks were created
365             # Should not happen
366             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
367                 raise TaskWarriorException("Unexpected output when creating "
368                                            "task: %s" % '\n'.join(id_lines))
369
370             # Circumvent the ID storage, since ID is considered read-only
371             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
372
373         self.refresh()
374
375     def add_annotation(self, annotation):
376         if not self.saved:
377             raise Task.NotSaved("Task needs to be saved to add annotation")
378
379         args = [self['uuid'], 'annotate', annotation]
380         self.warrior.execute_command(args)
381         self.refresh(only_fields=['annotations'])
382
383     def remove_annotation(self, annotation):
384         if not self.saved:
385             raise Task.NotSaved("Task needs to be saved to remove annotation")
386
387         if isinstance(annotation, TaskAnnotation):
388             annotation = annotation['description']
389         args = [self['uuid'], 'denotate', annotation]
390         self.warrior.execute_command(args)
391         self.refresh(only_fields=['annotations'])
392
393     def _get_modified_fields_as_args(self):
394         args = []
395
396         def add_field(field):
397             # Add the output of format_field method to args list (defaults to
398             # field:value)
399             serialized_value = self._serialize(field, self._data[field]) or ''
400             format_default = lambda: "{0}:{1}".format(
401                 field,
402                 "'{0}'".format(serialized_value) if serialized_value else ''
403             )
404             format_func = getattr(self, 'format_{0}'.format(field),
405                                   format_default)
406             args.append(format_func())
407
408         # If we're modifying saved task, simply pass on all modified fields
409         if self.saved:
410             for field in self._modified_fields:
411                 add_field(field)
412         # For new tasks, pass all fields that make sense
413         else:
414             for field in self._data.keys():
415                 if field in self.read_only_fields:
416                     continue
417                 add_field(field)
418
419         return args
420
421     def refresh(self, only_fields=[]):
422         # Raise error when trying to refresh a task that has not been saved
423         if not self.saved:
424             raise Task.NotSaved("Task needs to be saved to be refreshed")
425
426         # We need to use ID as backup for uuid here for the refreshes
427         # of newly saved tasks. Any other place in the code is fine
428         # with using UUID only.
429         args = [self['uuid'] or self['id'], 'export']
430         new_data = json.loads(self.warrior.execute_command(args)[0])
431         if only_fields:
432             to_update = dict(
433                 [(k, new_data.get(k)) for k in only_fields])
434             self._update_data(to_update, update_original=True)
435         else:
436             self._load_data(new_data)
437
438
439 class TaskFilter(SerializingObject):
440     """
441     A set of parameters to filter the task list with.
442     """
443
444     def __init__(self, filter_params=[]):
445         self.filter_params = filter_params
446
447     def add_filter(self, filter_str):
448         self.filter_params.append(filter_str)
449
450     def add_filter_param(self, key, value):
451         key = key.replace('__', '.')
452
453         # Replace the value with empty string, since that is the
454         # convention in TW for empty values
455         attribute_key = key.split('.')[0]
456         value = self._serialize(attribute_key, value)
457
458         # If we are filtering by uuid:, do not use uuid keyword
459         # due to TW-1452 bug
460         if key == 'uuid':
461             self.filter_params.insert(0, value)
462         else:
463             # Surround value with aphostrophes unless it's a empty string
464             value = "'%s'" % value if value else ''
465
466             # We enforce equality match by using 'is' (or 'none') modifier
467             # Without using this syntax, filter fails due to TW-1479
468             modifier = '.is' if value else '.none'
469             key = key + modifier if '.' not in key else key
470
471             self.filter_params.append("{0}:{1}".format(key, value))
472
473     def get_filter_params(self):
474         return [f for f in self.filter_params if f]
475
476     def clone(self):
477         c = self.__class__()
478         c.filter_params = list(self.filter_params)
479         return c
480
481
482 class TaskQuerySet(object):
483     """
484     Represents a lazy lookup for a task objects.
485     """
486
487     def __init__(self, warrior=None, filter_obj=None):
488         self.warrior = warrior
489         self._result_cache = None
490         self.filter_obj = filter_obj or TaskFilter()
491
492     def __deepcopy__(self, memo):
493         """
494         Deep copy of a QuerySet doesn't populate the cache
495         """
496         obj = self.__class__()
497         for k, v in self.__dict__.items():
498             if k in ('_iter', '_result_cache'):
499                 obj.__dict__[k] = None
500             else:
501                 obj.__dict__[k] = copy.deepcopy(v, memo)
502         return obj
503
504     def __repr__(self):
505         data = list(self[:REPR_OUTPUT_SIZE + 1])
506         if len(data) > REPR_OUTPUT_SIZE:
507             data[-1] = "...(remaining elements truncated)..."
508         return repr(data)
509
510     def __len__(self):
511         if self._result_cache is None:
512             self._result_cache = list(self)
513         return len(self._result_cache)
514
515     def __iter__(self):
516         if self._result_cache is None:
517             self._result_cache = self._execute()
518         return iter(self._result_cache)
519
520     def __getitem__(self, k):
521         if self._result_cache is None:
522             self._result_cache = list(self)
523         return self._result_cache.__getitem__(k)
524
525     def __bool__(self):
526         if self._result_cache is not None:
527             return bool(self._result_cache)
528         try:
529             next(iter(self))
530         except StopIteration:
531             return False
532         return True
533
534     def __nonzero__(self):
535         return type(self).__bool__(self)
536
537     def _clone(self, klass=None, **kwargs):
538         if klass is None:
539             klass = self.__class__
540         filter_obj = self.filter_obj.clone()
541         c = klass(warrior=self.warrior, filter_obj=filter_obj)
542         c.__dict__.update(kwargs)
543         return c
544
545     def _execute(self):
546         """
547         Fetch the tasks which match the current filters.
548         """
549         return self.warrior.filter_tasks(self.filter_obj)
550
551     def all(self):
552         """
553         Returns a new TaskQuerySet that is a copy of the current one.
554         """
555         return self._clone()
556
557     def pending(self):
558         return self.filter(status=PENDING)
559
560     def completed(self):
561         return self.filter(status=COMPLETED)
562
563     def filter(self, *args, **kwargs):
564         """
565         Returns a new TaskQuerySet with the given filters added.
566         """
567         clone = self._clone()
568         for f in args:
569             clone.filter_obj.add_filter(f)
570         for key, value in kwargs.items():
571             clone.filter_obj.add_filter_param(key, value)
572         return clone
573
574     def get(self, **kwargs):
575         """
576         Performs the query and returns a single object matching the given
577         keyword arguments.
578         """
579         clone = self.filter(**kwargs)
580         num = len(clone)
581         if num == 1:
582             return clone._result_cache[0]
583         if not num:
584             raise Task.DoesNotExist(
585                 'Task matching query does not exist. '
586                 'Lookup parameters were {0}'.format(kwargs))
587         raise ValueError(
588             'get() returned more than one Task -- it returned {0}! '
589             'Lookup parameters were {1}'.format(num, kwargs))
590
591
592 class TaskWarrior(object):
593     def __init__(self, data_location='~/.task', create=True):
594         data_location = os.path.expanduser(data_location)
595         if create and not os.path.exists(data_location):
596             os.makedirs(data_location)
597         self.config = {
598             'data.location': os.path.expanduser(data_location),
599             'confirmation': 'no',
600             'dependency.confirmation': 'no', # See TW-1483 or taskrc man page
601         }
602         self.tasks = TaskQuerySet(self)
603         self.version = self._get_version()
604
605     def _get_command_args(self, args, config_override={}):
606         command_args = ['task', 'rc:/']
607         config = self.config.copy()
608         config.update(config_override)
609         for item in config.items():
610             command_args.append('rc.{0}={1}'.format(*item))
611         command_args.extend(map(str, args))
612         return command_args
613
614     def _get_version(self):
615         p = subprocess.Popen(
616                 ['task', '--version'],
617                 stdout=subprocess.PIPE,
618                 stderr=subprocess.PIPE)
619         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
620         return stdout.strip('\n')
621
622     def execute_command(self, args, config_override={}):
623         command_args = self._get_command_args(
624             args, config_override=config_override)
625         logger.debug(' '.join(command_args))
626         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
627                              stderr=subprocess.PIPE)
628         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
629         if p.returncode:
630             if stderr.strip():
631                 error_msg = stderr.strip().splitlines()[-1]
632             else:
633                 error_msg = stdout.strip()
634             raise TaskWarriorException(error_msg)
635         return stdout.strip().split('\n')
636
637     def filter_tasks(self, filter_obj):
638         args = ['export', '--'] + filter_obj.get_filter_params()
639         tasks = []
640         for line in self.execute_command(args):
641             if line:
642                 data = line.strip(',')
643                 try:
644                     filtered_task = Task(self)
645                     filtered_task._load_data(json.loads(data))
646                     tasks.append(filtered_task)
647                 except ValueError:
648                     raise TaskWarriorException('Invalid JSON: %s' % data)
649         return tasks
650
651     def merge_with(self, path, push=False):
652         path = path.rstrip('/') + '/'
653         self.execute_command(['merge', path], config_override={
654             'merge.autopush': 'yes' if push else 'no',
655         })
656
657     def undo(self):
658         self.execute_command(['undo'])