]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Task: Add NotSaved exception
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 logger = logging.getLogger(__name__)
16
17
18 class TaskWarriorException(Exception):
19     pass
20
21
22 class TaskResource(object):
23     read_only_fields = []
24
25     def _load_data(self, data):
26         self._data = data
27
28     def __getitem__(self, key):
29         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
30                                lambda x: x)
31         return hydrate_func(self._data.get(key))
32
33     def __setitem__(self, key, value):
34         if key in self.read_only_fields:
35             raise RuntimeError('Field \'%s\' is read-only' % key)
36         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
37                                  lambda x: x)
38         self._data[key] = dehydrate_func(value)
39         self._modified_fields.add(key)
40
41     def __str__(self):
42         s = six.text_type(self.__unicode__())
43         if not six.PY3:
44             s = s.encode('utf-8')
45         return s
46
47     def __repr__(self):
48         return str(self)
49
50
51 class TaskAnnotation(TaskResource):
52     read_only_fields = ['entry', 'description']
53
54     def __init__(self, task, data={}):
55         self.task = task
56         self._load_data(data)
57
58     def deserialize_entry(self, data):
59         return datetime.datetime.strptime(data, DATE_FORMAT) if data else None
60
61     def serialize_entry(self, date):
62         return date.strftime(DATE_FORMAT) if date else ''
63
64     def remove(self):
65         self.task.remove_annotation(self)
66
67     def __unicode__(self):
68         return self['description']
69
70     __repr__ = __unicode__
71
72
73 class Task(TaskResource):
74     read_only_fields = ['id', 'entry', 'urgency', 'uuid']
75
76     class DoesNotExist(Exception):
77         pass
78
79     class CompletedTask(Exception):
80         """
81         Raised when the operation cannot be performed on the completed task.
82         """
83         pass
84
85     class DeletedTask(Exception):
86         """
87         Raised when the operation cannot be performed on the deleted task.
88         """
89         pass
90
91     class NotSaved(Exception):
92         """
93         Raised when the operation cannot be performed on the task, because
94         it has not been saved to TaskWarrior yet.
95         """
96         pass
97
98     def __init__(self, warrior, data={}):
99         self.warrior = warrior
100         self._load_data(data)
101         self._modified_fields = set()
102
103     def __unicode__(self):
104         return self['description']
105
106     @property
107     def completed(self):
108         return self['status'] == six.text_type('completed')
109
110     @property
111     def deleted(self):
112         return self['status'] == six.text_type('deleted')
113
114     @property
115     def waiting(self):
116         return self['status'] == six.text_type('waiting')
117
118     @property
119     def pending(self):
120         return self['status'] == six.text_type('pending')
121
122     @property
123     def saved(self):
124         return self['uuid'] is not None or self['id'] is not None
125
126     def serialize_due(self, date):
127         return date.strftime(DATE_FORMAT)
128
129     def deserialize_due(self, date_str):
130         if not date_str:
131             return None
132         return datetime.datetime.strptime(date_str, DATE_FORMAT)
133
134     def deserialize_annotations(self, data):
135         return [TaskAnnotation(self, d) for d in data] if data else []
136
137     def deserialize_tags(self, tags):
138         if isinstance(tags, basestring):
139             return tags.split(',') if tags else []
140         return tags
141
142     def serialize_tags(self, tags):
143         return ','.join(tags) if tags else ''
144
145     def delete(self):
146         if not self.saved:
147             raise self.NotSaved("Task needs to be saved before it can be deleted")
148
149         # Refresh the status, and raise exception if the task is deleted
150         self.refresh(only_fields=['status'])
151
152         if self.deleted:
153             raise self.DeletedTask("Task was already deleted")
154
155         self.warrior.execute_command([self['uuid'], 'delete'], config_override={
156             'confirmation': 'no',
157         })
158
159         # Refresh the status again, so that we have updated info stored
160         self.refresh(only_fields=['status'])
161
162
163     def done(self):
164         if not self.saved:
165             raise self.NotSaved("Task needs to be saved before it can be completed")
166
167         # Refresh, and raise exception if task is already completed/deleted
168         self.refresh(only_fields=['status'])
169
170         if self.completed:
171             raise self.CompletedTask("Cannot complete a completed task")
172         elif self.deleted:
173             raise self.DeletedTask("Deleted task cannot be completed")
174
175         self.warrior.execute_command([self['uuid'], 'done'])
176
177         # Refresh the status again, so that we have updated info stored
178         self.refresh(only_fields=['status'])
179
180     def save(self):
181         args = [self['uuid'], 'modify'] if self.saved else ['add']
182         args.extend(self._get_modified_fields_as_args())
183         output = self.warrior.execute_command(args)
184
185         # Parse out the new ID, if the task is being added for the first time
186         if not self.saved:
187             id_lines = [l for l in output if l.startswith('Created task ')]
188
189             # Complain loudly if it seems that more tasks were created
190             # Should not happen
191             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
192                 raise TaskWarriorException("Unexpected output when creating "
193                                            "task: %s" % '\n'.join(id_lines))
194
195             # Circumvent the ID storage, since ID is considered read-only
196             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
197
198         self._modified_fields.clear()
199         self.refresh()
200
201     def add_annotation(self, annotation):
202         if not self.saved:
203             raise self.NotSaved("Task needs to be saved to add annotation")
204
205         args = [self['uuid'], 'annotate', annotation]
206         self.warrior.execute_command(args)
207         self.refresh(only_fields=['annotations'])
208
209     def remove_annotation(self, annotation):
210         if not self.saved:
211             raise self.NotSaved("Task needs to be saved to add annotation")
212
213         if isinstance(annotation, TaskAnnotation):
214             annotation = annotation['description']
215         args = [self['uuid'], 'denotate', annotation]
216         self.warrior.execute_command(args)
217         self.refresh(only_fields=['annotations'])
218
219     def _get_modified_fields_as_args(self):
220         args = []
221
222         # If we're modifying saved task, simply pass on all modified fields
223         if self.saved:
224             for field in self._modified_fields:
225                 args.append('{0}:{1}'.format(field, self._data[field]))
226         # For new tasks, pass all fields that make sense
227         else:
228             for field in self._data.keys():
229                 if field in self.read_only_fields:
230                     continue
231                 args.append('{0}:{1}'.format(field, self._data[field]))
232
233         return args
234
235     def refresh(self, only_fields=[]):
236         # Raise error when trying to refresh a task that has not been saved
237         if not self.saved:
238             raise self.NotSaved("Task needs to be saved to be refreshed")
239
240         # We need to use ID as backup for uuid here for the refreshes
241         # of newly saved tasks. Any other place in the code is fine
242         # with using UUID only.
243         args = [self['uuid'] or self['id'], 'export']
244         new_data = json.loads(self.warrior.execute_command(args)[0])
245         if only_fields:
246             to_update = dict(
247                 [(k, new_data.get(k)) for k in only_fields])
248             self._data.update(to_update)
249         else:
250             self._data = new_data
251
252
253 class TaskFilter(object):
254     """
255     A set of parameters to filter the task list with.
256     """
257
258     def __init__(self, filter_params=[]):
259         self.filter_params = filter_params
260
261     def add_filter(self, filter_str):
262         self.filter_params.append(filter_str)
263
264     def add_filter_param(self, key, value):
265         key = key.replace('__', '.')
266
267         # Replace the value with empty string, since that is the
268         # convention in TW for empty values
269         value = value if value is not None else ''
270         self.filter_params.append('{0}:{1}'.format(key, value))
271
272     def get_filter_params(self):
273         return [f for f in self.filter_params if f]
274
275     def clone(self):
276         c = self.__class__()
277         c.filter_params = list(self.filter_params)
278         return c
279
280
281 class TaskQuerySet(object):
282     """
283     Represents a lazy lookup for a task objects.
284     """
285
286     def __init__(self, warrior=None, filter_obj=None):
287         self.warrior = warrior
288         self._result_cache = None
289         self.filter_obj = filter_obj or TaskFilter()
290
291     def __deepcopy__(self, memo):
292         """
293         Deep copy of a QuerySet doesn't populate the cache
294         """
295         obj = self.__class__()
296         for k, v in self.__dict__.items():
297             if k in ('_iter', '_result_cache'):
298                 obj.__dict__[k] = None
299             else:
300                 obj.__dict__[k] = copy.deepcopy(v, memo)
301         return obj
302
303     def __repr__(self):
304         data = list(self[:REPR_OUTPUT_SIZE + 1])
305         if len(data) > REPR_OUTPUT_SIZE:
306             data[-1] = "...(remaining elements truncated)..."
307         return repr(data)
308
309     def __len__(self):
310         if self._result_cache is None:
311             self._result_cache = list(self)
312         return len(self._result_cache)
313
314     def __iter__(self):
315         if self._result_cache is None:
316             self._result_cache = self._execute()
317         return iter(self._result_cache)
318
319     def __getitem__(self, k):
320         if self._result_cache is None:
321             self._result_cache = list(self)
322         return self._result_cache.__getitem__(k)
323
324     def __bool__(self):
325         if self._result_cache is not None:
326             return bool(self._result_cache)
327         try:
328             next(iter(self))
329         except StopIteration:
330             return False
331         return True
332
333     def __nonzero__(self):
334         return type(self).__bool__(self)
335
336     def _clone(self, klass=None, **kwargs):
337         if klass is None:
338             klass = self.__class__
339         filter_obj = self.filter_obj.clone()
340         c = klass(warrior=self.warrior, filter_obj=filter_obj)
341         c.__dict__.update(kwargs)
342         return c
343
344     def _execute(self):
345         """
346         Fetch the tasks which match the current filters.
347         """
348         return self.warrior.filter_tasks(self.filter_obj)
349
350     def all(self):
351         """
352         Returns a new TaskQuerySet that is a copy of the current one.
353         """
354         return self._clone()
355
356     def pending(self):
357         return self.filter(status=PENDING)
358
359     def completed(self):
360         return self.filter(status=COMPLETED)
361
362     def filter(self, *args, **kwargs):
363         """
364         Returns a new TaskQuerySet with the given filters added.
365         """
366         clone = self._clone()
367         for f in args:
368             clone.filter_obj.add_filter(f)
369         for key, value in kwargs.items():
370             clone.filter_obj.add_filter_param(key, value)
371         return clone
372
373     def get(self, **kwargs):
374         """
375         Performs the query and returns a single object matching the given
376         keyword arguments.
377         """
378         clone = self.filter(**kwargs)
379         num = len(clone)
380         if num == 1:
381             return clone._result_cache[0]
382         if not num:
383             raise Task.DoesNotExist(
384                 'Task matching query does not exist. '
385                 'Lookup parameters were {0}'.format(kwargs))
386         raise ValueError(
387             'get() returned more than one Task -- it returned {0}! '
388             'Lookup parameters were {1}'.format(num, kwargs))
389
390
391 class TaskWarrior(object):
392     def __init__(self, data_location='~/.task', create=True):
393         data_location = os.path.expanduser(data_location)
394         if create and not os.path.exists(data_location):
395             os.makedirs(data_location)
396         self.config = {
397             'data.location': os.path.expanduser(data_location),
398         }
399         self.tasks = TaskQuerySet(self)
400
401     def _get_command_args(self, args, config_override={}):
402         command_args = ['task', 'rc:/']
403         config = self.config.copy()
404         config.update(config_override)
405         for item in config.items():
406             command_args.append('rc.{0}={1}'.format(*item))
407         command_args.extend(map(str, args))
408         return command_args
409
410     def execute_command(self, args, config_override={}):
411         command_args = self._get_command_args(
412             args, config_override=config_override)
413         logger.debug(' '.join(command_args))
414         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
415                              stderr=subprocess.PIPE)
416         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
417         if p.returncode:
418             if stderr.strip():
419                 error_msg = stderr.strip().splitlines()[-1]
420             else:
421                 error_msg = stdout.strip()
422             raise TaskWarriorException(error_msg)
423         return stdout.strip().split('\n')
424
425     def filter_tasks(self, filter_obj):
426         args = ['export', '--'] + filter_obj.get_filter_params()
427         tasks = []
428         for line in self.execute_command(args):
429             if line:
430                 data = line.strip(',')
431                 try:
432                     tasks.append(Task(self, json.loads(data)))
433                 except ValueError:
434                     raise TaskWarriorException('Invalid JSON: %s' % data)
435         return tasks
436
437     def merge_with(self, path, push=False):
438         path = path.rstrip('/') + '/'
439         self.execute_command(['merge', path], config_override={
440             'merge.autopush': 'yes' if push else 'no',
441         })
442
443     def undo(self):
444         self.execute_command(['undo'], config_override={
445             'confirmation': 'no',
446         })