]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add workaround for TW-1452 bug
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 logger = logging.getLogger(__name__)
16
17
18 class TaskWarriorException(Exception):
19     pass
20
21
22 class TaskResource(object):
23     read_only_fields = []
24
25     def _load_data(self, data):
26         self._data = data
27
28     def __getitem__(self, key):
29         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
30                                lambda x: x)
31         return hydrate_func(self._data.get(key))
32
33     def __setitem__(self, key, value):
34         if key in self.read_only_fields:
35             raise RuntimeError('Field \'%s\' is read-only' % key)
36         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
37                                  lambda x: x)
38         self._data[key] = dehydrate_func(value)
39         self._modified_fields.add(key)
40
41     def __str__(self):
42         s = six.text_type(self.__unicode__())
43         if not six.PY3:
44             s = s.encode('utf-8')
45         return s
46
47     def __repr__(self):
48         return str(self)
49
50
51 class TaskAnnotation(TaskResource):
52     read_only_fields = ['entry', 'description']
53
54     def __init__(self, task, data={}):
55         self.task = task
56         self._load_data(data)
57
58     def deserialize_entry(self, data):
59         return datetime.datetime.strptime(data, DATE_FORMAT) if data else None
60
61     def serialize_entry(self, date):
62         return date.strftime(DATE_FORMAT) if date else ''
63
64     def remove(self):
65         self.task.remove_annotation(self)
66
67     def __unicode__(self):
68         return self['description']
69
70     __repr__ = __unicode__
71
72
73 class Task(TaskResource):
74     read_only_fields = ['id', 'entry', 'urgency']
75
76     class DoesNotExist(Exception):
77         pass
78
79     def __init__(self, warrior, data={}):
80         self.warrior = warrior
81         self._load_data(data)
82         self._modified_fields = set()
83
84     def __unicode__(self):
85         return self['description']
86
87     def serialize_due(self, date):
88         return date.strftime(DATE_FORMAT)
89
90     def deserialize_due(self, date_str):
91         if not date_str:
92             return None
93         return datetime.datetime.strptime(date_str, DATE_FORMAT)
94
95     def deserialize_annotations(self, data):
96         return [TaskAnnotation(self, d) for d in data] if data else []
97
98     def deserialize_tags(self, tags):
99         if isinstance(tags, basestring):
100             return tags.split(',') if tags else []
101         return tags
102
103     def serialize_tags(self, tags):
104         return ','.join(tags) if tags else ''
105
106     def delete(self):
107         self.warrior.execute_command([self['id'], 'delete'], config_override={
108             'confirmation': 'no',
109         })
110
111     def done(self):
112         self.warrior.execute_command([self['id'], 'done'])
113
114     def save(self):
115         args = [self['id'], 'modify'] if self['id'] else ['add']
116         args.extend(self._get_modified_fields_as_args())
117         self.warrior.execute_command(args)
118         self._modified_fields.clear()
119
120     def add_annotation(self, annotation):
121         args = [self['id'], 'annotate', annotation]
122         self.warrior.execute_command(args)
123         self.refresh(only_fields=['annotations'])
124
125     def remove_annotation(self, annotation):
126         if isinstance(annotation, TaskAnnotation):
127             annotation = annotation['description']
128         args = [self['id'], 'denotate', annotation]
129         self.warrior.execute_command(args)
130         self.refresh(only_fields=['annotations'])
131
132     def _get_modified_fields_as_args(self):
133         args = []
134         for field in self._modified_fields:
135             args.append('{}:{}'.format(field, self._data[field]))
136         return args
137
138     def refresh(self, only_fields=[]):
139         args = [self['uuid'], 'export']
140         new_data = json.loads(self.warrior.execute_command(args)[0])
141         if only_fields:
142             to_update = dict(
143                 [(k, new_data.get(k)) for k in only_fields])
144             self._data.update(to_update)
145         else:
146             self._data = new_data
147
148
149 class TaskFilter(object):
150     """
151     A set of parameters to filter the task list with.
152     """
153
154     def __init__(self, filter_params=[]):
155         self.filter_params = filter_params
156
157     def add_filter(self, filter_str):
158         self.filter_params.append(filter_str)
159
160     def add_filter_param(self, key, value):
161         key = key.replace('__', '.')
162
163         # Replace the value with empty string, since that is the
164         # convention in TW for empty values
165         value = value if value is not None else ''
166
167         # If we are filtering by uuid:, do not use uuid keyword
168         # due to TW-1452 bug
169         if key == 'uuid':
170             self.filter_params.insert(0, value)
171         else:
172             self.filter_params.append('{0}:{1}'.format(key, value))
173
174     def get_filter_params(self):
175         return [f for f in self.filter_params if f]
176
177     def clone(self):
178         c = self.__class__()
179         c.filter_params = list(self.filter_params)
180         return c
181
182
183 class TaskQuerySet(object):
184     """
185     Represents a lazy lookup for a task objects.
186     """
187
188     def __init__(self, warrior=None, filter_obj=None):
189         self.warrior = warrior
190         self._result_cache = None
191         self.filter_obj = filter_obj or TaskFilter()
192
193     def __deepcopy__(self, memo):
194         """
195         Deep copy of a QuerySet doesn't populate the cache
196         """
197         obj = self.__class__()
198         for k, v in self.__dict__.items():
199             if k in ('_iter', '_result_cache'):
200                 obj.__dict__[k] = None
201             else:
202                 obj.__dict__[k] = copy.deepcopy(v, memo)
203         return obj
204
205     def __repr__(self):
206         data = list(self[:REPR_OUTPUT_SIZE + 1])
207         if len(data) > REPR_OUTPUT_SIZE:
208             data[-1] = "...(remaining elements truncated)..."
209         return repr(data)
210
211     def __len__(self):
212         if self._result_cache is None:
213             self._result_cache = list(self)
214         return len(self._result_cache)
215
216     def __iter__(self):
217         if self._result_cache is None:
218             self._result_cache = self._execute()
219         return iter(self._result_cache)
220
221     def __getitem__(self, k):
222         if self._result_cache is None:
223             self._result_cache = list(self)
224         return self._result_cache.__getitem__(k)
225
226     def __bool__(self):
227         if self._result_cache is not None:
228             return bool(self._result_cache)
229         try:
230             next(iter(self))
231         except StopIteration:
232             return False
233         return True
234
235     def __nonzero__(self):
236         return type(self).__bool__(self)
237
238     def _clone(self, klass=None, **kwargs):
239         if klass is None:
240             klass = self.__class__
241         filter_obj = self.filter_obj.clone()
242         c = klass(warrior=self.warrior, filter_obj=filter_obj)
243         c.__dict__.update(kwargs)
244         return c
245
246     def _execute(self):
247         """
248         Fetch the tasks which match the current filters.
249         """
250         return self.warrior.filter_tasks(self.filter_obj)
251
252     def all(self):
253         """
254         Returns a new TaskQuerySet that is a copy of the current one.
255         """
256         return self._clone()
257
258     def pending(self):
259         return self.filter(status=PENDING)
260
261     def completed(self):
262         return self.filter(status=COMPLETED)
263
264     def filter(self, *args, **kwargs):
265         """
266         Returns a new TaskQuerySet with the given filters added.
267         """
268         clone = self._clone()
269         for f in args:
270             clone.filter_obj.add_filter(f)
271         for key, value in kwargs.items():
272             clone.filter_obj.add_filter_param(key, value)
273         return clone
274
275     def get(self, **kwargs):
276         """
277         Performs the query and returns a single object matching the given
278         keyword arguments.
279         """
280         clone = self.filter(**kwargs)
281         num = len(clone)
282         if num == 1:
283             return clone._result_cache[0]
284         if not num:
285             raise Task.DoesNotExist(
286                 'Task matching query does not exist. '
287                 'Lookup parameters were {0}'.format(kwargs))
288         raise ValueError(
289             'get() returned more than one Task -- it returned {0}! '
290             'Lookup parameters were {1}'.format(num, kwargs))
291
292
293 class TaskWarrior(object):
294     def __init__(self, data_location='~/.task', create=True):
295         data_location = os.path.expanduser(data_location)
296         if create and not os.path.exists(data_location):
297             os.makedirs(data_location)
298         self.config = {
299             'data.location': os.path.expanduser(data_location),
300         }
301         self.tasks = TaskQuerySet(self)
302
303     def _get_command_args(self, args, config_override={}):
304         command_args = ['task', 'rc:/']
305         config = self.config.copy()
306         config.update(config_override)
307         for item in config.items():
308             command_args.append('rc.{0}={1}'.format(*item))
309         command_args.extend(map(str, args))
310         return command_args
311
312     def execute_command(self, args, config_override={}):
313         command_args = self._get_command_args(
314             args, config_override=config_override)
315         logger.debug(' '.join(command_args))
316         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
317                              stderr=subprocess.PIPE)
318         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
319         if p.returncode:
320             if stderr.strip():
321                 error_msg = stderr.strip().splitlines()[-1]
322             else:
323                 error_msg = stdout.strip()
324             raise TaskWarriorException(error_msg)
325         return stdout.strip().split('\n')
326
327     def filter_tasks(self, filter_obj):
328         args = ['export', '--'] + filter_obj.get_filter_params()
329         tasks = []
330         for line in self.execute_command(args):
331             if line:
332                 data = line.strip(',')
333                 try:
334                     tasks.append(Task(self, json.loads(data)))
335                 except ValueError:
336                     raise TaskWarriorException('Invalid JSON: %s' % data)
337         return tasks
338
339     def merge_with(self, path, push=False):
340         path = path.rstrip('/') + '/'
341         self.execute_command(['merge', path], config_override={
342             'merge.autopush': 'yes' if push else 'no',
343         })
344
345     def undo(self):
346         self.execute_command(['undo'], config_override={
347             'confirmation': 'no',
348         })