]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

004ed906730a9ca83f51e04fa57a275b63df6f51
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 VERSION_2_1_0 = six.u('2.1.0')
16 VERSION_2_2_0 = six.u('2.2.0')
17 VERSION_2_3_0 = six.u('2.3.0')
18 VERSION_2_4_0 = six.u('2.4.0')
19
20 logger = logging.getLogger(__name__)
21
22
23 class TaskWarriorException(Exception):
24     pass
25
26
27 class SerializingObject(object):
28     """
29     Common ancestor for TaskResource & TaskFilter, since they both
30     need to serialize arguments.
31     """
32
33     def _deserialize(self, key, value):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x if x != '' else None)
36         return hydrate_func(value)
37
38     def _serialize(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x if x is not None else '')
41         return dehydrate_func(value)
42
43     def timestamp_serializer(self, date):
44         if not date:
45             return None
46         return date.strftime(DATE_FORMAT)
47
48     def timestamp_deserializer(self, date_str):
49         if not date_str:
50             return None
51         return datetime.datetime.strptime(date_str, DATE_FORMAT)
52
53     def serialize_entry(self, value):
54         return self.timestamp_serializer(value)
55
56     def deserialize_entry(self, value):
57         return self.timestamp_deserializer(value)
58
59     def serialize_modified(self, value):
60         return self.timestamp_serializer(value)
61
62     def deserialize_modified(self, value):
63         return self.timestamp_deserializer(value)
64
65     def serialize_due(self, value):
66         return self.timestamp_serializer(value)
67
68     def deserialize_due(self, value):
69         return self.timestamp_deserializer(value)
70
71     def serialize_scheduled(self, value):
72         return self.timestamp_serializer(value)
73
74     def deserialize_scheduled(self, value):
75         return self.timestamp_deserializer(value)
76
77     def serialize_until(self, value):
78         return self.timestamp_serializer(value)
79
80     def deserialize_until(self, value):
81         return self.timestamp_deserializer(value)
82
83     def serialize_wait(self, value):
84         return self.timestamp_serializer(value)
85
86     def deserialize_wait(self, value):
87         return self.timestamp_deserializer(value)
88
89     def deserialize_annotations(self, data):
90         return [TaskAnnotation(self, d) for d in data] if data else []
91
92     def serialize_tags(self, tags):
93         return ','.join(tags) if tags else ''
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, basestring):
97             return tags.split(',') if tags else []
98         return tags
99
100     def serialize_depends(self, cur_dependencies):
101         # Return the list of uuids
102         return ','.join(task['uuid'] for task in cur_dependencies)
103
104     def deserialize_depends(self, raw_uuids):
105         raw_uuids = raw_uuids or ''  # Convert None to empty string
106         uuids = raw_uuids.split(',')
107         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
108
109
110 class TaskResource(SerializingObject):
111     read_only_fields = []
112
113     def _load_data(self, data):
114         self._data = data
115         # We need to use a copy for original data, so that changes
116         # are not propagated. Shallow copy is alright, since data dict uses only
117         # primitive data types
118         self._original_data = data.copy()
119
120     def _update_data(self, data, update_original=False):
121         """
122         Low level update of the internal _data dict. Data which are coming as
123         updates should already be serialized. If update_original is True, the
124         original_data dict is updated as well.
125         """
126
127         self._data.update(data)
128
129         if update_original:
130             self._original_data.update(data)
131
132     def __getitem__(self, key):
133         # This is a workaround to make TaskResource non-iterable
134         # over simple index-based iteration
135         try:
136             int(key)
137             raise StopIteration
138         except ValueError:
139             pass
140
141         return self._deserialize(key, self._data.get(key))
142
143     def __setitem__(self, key, value):
144         if key in self.read_only_fields:
145             raise RuntimeError('Field \'%s\' is read-only' % key)
146         self._data[key] = self._serialize(key, value)
147
148     def __str__(self):
149         s = six.text_type(self.__unicode__())
150         if not six.PY3:
151             s = s.encode('utf-8')
152         return s
153
154     def __repr__(self):
155         return str(self)
156
157
158 class TaskAnnotation(TaskResource):
159     read_only_fields = ['entry', 'description']
160
161     def __init__(self, task, data={}):
162         self.task = task
163         self._load_data(data)
164
165     def remove(self):
166         self.task.remove_annotation(self)
167
168     def __unicode__(self):
169         return self['description']
170
171     __repr__ = __unicode__
172
173
174 class Task(TaskResource):
175     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
176
177     class DoesNotExist(Exception):
178         pass
179
180     class CompletedTask(Exception):
181         """
182         Raised when the operation cannot be performed on the completed task.
183         """
184         pass
185
186     class DeletedTask(Exception):
187         """
188         Raised when the operation cannot be performed on the deleted task.
189         """
190         pass
191
192     class NotSaved(Exception):
193         """
194         Raised when the operation cannot be performed on the task, because
195         it has not been saved to TaskWarrior yet.
196         """
197         pass
198
199     def __init__(self, warrior, **kwargs):
200         self.warrior = warrior
201
202         # Check that user is not able to set read-only value in __init__
203         for key in kwargs.keys():
204             if key in self.read_only_fields:
205                 raise RuntimeError('Field \'%s\' is read-only' % key)
206
207         # We serialize the data in kwargs so that users of the library
208         # do not have to pass different data formats via __setitem__ and
209         # __init__ methods, that would be confusing
210
211         # Rather unfortunate syntax due to python2.6 comaptiblity
212         self._load_data(dict((key, self._serialize(key, value))
213                         for (key, value) in six.iteritems(kwargs)))
214
215     def __unicode__(self):
216         return self['description']
217
218     def __eq__(self, other):
219         if self['uuid'] and other['uuid']:
220             # For saved Tasks, just define equality by equality of uuids
221             return self['uuid'] == other['uuid']
222         else:
223             # If the tasks are not saved, compare the actual instances
224             return id(self) == id(other)
225
226
227     def __hash__(self):
228         if self['uuid']:
229             # For saved Tasks, just define equality by equality of uuids
230             return self['uuid'].__hash__()
231         else:
232             # If the tasks are not saved, return hash of instance id
233             return id(self).__hash__()
234
235     @property
236     def _modified_fields(self):
237         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
238         for key in writable_fields:
239             if self._data.get(key) != self._original_data.get(key):
240                 yield key
241
242     @property
243     def completed(self):
244         return self['status'] == six.text_type('completed')
245
246     @property
247     def deleted(self):
248         return self['status'] == six.text_type('deleted')
249
250     @property
251     def waiting(self):
252         return self['status'] == six.text_type('waiting')
253
254     @property
255     def pending(self):
256         return self['status'] == six.text_type('pending')
257
258     @property
259     def saved(self):
260         return self['uuid'] is not None or self['id'] is not None
261
262     def serialize_depends(self, cur_dependencies):
263         # Check that all the tasks are saved
264         for task in cur_dependencies:
265             if not task.saved:
266                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
267                                     'it can be set as dependency.' % task)
268
269         return super(Task, self).serialize_depends(cur_dependencies)
270
271     def format_depends(self):
272         # We need to generate added and removed dependencies list,
273         # since Taskwarrior does not accept redefining dependencies.
274
275         # This cannot be part of serialize_depends, since we need
276         # to keep a list of all depedencies in the _data dictionary,
277         # not just currently added/removed ones
278
279         old_dependencies_raw = self._original_data.get('depends','')
280         old_dependencies = self.deserialize_depends(old_dependencies_raw)
281
282         added = self['depends'] - old_dependencies
283         removed = old_dependencies - self['depends']
284
285         # Removed dependencies need to be prefixed with '-'
286         return 'depends:' + ','.join(
287                 [t['uuid'] for t in added] +
288                 ['-' + t['uuid'] for t in removed]
289             )
290
291     def format_description(self):
292         # Task version older than 2.4.0 ignores first word of the
293         # task description if description: prefix is used
294         if self.warrior.version < VERSION_2_4_0:
295             return self._data['description']
296         else:
297             return "description:'{0}'".format(self._data['description'] or '')
298
299     def delete(self):
300         if not self.saved:
301             raise Task.NotSaved("Task needs to be saved before it can be deleted")
302
303         # Refresh the status, and raise exception if the task is deleted
304         self.refresh(only_fields=['status'])
305
306         if self.deleted:
307             raise Task.DeletedTask("Task was already deleted")
308
309         self.warrior.execute_command([self['uuid'], 'delete'])
310
311         # Refresh the status again, so that we have updated info stored
312         self.refresh(only_fields=['status'])
313
314
315     def done(self):
316         if not self.saved:
317             raise Task.NotSaved("Task needs to be saved before it can be completed")
318
319         # Refresh, and raise exception if task is already completed/deleted
320         self.refresh(only_fields=['status'])
321
322         if self.completed:
323             raise Task.CompletedTask("Cannot complete a completed task")
324         elif self.deleted:
325             raise Task.DeletedTask("Deleted task cannot be completed")
326
327         self.warrior.execute_command([self['uuid'], 'done'])
328
329         # Refresh the status again, so that we have updated info stored
330         self.refresh(only_fields=['status'])
331
332     def save(self):
333         args = [self['uuid'], 'modify'] if self.saved else ['add']
334         args.extend(self._get_modified_fields_as_args())
335         output = self.warrior.execute_command(args)
336
337         # Parse out the new ID, if the task is being added for the first time
338         if not self.saved:
339             id_lines = [l for l in output if l.startswith('Created task ')]
340
341             # Complain loudly if it seems that more tasks were created
342             # Should not happen
343             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
344                 raise TaskWarriorException("Unexpected output when creating "
345                                            "task: %s" % '\n'.join(id_lines))
346
347             # Circumvent the ID storage, since ID is considered read-only
348             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
349
350         self.refresh()
351
352     def add_annotation(self, annotation):
353         if not self.saved:
354             raise Task.NotSaved("Task needs to be saved to add annotation")
355
356         args = [self['uuid'], 'annotate', annotation]
357         self.warrior.execute_command(args)
358         self.refresh(only_fields=['annotations'])
359
360     def remove_annotation(self, annotation):
361         if not self.saved:
362             raise Task.NotSaved("Task needs to be saved to add annotation")
363
364         if isinstance(annotation, TaskAnnotation):
365             annotation = annotation['description']
366         args = [self['uuid'], 'denotate', annotation]
367         self.warrior.execute_command(args)
368         self.refresh(only_fields=['annotations'])
369
370     def _get_modified_fields_as_args(self):
371         args = []
372
373         def add_field(field):
374             # Add the output of format_field method to args list (defaults to
375             # field:value)
376             format_default = lambda k: "{0}:'{1}'".format(k, self._data[k] or '')
377             format_func = getattr(self, 'format_{0}'.format(field),
378                                   lambda: format_default(field))
379             args.append(format_func())
380
381         # If we're modifying saved task, simply pass on all modified fields
382         if self.saved:
383             for field in self._modified_fields:
384                 add_field(field)
385         # For new tasks, pass all fields that make sense
386         else:
387             for field in self._data.keys():
388                 if field in self.read_only_fields:
389                     continue
390                 add_field(field)
391
392         return args
393
394     def refresh(self, only_fields=[]):
395         # Raise error when trying to refresh a task that has not been saved
396         if not self.saved:
397             raise Task.NotSaved("Task needs to be saved to be refreshed")
398
399         # We need to use ID as backup for uuid here for the refreshes
400         # of newly saved tasks. Any other place in the code is fine
401         # with using UUID only.
402         args = [self['uuid'] or self['id'], 'export']
403         new_data = json.loads(self.warrior.execute_command(args)[0])
404         if only_fields:
405             to_update = dict(
406                 [(k, new_data.get(k)) for k in only_fields])
407             self._update_data(to_update, update_original=True)
408         else:
409             self._load_data(new_data)
410
411
412 class TaskFilter(SerializingObject):
413     """
414     A set of parameters to filter the task list with.
415     """
416
417     def __init__(self, filter_params=[]):
418         self.filter_params = filter_params
419
420     def add_filter(self, filter_str):
421         self.filter_params.append(filter_str)
422
423     def add_filter_param(self, key, value):
424         key = key.replace('__', '.')
425
426         # Replace the value with empty string, since that is the
427         # convention in TW for empty values
428         attribute_key = key.split('.')[0]
429         value = self._serialize(attribute_key, value)
430
431         # If we are filtering by uuid:, do not use uuid keyword
432         # due to TW-1452 bug
433         if key == 'uuid':
434             self.filter_params.insert(0, value)
435         else:
436             self.filter_params.append("{0}:'{1}'".format(key, value))
437
438     def get_filter_params(self):
439         return [f for f in self.filter_params if f]
440
441     def clone(self):
442         c = self.__class__()
443         c.filter_params = list(self.filter_params)
444         return c
445
446
447 class TaskQuerySet(object):
448     """
449     Represents a lazy lookup for a task objects.
450     """
451
452     def __init__(self, warrior=None, filter_obj=None):
453         self.warrior = warrior
454         self._result_cache = None
455         self.filter_obj = filter_obj or TaskFilter()
456
457     def __deepcopy__(self, memo):
458         """
459         Deep copy of a QuerySet doesn't populate the cache
460         """
461         obj = self.__class__()
462         for k, v in self.__dict__.items():
463             if k in ('_iter', '_result_cache'):
464                 obj.__dict__[k] = None
465             else:
466                 obj.__dict__[k] = copy.deepcopy(v, memo)
467         return obj
468
469     def __repr__(self):
470         data = list(self[:REPR_OUTPUT_SIZE + 1])
471         if len(data) > REPR_OUTPUT_SIZE:
472             data[-1] = "...(remaining elements truncated)..."
473         return repr(data)
474
475     def __len__(self):
476         if self._result_cache is None:
477             self._result_cache = list(self)
478         return len(self._result_cache)
479
480     def __iter__(self):
481         if self._result_cache is None:
482             self._result_cache = self._execute()
483         return iter(self._result_cache)
484
485     def __getitem__(self, k):
486         if self._result_cache is None:
487             self._result_cache = list(self)
488         return self._result_cache.__getitem__(k)
489
490     def __bool__(self):
491         if self._result_cache is not None:
492             return bool(self._result_cache)
493         try:
494             next(iter(self))
495         except StopIteration:
496             return False
497         return True
498
499     def __nonzero__(self):
500         return type(self).__bool__(self)
501
502     def _clone(self, klass=None, **kwargs):
503         if klass is None:
504             klass = self.__class__
505         filter_obj = self.filter_obj.clone()
506         c = klass(warrior=self.warrior, filter_obj=filter_obj)
507         c.__dict__.update(kwargs)
508         return c
509
510     def _execute(self):
511         """
512         Fetch the tasks which match the current filters.
513         """
514         return self.warrior.filter_tasks(self.filter_obj)
515
516     def all(self):
517         """
518         Returns a new TaskQuerySet that is a copy of the current one.
519         """
520         return self._clone()
521
522     def pending(self):
523         return self.filter(status=PENDING)
524
525     def completed(self):
526         return self.filter(status=COMPLETED)
527
528     def filter(self, *args, **kwargs):
529         """
530         Returns a new TaskQuerySet with the given filters added.
531         """
532         clone = self._clone()
533         for f in args:
534             clone.filter_obj.add_filter(f)
535         for key, value in kwargs.items():
536             clone.filter_obj.add_filter_param(key, value)
537         return clone
538
539     def get(self, **kwargs):
540         """
541         Performs the query and returns a single object matching the given
542         keyword arguments.
543         """
544         clone = self.filter(**kwargs)
545         num = len(clone)
546         if num == 1:
547             return clone._result_cache[0]
548         if not num:
549             raise Task.DoesNotExist(
550                 'Task matching query does not exist. '
551                 'Lookup parameters were {0}'.format(kwargs))
552         raise ValueError(
553             'get() returned more than one Task -- it returned {0}! '
554             'Lookup parameters were {1}'.format(num, kwargs))
555
556
557 class TaskWarrior(object):
558     def __init__(self, data_location='~/.task', create=True):
559         data_location = os.path.expanduser(data_location)
560         if create and not os.path.exists(data_location):
561             os.makedirs(data_location)
562         self.config = {
563             'data.location': os.path.expanduser(data_location),
564             'confirmation': 'no',
565         }
566         self.tasks = TaskQuerySet(self)
567         self.version = self._get_version()
568
569     def _get_command_args(self, args, config_override={}):
570         command_args = ['task', 'rc:/']
571         config = self.config.copy()
572         config.update(config_override)
573         for item in config.items():
574             command_args.append('rc.{0}={1}'.format(*item))
575         command_args.extend(map(str, args))
576         return command_args
577
578     def _get_version(self):
579         p = subprocess.Popen(
580                 ['task', '--version'],
581                 stdout=subprocess.PIPE,
582                 stderr=subprocess.PIPE)
583         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
584         return stdout.strip('\n')
585
586     def execute_command(self, args, config_override={}):
587         command_args = self._get_command_args(
588             args, config_override=config_override)
589         logger.debug(' '.join(command_args))
590         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
591                              stderr=subprocess.PIPE)
592         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
593         if p.returncode:
594             if stderr.strip():
595                 error_msg = stderr.strip().splitlines()[-1]
596             else:
597                 error_msg = stdout.strip()
598             raise TaskWarriorException(error_msg)
599         return stdout.strip().split('\n')
600
601     def filter_tasks(self, filter_obj):
602         args = ['export', '--'] + filter_obj.get_filter_params()
603         tasks = []
604         for line in self.execute_command(args):
605             if line:
606                 data = line.strip(',')
607                 try:
608                     filtered_task = Task(self)
609                     filtered_task._load_data(json.loads(data))
610                     tasks.append(filtered_task)
611                 except ValueError:
612                     raise TaskWarriorException('Invalid JSON: %s' % data)
613         return tasks
614
615     def merge_with(self, path, push=False):
616         path = path.rstrip('/') + '/'
617         self.execute_command(['merge', path], config_override={
618             'merge.autopush': 'yes' if push else 'no',
619         })
620
621     def undo(self):
622         self.execute_command(['undo'])