]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

7d1bbe4b1fffdb23b90c990cca5ac7edf66762c8
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, basestring):
97             return tags.split(',') if tags else []
98         return tags
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = data
115         # We need to use a copy for original data, so that changes
116         # are not propagated. Shallow copy is alright, since data dict uses only
117         # primitive data types
118         self._original_data = data.copy()
119
120     def _update_data(self, data, update_original=False):
121         """
122         Low level update of the internal _data dict. Data which are coming as
123         updates should already be serialized. If update_original is True, the
124         original_data dict is updated as well.
125         """
126
127         self._data.update(data)
128
129         if update_original:
130             self._original_data.update(data)
131
132     def __getitem__(self, key):
133         # This is a workaround to make TaskResource non-iterable
134         # over simple index-based iteration
135         try:
136             int(key)
137             raise StopIteration
138         except ValueError:
139             pass
140
141         return self._deserialize(key, self._data.get(key))
142
143     def __setitem__(self, key, value):
144         if key in self.read_only_fields:
145             raise RuntimeError('Field \'%s\' is read-only' % key)
146         self._data[key] = self._serialize(key, value)
147
148     def __str__(self):
149         s = six.text_type(self.__unicode__())
150         if not six.PY3:
151             s = s.encode('utf-8')
152         return s
153
154     def __repr__(self):
155         return str(self)
156
157
158 class TaskAnnotation(TaskResource):
159     read_only_fields = ['entry', 'description']
160
161     def __init__(self, task, data={}):
162         self.task = task
163         self._load_data(data)
164
165     def remove(self):
166         self.task.remove_annotation(self)
167
168     def __unicode__(self):
169         return self['description']
170
171     __repr__ = __unicode__
172
173
174 class Task(TaskResource):
175     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
176
177     class DoesNotExist(Exception):
178         pass
179
180     class CompletedTask(Exception):
181         """
182         Raised when the operation cannot be performed on the completed task.
183         """
184         pass
185
186     class DeletedTask(Exception):
187         """
188         Raised when the operation cannot be performed on the deleted task.
189         """
190         pass
191
192     class NotSaved(Exception):
193         """
194         Raised when the operation cannot be performed on the task, because
195         it has not been saved to TaskWarrior yet.
196         """
197         pass
198
199     def __init__(self, warrior, **kwargs):
200         self.warrior = warrior
201
202         # Check that user is not able to set read-only value in __init__
203         for key in kwargs.keys():
204             if key in self.read_only_fields:
205                 raise RuntimeError('Field \'%s\' is read-only' % key)
206
207         # We serialize the data in kwargs so that users of the library
208         # do not have to pass different data formats via __setitem__ and
209         # __init__ methods, that would be confusing
210
211         # Rather unfortunate syntax due to python2.6 comaptiblity
212         self._load_data(dict((key, self._serialize(key, value))
213                         for (key, value) in six.iteritems(kwargs)))
214
215     def __unicode__(self):
216         return self['description']
217
218     def __eq__(self, other):
219         if self['uuid'] and other['uuid']:
220             # For saved Tasks, just define equality by equality of uuids
221             return self['uuid'] == other['uuid']
222         else:
223             # If the tasks are not saved, compare the actual instances
224             return id(self) == id(other)
225
226
227     def __hash__(self):
228         if self['uuid']:
229             # For saved Tasks, just define equality by equality of uuids
230             return self['uuid'].__hash__()
231         else:
232             # If the tasks are not saved, return hash of instance id
233             return id(self).__hash__()
234
235     @property
236     def _modified_fields(self):
237         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
238         for key in writable_fields:
239             if self._data.get(key) != self._original_data.get(key):
240                 yield key
241
242     @property
243     def completed(self):
244         return self['status'] == six.text_type('completed')
245
246     @property
247     def deleted(self):
248         return self['status'] == six.text_type('deleted')
249
250     @property
251     def waiting(self):
252         return self['status'] == six.text_type('waiting')
253
254     @property
255     def pending(self):
256         return self['status'] == six.text_type('pending')
257
258     @property
259     def saved(self):
260         return self['uuid'] is not None or self['id'] is not None
261
262     def serialize_depends(self, cur_dependencies):
263         # Check that all the tasks are saved
264         for task in cur_dependencies:
265             if not task.saved:
266                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
267                                     'it can be set as dependency.' % task)
268
269         return super(Task, self).serialize_depends(cur_dependencies)
270
271     def format_depends(self):
272         # We need to generate added and removed dependencies list,
273         # since Taskwarrior does not accept redefining dependencies.
274
275         # This cannot be part of serialize_depends, since we need
276         # to keep a list of all depedencies in the _data dictionary,
277         # not just currently added/removed ones
278
279         old_dependencies_raw = self._original_data.get('depends','')
280         old_dependencies = self.deserialize_depends(old_dependencies_raw)
281
282         added = self['depends'] - old_dependencies
283         removed = old_dependencies - self['depends']
284
285         # Removed dependencies need to be prefixed with '-'
286         return 'depends:' + ','.join(
287                 [t['uuid'] for t in added] +
288                 ['-' + t['uuid'] for t in removed]
289             )
290
291     def format_description(self):
292         # Task version older than 2.4.0 ignores first word of the
293         # task description if description: prefix is used
294         if self.warrior.version < VERSION_2_4_0:
295             return self._data['description']
296         else:
297             return "description:'{0}'".format(self._data['description'] or '')
298
299     def delete(self):
300         if not self.saved:
301             raise Task.NotSaved("Task needs to be saved before it can be deleted")
302
303         # Refresh the status, and raise exception if the task is deleted
304         self.refresh(only_fields=['status'])
305
306         if self.deleted:
307             raise Task.DeletedTask("Task was already deleted")
308
309         self.warrior.execute_command([self['uuid'], 'delete'])
310
311         # Refresh the status again, so that we have updated info stored
312         self.refresh(only_fields=['status'])
313
314
315     def done(self):
316         if not self.saved:
317             raise Task.NotSaved("Task needs to be saved before it can be completed")
318
319         # Refresh, and raise exception if task is already completed/deleted
320         self.refresh(only_fields=['status'])
321
322         if self.completed:
323             raise Task.CompletedTask("Cannot complete a completed task")
324         elif self.deleted:
325             raise Task.DeletedTask("Deleted task cannot be completed")
326
327         self.warrior.execute_command([self['uuid'], 'done'])
328
329         # Refresh the status again, so that we have updated info stored
330         self.refresh(only_fields=['status'])
331
332     def save(self):
333         args = [self['uuid'], 'modify'] if self.saved else ['add']
334         args.extend(self._get_modified_fields_as_args())
335         output = self.warrior.execute_command(args)
336
337         # Parse out the new ID, if the task is being added for the first time
338         if not self.saved:
339             id_lines = [l for l in output if l.startswith('Created task ')]
340
341             # Complain loudly if it seems that more tasks were created
342             # Should not happen
343             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
344                 raise TaskWarriorException("Unexpected output when creating "
345                                            "task: %s" % '\n'.join(id_lines))
346
347             # Circumvent the ID storage, since ID is considered read-only
348             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
349
350         self.refresh()
351
352     def add_annotation(self, annotation):
353         if not self.saved:
354             raise Task.NotSaved("Task needs to be saved to add annotation")
355
356         args = [self['uuid'], 'annotate', annotation]
357         self.warrior.execute_command(args)
358         self.refresh(only_fields=['annotations'])
359
360     def remove_annotation(self, annotation):
361         if not self.saved:
362             raise Task.NotSaved("Task needs to be saved to remove annotation")
363
364         if isinstance(annotation, TaskAnnotation):
365             annotation = annotation['description']
366         args = [self['uuid'], 'denotate', annotation]
367         self.warrior.execute_command(args)
368         self.refresh(only_fields=['annotations'])
369
370     def _get_modified_fields_as_args(self):
371         args = []
372
373         def add_field(field):
374             # Add the output of format_field method to args list (defaults to
375             # field:value)
376             format_default = lambda k: "{0}:'{1}'".format(k, self._data[k] or '')
377             format_func = getattr(self, 'format_{0}'.format(field),
378                                   lambda: format_default(field))
379             args.append(format_func())
380
381         # If we're modifying saved task, simply pass on all modified fields
382         if self.saved:
383             for field in self._modified_fields:
384                 add_field(field)
385         # For new tasks, pass all fields that make sense
386         else:
387             for field in self._data.keys():
388                 if field in self.read_only_fields:
389                     continue
390                 add_field(field)
391
392         return args
393
394     def refresh(self, only_fields=[]):
395         # Raise error when trying to refresh a task that has not been saved
396         if not self.saved:
397             raise Task.NotSaved("Task needs to be saved to be refreshed")
398
399         # We need to use ID as backup for uuid here for the refreshes
400         # of newly saved tasks. Any other place in the code is fine
401         # with using UUID only.
402         args = [self['uuid'] or self['id'], 'export']
403         new_data = json.loads(self.warrior.execute_command(args)[0])
404         if only_fields:
405             to_update = dict(
406                 [(k, new_data.get(k)) for k in only_fields])
407             self._update_data(to_update, update_original=True)
408         else:
409             self._load_data(new_data)
410
411
412 class TaskFilter(SerializingObject):
413     """
414     A set of parameters to filter the task list with.
415     """
416
417     def __init__(self, filter_params=[]):
418         self.filter_params = filter_params
419
420     def add_filter(self, filter_str):
421         self.filter_params.append(filter_str)
422
423     def add_filter_param(self, key, value):
424         key = key.replace('__', '.')
425
426         # Replace the value with empty string, since that is the
427         # convention in TW for empty values
428         attribute_key = key.split('.')[0]
429         value = self._serialize(attribute_key, value)
430
431         # If we are filtering by uuid:, do not use uuid keyword
432         # due to TW-1452 bug
433         if key == 'uuid':
434             self.filter_params.insert(0, value)
435         else:
436             # Surround value with aphostrophes unless it's a empty string
437             value = "'%s'" % value if value else ''
438
439             # We enforce equality match by using 'is' (or 'none') modifier
440             # Without using this syntax, filter fails due to TW-1479
441             modifier = '.is' if value else '.none'
442             key = key + modifier if '.' not in key else key
443
444             self.filter_params.append("{0}:{1}".format(key, value))
445
446     def get_filter_params(self):
447         return [f for f in self.filter_params if f]
448
449     def clone(self):
450         c = self.__class__()
451         c.filter_params = list(self.filter_params)
452         return c
453
454
455 class TaskQuerySet(object):
456     """
457     Represents a lazy lookup for a task objects.
458     """
459
460     def __init__(self, warrior=None, filter_obj=None):
461         self.warrior = warrior
462         self._result_cache = None
463         self.filter_obj = filter_obj or TaskFilter()
464
465     def __deepcopy__(self, memo):
466         """
467         Deep copy of a QuerySet doesn't populate the cache
468         """
469         obj = self.__class__()
470         for k, v in self.__dict__.items():
471             if k in ('_iter', '_result_cache'):
472                 obj.__dict__[k] = None
473             else:
474                 obj.__dict__[k] = copy.deepcopy(v, memo)
475         return obj
476
477     def __repr__(self):
478         data = list(self[:REPR_OUTPUT_SIZE + 1])
479         if len(data) > REPR_OUTPUT_SIZE:
480             data[-1] = "...(remaining elements truncated)..."
481         return repr(data)
482
483     def __len__(self):
484         if self._result_cache is None:
485             self._result_cache = list(self)
486         return len(self._result_cache)
487
488     def __iter__(self):
489         if self._result_cache is None:
490             self._result_cache = self._execute()
491         return iter(self._result_cache)
492
493     def __getitem__(self, k):
494         if self._result_cache is None:
495             self._result_cache = list(self)
496         return self._result_cache.__getitem__(k)
497
498     def __bool__(self):
499         if self._result_cache is not None:
500             return bool(self._result_cache)
501         try:
502             next(iter(self))
503         except StopIteration:
504             return False
505         return True
506
507     def __nonzero__(self):
508         return type(self).__bool__(self)
509
510     def _clone(self, klass=None, **kwargs):
511         if klass is None:
512             klass = self.__class__
513         filter_obj = self.filter_obj.clone()
514         c = klass(warrior=self.warrior, filter_obj=filter_obj)
515         c.__dict__.update(kwargs)
516         return c
517
518     def _execute(self):
519         """
520         Fetch the tasks which match the current filters.
521         """
522         return self.warrior.filter_tasks(self.filter_obj)
523
524     def all(self):
525         """
526         Returns a new TaskQuerySet that is a copy of the current one.
527         """
528         return self._clone()
529
530     def pending(self):
531         return self.filter(status=PENDING)
532
533     def completed(self):
534         return self.filter(status=COMPLETED)
535
536     def filter(self, *args, **kwargs):
537         """
538         Returns a new TaskQuerySet with the given filters added.
539         """
540         clone = self._clone()
541         for f in args:
542             clone.filter_obj.add_filter(f)
543         for key, value in kwargs.items():
544             clone.filter_obj.add_filter_param(key, value)
545         return clone
546
547     def get(self, **kwargs):
548         """
549         Performs the query and returns a single object matching the given
550         keyword arguments.
551         """
552         clone = self.filter(**kwargs)
553         num = len(clone)
554         if num == 1:
555             return clone._result_cache[0]
556         if not num:
557             raise Task.DoesNotExist(
558                 'Task matching query does not exist. '
559                 'Lookup parameters were {0}'.format(kwargs))
560         raise ValueError(
561             'get() returned more than one Task -- it returned {0}! '
562             'Lookup parameters were {1}'.format(num, kwargs))
563
564
565 class TaskWarrior(object):
566     def __init__(self, data_location='~/.task', create=True):
567         data_location = os.path.expanduser(data_location)
568         if create and not os.path.exists(data_location):
569             os.makedirs(data_location)
570         self.config = {
571             'data.location': os.path.expanduser(data_location),
572             'confirmation': 'no',
573         }
574         self.tasks = TaskQuerySet(self)
575         self.version = self._get_version()
576
577     def _get_command_args(self, args, config_override={}):
578         command_args = ['task', 'rc:/']
579         config = self.config.copy()
580         config.update(config_override)
581         for item in config.items():
582             command_args.append('rc.{0}={1}'.format(*item))
583         command_args.extend(map(str, args))
584         return command_args
585
586     def _get_version(self):
587         p = subprocess.Popen(
588                 ['task', '--version'],
589                 stdout=subprocess.PIPE,
590                 stderr=subprocess.PIPE)
591         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
592         return stdout.strip('\n')
593
594     def execute_command(self, args, config_override={}):
595         command_args = self._get_command_args(
596             args, config_override=config_override)
597         logger.debug(' '.join(command_args))
598         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
599                              stderr=subprocess.PIPE)
600         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
601         if p.returncode:
602             if stderr.strip():
603                 error_msg = stderr.strip().splitlines()[-1]
604             else:
605                 error_msg = stdout.strip()
606             raise TaskWarriorException(error_msg)
607         return stdout.strip().split('\n')
608
609     def filter_tasks(self, filter_obj):
610         args = ['export', '--'] + filter_obj.get_filter_params()
611         tasks = []
612         for line in self.execute_command(args):
613             if line:
614                 data = line.strip(',')
615                 try:
616                     filtered_task = Task(self)
617                     filtered_task._load_data(json.loads(data))
618                     tasks.append(filtered_task)
619                 except ValueError:
620                     raise TaskWarriorException('Invalid JSON: %s' % data)
621         return tasks
622
623     def merge_with(self, path, push=False):
624         path = path.rstrip('/') + '/'
625         self.execute_command(['merge', path], config_override={
626             'merge.autopush': 'yes' if push else 'no',
627         })
628
629     def undo(self):
630         self.execute_command(['undo'])