]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix #2 -- correctly deal with unicode data
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import six
8 import subprocess
9
10 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
11 REPR_OUTPUT_SIZE = 10
12 PENDING = 'pending'
13 COMPLETED = 'completed'
14
15 logger = logging.getLogger(__name__)
16
17
18 class TaskWarriorException(Exception):
19     pass
20
21
22 class TaskResource(object):
23     read_only_fields = []
24
25     def _load_data(self, data):
26         self._data = data
27
28     def __getitem__(self, key):
29         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
30                                lambda x: x)
31         return hydrate_func(self._data.get(key))
32
33     def __setitem__(self, key, value):
34         if key in self.read_only_fields:
35             raise RuntimeError('Field \'%s\' is read-only' % key)
36         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
37                                  lambda x: x)
38         self._data[key] = dehydrate_func(value)
39         self._modified_fields.add(key)
40
41     def __str__(self):
42         s = six.text_type(self.__unicode__())
43         if not six.PY3:
44             s = s.encode('utf-8')
45         return s
46
47
48 class TaskAnnotation(TaskResource):
49     read_only_fields = ['entry', 'description']
50
51     def __init__(self, task, data={}):
52         self.task = task
53         self._load_data(data)
54
55     def deserialize_entry(self, data):
56         return datetime.datetime.strptime(data, DATE_FORMAT) if data else None
57
58     def serialize_entry(self, date):
59         return date.strftime(DATE_FORMAT) if date else ''
60
61     def remove(self):
62         self.task.remove_annotation(self)
63
64     def __unicode__(self):
65         return self['description']
66
67     __repr__ = __unicode__
68
69
70 class Task(TaskResource):
71     read_only_fields = ['id', 'entry', 'urgency']
72
73     class DoesNotExist(Exception):
74         pass
75
76     def __init__(self, warrior, data={}):
77         self.warrior = warrior
78         self._load_data(data)
79         self._modified_fields = set()
80
81     def __unicode__(self):
82         return self['description']
83
84     def serialize_due(self, date):
85         return date.strftime(DATE_FORMAT)
86
87     def deserialize_due(self, date_str):
88         if not date_str:
89             return None
90         return datetime.datetime.strptime(date_str, DATE_FORMAT)
91
92     def deserialize_annotations(self, data):
93         return [TaskAnnotation(self, d) for d in data] if data else []
94
95     def deserialize_tags(self, tags):
96         if isinstance(tags, basestring):
97             return tags.split(',') if tags else []
98         return tags
99
100     def serialize_tags(self, tags):
101         return ','.join(tags) if tags else ''
102
103     def delete(self):
104         self.warrior.execute_command([self['id'], 'delete'], config_override={
105             'confirmation': 'no',
106         })
107
108     def done(self):
109         self.warrior.execute_command([self['id'], 'done'])
110
111     def save(self):
112         args = [self['id'], 'modify'] if self['id'] else ['add']
113         args.extend(self._get_modified_fields_as_args())
114         self.warrior.execute_command(args)
115         self._modified_fields.clear()
116
117     def add_annotation(self, annotation):
118         args = [self['id'], 'annotate', annotation]
119         self.warrior.execute_command(args)
120         self.refresh(only_fields=['annotations'])
121
122     def remove_annotation(self, annotation):
123         if isinstance(annotation, TaskAnnotation):
124             annotation = annotation['description']
125         args = [self['id'], 'denotate', annotation]
126         self.warrior.execute_command(args)
127         self.refresh(only_fields=['annotations'])
128
129     def _get_modified_fields_as_args(self):
130         args = []
131         for field in self._modified_fields:
132             args.append('{}:{}'.format(field, self._data[field]))
133         return args
134
135     def refresh(self, only_fields=[]):
136         args = [self['uuid'], 'export']
137         new_data = json.loads(self.warrior.execute_command(args)[0])
138         if only_fields:
139             to_update = dict(
140                 [(k, new_data.get(k)) for k in only_fields])
141             self._data.update(to_update)
142         else:
143             self._data = new_data
144
145
146 class TaskFilter(object):
147     """
148     A set of parameters to filter the task list with.
149     """
150
151     def __init__(self, filter_params=[]):
152         self.filter_params = filter_params
153
154     def add_filter(self, filter_str):
155         self.filter_params.append(filter_str)
156
157     def add_filter_param(self, key, value):
158         key = key.replace('__', '.')
159         self.filter_params.append('{0}:{1}'.format(key, value))
160
161     def get_filter_params(self):
162         return [f for f in self.filter_params if f]
163
164     def clone(self):
165         c = self.__class__()
166         c.filter_params = list(self.filter_params)
167         return c
168
169
170 class TaskQuerySet(object):
171     """
172     Represents a lazy lookup for a task objects.
173     """
174
175     def __init__(self, warrior=None, filter_obj=None):
176         self.warrior = warrior
177         self._result_cache = None
178         self.filter_obj = filter_obj or TaskFilter()
179
180     def __deepcopy__(self, memo):
181         """
182         Deep copy of a QuerySet doesn't populate the cache
183         """
184         obj = self.__class__()
185         for k, v in self.__dict__.items():
186             if k in ('_iter', '_result_cache'):
187                 obj.__dict__[k] = None
188             else:
189                 obj.__dict__[k] = copy.deepcopy(v, memo)
190         return obj
191
192     def __repr__(self):
193         data = list(self[:REPR_OUTPUT_SIZE + 1])
194         if len(data) > REPR_OUTPUT_SIZE:
195             data[-1] = "...(remaining elements truncated)..."
196         return repr(data)
197
198     def __len__(self):
199         if self._result_cache is None:
200             self._result_cache = list(self)
201         return len(self._result_cache)
202
203     def __iter__(self):
204         if self._result_cache is None:
205             self._result_cache = self._execute()
206         return iter(self._result_cache)
207
208     def __getitem__(self, k):
209         if self._result_cache is None:
210             self._result_cache = list(self)
211         return self._result_cache.__getitem__(k)
212
213     def __bool__(self):
214         if self._result_cache is not None:
215             return bool(self._result_cache)
216         try:
217             next(iter(self))
218         except StopIteration:
219             return False
220         return True
221
222     def __nonzero__(self):
223         return type(self).__bool__(self)
224
225     def _clone(self, klass=None, **kwargs):
226         if klass is None:
227             klass = self.__class__
228         filter_obj = self.filter_obj.clone()
229         c = klass(warrior=self.warrior, filter_obj=filter_obj)
230         c.__dict__.update(kwargs)
231         return c
232
233     def _execute(self):
234         """
235         Fetch the tasks which match the current filters.
236         """
237         return self.warrior.filter_tasks(self.filter_obj)
238
239     def all(self):
240         """
241         Returns a new TaskQuerySet that is a copy of the current one.
242         """
243         return self._clone()
244
245     def pending(self):
246         return self.filter(status=PENDING)
247
248     def completed(self):
249         return self.filter(status=COMPLETED)
250
251     def filter(self, *args, **kwargs):
252         """
253         Returns a new TaskQuerySet with the given filters added.
254         """
255         clone = self._clone()
256         for f in args:
257             clone.filter_obj.add_filter(f)
258         for key, value in kwargs.items():
259             clone.filter_obj.add_filter_param(key, value)
260         return clone
261
262     def get(self, **kwargs):
263         """
264         Performs the query and returns a single object matching the given
265         keyword arguments.
266         """
267         clone = self.filter(**kwargs)
268         num = len(clone)
269         if num == 1:
270             return clone._result_cache[0]
271         if not num:
272             raise Task.DoesNotExist(
273                 'Task matching query does not exist. '
274                 'Lookup parameters were {0}'.format(kwargs))
275         raise ValueError(
276             'get() returned more than one Task -- it returned {0}! '
277             'Lookup parameters were {1}'.format(num, kwargs))
278
279
280 class TaskWarrior(object):
281     def __init__(self, data_location='~/.task', create=True):
282         data_location = os.path.expanduser(data_location)
283         if create and not os.path.exists(data_location):
284             os.makedirs(data_location)
285         self.config = {
286             'data.location': os.path.expanduser(data_location),
287         }
288         self.tasks = TaskQuerySet(self)
289
290     def _get_command_args(self, args, config_override={}):
291         command_args = ['task', 'rc:/']
292         config = self.config.copy()
293         config.update(config_override)
294         for item in config.items():
295             command_args.append('rc.{0}={1}'.format(*item))
296         command_args.extend(map(str, args))
297         return command_args
298
299     def execute_command(self, args, config_override={}):
300         command_args = self._get_command_args(
301             args, config_override=config_override)
302         logger.debug(' '.join(command_args))
303         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
304                              stderr=subprocess.PIPE)
305         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
306         if p.returncode:
307             if stderr.strip():
308                 error_msg = stderr.strip().splitlines()[-1]
309             else:
310                 error_msg = stdout.strip()
311             raise TaskWarriorException(error_msg)
312         return stdout.strip().split('\n')
313
314     def filter_tasks(self, filter_obj):
315         args = ['export', '--'] + filter_obj.get_filter_params()
316         tasks = []
317         for line in self.execute_command(args):
318             if line:
319                 data = line.strip(',')
320                 try:
321                     tasks.append(Task(self, json.loads(data)))
322                 except ValueError:
323                     raise TaskWarriorException('Invalid JSON: %s' % data)
324         return tasks
325
326     def merge_with(self, path, push=False):
327         path = path.rstrip('/') + '/'
328         self.execute_command(['merge', path], config_override={
329             'merge.autopush': 'yes' if push else 'no',
330         })
331
332     def undo(self):
333         self.execute_command(['undo'], config_override={
334             'confirmation': 'no',
335         })