]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

d5913af9adda62e66f05bb7931410c94a0000f37
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, basestring):
97             return tags.split(',') if tags else []
98         return tags or []
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = dict((key, self._deserialize(key, value))
115                           for key, value in data.items())
116         # We need to use a copy for original data, so that changes
117         # are not propagated.
118         self._original_data = copy.deepcopy(self._data)
119
120     def __getitem__(self, key):
121         # This is a workaround to make TaskResource non-iterable
122         # over simple index-based iteration
123         try:
124             int(key)
125             raise StopIteration
126         except ValueError:
127             pass
128
129         return self._data.get(key) or self._deserialize(key, None)
130
131     def __setitem__(self, key, value):
132         if key in self.read_only_fields:
133             raise RuntimeError('Field \'%s\' is read-only' % key)
134         self._data[key] = value
135
136     def __str__(self):
137         s = six.text_type(self.__unicode__())
138         if not six.PY3:
139             s = s.encode('utf-8')
140         return s
141
142     def __repr__(self):
143         return str(self)
144
145
146 class TaskAnnotation(TaskResource):
147     read_only_fields = ['entry', 'description']
148
149     def __init__(self, task, data={}):
150         self.task = task
151         self._load_data(data)
152
153     def remove(self):
154         self.task.remove_annotation(self)
155
156     def __unicode__(self):
157         return self['description']
158
159     __repr__ = __unicode__
160
161
162 class Task(TaskResource):
163     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
164
165     class DoesNotExist(Exception):
166         pass
167
168     class CompletedTask(Exception):
169         """
170         Raised when the operation cannot be performed on the completed task.
171         """
172         pass
173
174     class DeletedTask(Exception):
175         """
176         Raised when the operation cannot be performed on the deleted task.
177         """
178         pass
179
180     class NotSaved(Exception):
181         """
182         Raised when the operation cannot be performed on the task, because
183         it has not been saved to TaskWarrior yet.
184         """
185         pass
186
187     def __init__(self, warrior, **kwargs):
188         self.warrior = warrior
189
190         # Check that user is not able to set read-only value in __init__
191         for key in kwargs.keys():
192             if key in self.read_only_fields:
193                 raise RuntimeError('Field \'%s\' is read-only' % key)
194
195         # We serialize the data in kwargs so that users of the library
196         # do not have to pass different data formats via __setitem__ and
197         # __init__ methods, that would be confusing
198
199         # Rather unfortunate syntax due to python2.6 comaptiblity
200         self._load_data(dict((key, self._serialize(key, value))
201                         for (key, value) in six.iteritems(kwargs)))
202
203     def __unicode__(self):
204         return self['description']
205
206     def __eq__(self, other):
207         if self['uuid'] and other['uuid']:
208             # For saved Tasks, just define equality by equality of uuids
209             return self['uuid'] == other['uuid']
210         else:
211             # If the tasks are not saved, compare the actual instances
212             return id(self) == id(other)
213
214
215     def __hash__(self):
216         if self['uuid']:
217             # For saved Tasks, just define equality by equality of uuids
218             return self['uuid'].__hash__()
219         else:
220             # If the tasks are not saved, return hash of instance id
221             return id(self).__hash__()
222
223     @property
224     def _modified_fields(self):
225         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
226         for key in writable_fields:
227             if self._data.get(key) != self._original_data.get(key):
228                 yield key
229
230     @property
231     def _is_modified(self):
232         return bool(list(self._modified_fields))
233
234     @property
235     def completed(self):
236         return self['status'] == six.text_type('completed')
237
238     @property
239     def deleted(self):
240         return self['status'] == six.text_type('deleted')
241
242     @property
243     def waiting(self):
244         return self['status'] == six.text_type('waiting')
245
246     @property
247     def pending(self):
248         return self['status'] == six.text_type('pending')
249
250     @property
251     def saved(self):
252         return self['uuid'] is not None or self['id'] is not None
253
254     def serialize_depends(self, cur_dependencies):
255         # Check that all the tasks are saved
256         for task in cur_dependencies:
257             if not task.saved:
258                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
259                                     'it can be set as dependency.' % task)
260
261         return super(Task, self).serialize_depends(cur_dependencies)
262
263     def format_depends(self):
264         # We need to generate added and removed dependencies list,
265         # since Taskwarrior does not accept redefining dependencies.
266
267         # This cannot be part of serialize_depends, since we need
268         # to keep a list of all depedencies in the _data dictionary,
269         # not just currently added/removed ones
270
271         old_dependencies = self._original_data.get('depends', set())
272
273         added = self['depends'] - old_dependencies
274         removed = old_dependencies - self['depends']
275
276         # Removed dependencies need to be prefixed with '-'
277         return 'depends:' + ','.join(
278                 [t['uuid'] for t in added] +
279                 ['-' + t['uuid'] for t in removed]
280             )
281
282     def format_description(self):
283         # Task version older than 2.4.0 ignores first word of the
284         # task description if description: prefix is used
285         if self.warrior.version < VERSION_2_4_0:
286             return self._data['description']
287         else:
288             return "description:'{0}'".format(self._data['description'] or '')
289
290     def delete(self):
291         if not self.saved:
292             raise Task.NotSaved("Task needs to be saved before it can be deleted")
293
294         # Refresh the status, and raise exception if the task is deleted
295         self.refresh()
296
297         if self.deleted:
298             raise Task.DeletedTask("Task was already deleted")
299
300         self.warrior.execute_command([self['uuid'], 'delete'])
301
302         # Refresh the status again, so that we have updated info stored
303         self.refresh()
304
305
306     def done(self):
307         if not self.saved:
308             raise Task.NotSaved("Task needs to be saved before it can be completed")
309
310         # Refresh, and raise exception if task is already completed/deleted
311         self.refresh()
312
313         if self.completed:
314             raise Task.CompletedTask("Cannot complete a completed task")
315         elif self.deleted:
316             raise Task.DeletedTask("Deleted task cannot be completed")
317
318         self.warrior.execute_command([self['uuid'], 'done'])
319
320         # Refresh the status again, so that we have updated info stored
321         self.refresh()
322
323     def save(self):
324         if self.saved and not self._is_modified:
325             return
326
327         args = [self['uuid'], 'modify'] if self.saved else ['add']
328         args.extend(self._get_modified_fields_as_args())
329         output = self.warrior.execute_command(args)
330
331         # Parse out the new ID, if the task is being added for the first time
332         if not self.saved:
333             id_lines = [l for l in output if l.startswith('Created task ')]
334
335             # Complain loudly if it seems that more tasks were created
336             # Should not happen
337             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
338                 raise TaskWarriorException("Unexpected output when creating "
339                                            "task: %s" % '\n'.join(id_lines))
340
341             # Circumvent the ID storage, since ID is considered read-only
342             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
343
344         self.refresh()
345
346     def add_annotation(self, annotation):
347         if not self.saved:
348             raise Task.NotSaved("Task needs to be saved to add annotation")
349
350         args = [self['uuid'], 'annotate', annotation]
351         self.warrior.execute_command(args)
352         self.refresh()
353
354     def remove_annotation(self, annotation):
355         if not self.saved:
356             raise Task.NotSaved("Task needs to be saved to remove annotation")
357
358         if isinstance(annotation, TaskAnnotation):
359             annotation = annotation['description']
360         args = [self['uuid'], 'denotate', annotation]
361         self.warrior.execute_command(args)
362         self.refresh()
363
364     def _get_modified_fields_as_args(self):
365         args = []
366
367         def add_field(field):
368             # Add the output of format_field method to args list (defaults to
369             # field:value)
370             serialized_value = self._serialize(field, self._data[field]) or ''
371             format_default = lambda: "{0}:'{1}'".format(field, serialized_value)
372             format_func = getattr(self, 'format_{0}'.format(field),
373                                   format_default)
374             args.append(format_func())
375
376         # If we're modifying saved task, simply pass on all modified fields
377         if self.saved:
378             for field in self._modified_fields:
379                 add_field(field)
380         # For new tasks, pass all fields that make sense
381         else:
382             for field in self._data.keys():
383                 if field in self.read_only_fields:
384                     continue
385                 add_field(field)
386
387         return args
388
389     def refresh(self):
390         # Raise error when trying to refresh a task that has not been saved
391         if not self.saved:
392             raise Task.NotSaved("Task needs to be saved to be refreshed")
393
394         # We need to use ID as backup for uuid here for the refreshes
395         # of newly saved tasks. Any other place in the code is fine
396         # with using UUID only.
397         args = [self['uuid'] or self['id'], 'export']
398         new_data = json.loads(self.warrior.execute_command(args)[0])
399         self._load_data(new_data)
400
401
402 class TaskFilter(SerializingObject):
403     """
404     A set of parameters to filter the task list with.
405     """
406
407     def __init__(self, filter_params=[]):
408         self.filter_params = filter_params
409
410     def add_filter(self, filter_str):
411         self.filter_params.append(filter_str)
412
413     def add_filter_param(self, key, value):
414         key = key.replace('__', '.')
415
416         # Replace the value with empty string, since that is the
417         # convention in TW for empty values
418         attribute_key = key.split('.')[0]
419         value = self._serialize(attribute_key, value)
420
421         # If we are filtering by uuid:, do not use uuid keyword
422         # due to TW-1452 bug
423         if key == 'uuid':
424             self.filter_params.insert(0, value)
425         else:
426             # Surround value with aphostrophes unless it's a empty string
427             value = "'%s'" % value if value else ''
428
429             # We enforce equality match by using 'is' (or 'none') modifier
430             # Without using this syntax, filter fails due to TW-1479
431             modifier = '.is' if value else '.none'
432             key = key + modifier if '.' not in key else key
433
434             self.filter_params.append("{0}:{1}".format(key, value))
435
436     def get_filter_params(self):
437         return [f for f in self.filter_params if f]
438
439     def clone(self):
440         c = self.__class__()
441         c.filter_params = list(self.filter_params)
442         return c
443
444
445 class TaskQuerySet(object):
446     """
447     Represents a lazy lookup for a task objects.
448     """
449
450     def __init__(self, warrior=None, filter_obj=None):
451         self.warrior = warrior
452         self._result_cache = None
453         self.filter_obj = filter_obj or TaskFilter()
454
455     def __deepcopy__(self, memo):
456         """
457         Deep copy of a QuerySet doesn't populate the cache
458         """
459         obj = self.__class__()
460         for k, v in self.__dict__.items():
461             if k in ('_iter', '_result_cache'):
462                 obj.__dict__[k] = None
463             else:
464                 obj.__dict__[k] = copy.deepcopy(v, memo)
465         return obj
466
467     def __repr__(self):
468         data = list(self[:REPR_OUTPUT_SIZE + 1])
469         if len(data) > REPR_OUTPUT_SIZE:
470             data[-1] = "...(remaining elements truncated)..."
471         return repr(data)
472
473     def __len__(self):
474         if self._result_cache is None:
475             self._result_cache = list(self)
476         return len(self._result_cache)
477
478     def __iter__(self):
479         if self._result_cache is None:
480             self._result_cache = self._execute()
481         return iter(self._result_cache)
482
483     def __getitem__(self, k):
484         if self._result_cache is None:
485             self._result_cache = list(self)
486         return self._result_cache.__getitem__(k)
487
488     def __bool__(self):
489         if self._result_cache is not None:
490             return bool(self._result_cache)
491         try:
492             next(iter(self))
493         except StopIteration:
494             return False
495         return True
496
497     def __nonzero__(self):
498         return type(self).__bool__(self)
499
500     def _clone(self, klass=None, **kwargs):
501         if klass is None:
502             klass = self.__class__
503         filter_obj = self.filter_obj.clone()
504         c = klass(warrior=self.warrior, filter_obj=filter_obj)
505         c.__dict__.update(kwargs)
506         return c
507
508     def _execute(self):
509         """
510         Fetch the tasks which match the current filters.
511         """
512         return self.warrior.filter_tasks(self.filter_obj)
513
514     def all(self):
515         """
516         Returns a new TaskQuerySet that is a copy of the current one.
517         """
518         return self._clone()
519
520     def pending(self):
521         return self.filter(status=PENDING)
522
523     def completed(self):
524         return self.filter(status=COMPLETED)
525
526     def filter(self, *args, **kwargs):
527         """
528         Returns a new TaskQuerySet with the given filters added.
529         """
530         clone = self._clone()
531         for f in args:
532             clone.filter_obj.add_filter(f)
533         for key, value in kwargs.items():
534             clone.filter_obj.add_filter_param(key, value)
535         return clone
536
537     def get(self, **kwargs):
538         """
539         Performs the query and returns a single object matching the given
540         keyword arguments.
541         """
542         clone = self.filter(**kwargs)
543         num = len(clone)
544         if num == 1:
545             return clone._result_cache[0]
546         if not num:
547             raise Task.DoesNotExist(
548                 'Task matching query does not exist. '
549                 'Lookup parameters were {0}'.format(kwargs))
550         raise ValueError(
551             'get() returned more than one Task -- it returned {0}! '
552             'Lookup parameters were {1}'.format(num, kwargs))
553
554
555 class TaskWarrior(object):
556     def __init__(self, data_location='~/.task', create=True):
557         data_location = os.path.expanduser(data_location)
558         if create and not os.path.exists(data_location):
559             os.makedirs(data_location)
560         self.config = {
561             'data.location': os.path.expanduser(data_location),
562             'confirmation': 'no',
563             'dependency.confirmation': 'no', # See TW-1483 or taskrc man page
564         }
565         self.tasks = TaskQuerySet(self)
566         self.version = self._get_version()
567
568     def _get_command_args(self, args, config_override={}):
569         command_args = ['task', 'rc:/']
570         config = self.config.copy()
571         config.update(config_override)
572         for item in config.items():
573             command_args.append('rc.{0}={1}'.format(*item))
574         command_args.extend(map(str, args))
575         return command_args
576
577     def _get_version(self):
578         p = subprocess.Popen(
579                 ['task', '--version'],
580                 stdout=subprocess.PIPE,
581                 stderr=subprocess.PIPE)
582         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
583         return stdout.strip('\n')
584
585     def execute_command(self, args, config_override={}):
586         command_args = self._get_command_args(
587             args, config_override=config_override)
588         logger.debug(' '.join(command_args))
589         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
590                              stderr=subprocess.PIPE)
591         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
592         if p.returncode:
593             if stderr.strip():
594                 error_msg = stderr.strip().splitlines()[-1]
595             else:
596                 error_msg = stdout.strip()
597             raise TaskWarriorException(error_msg)
598         return stdout.strip().split('\n')
599
600     def filter_tasks(self, filter_obj):
601         args = ['export', '--'] + filter_obj.get_filter_params()
602         tasks = []
603         for line in self.execute_command(args):
604             if line:
605                 data = line.strip(',')
606                 try:
607                     filtered_task = Task(self)
608                     filtered_task._load_data(json.loads(data))
609                     tasks.append(filtered_task)
610                 except ValueError:
611                     raise TaskWarriorException('Invalid JSON: %s' % data)
612         return tasks
613
614     def merge_with(self, path, push=False):
615         path = path.rstrip('/') + '/'
616         self.execute_command(['merge', path], config_override={
617             'merge.autopush': 'yes' if push else 'no',
618         })
619
620     def undo(self):
621         self.execute_command(['undo'])