]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Install taskwarrior before running tests
[etc/taskwarrior.git] / tasklib / task.py
1 import copy
2 import datetime
3 import json
4 import logging
5 import os
6 import subprocess
7
8 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
9 REPR_OUTPUT_SIZE = 10
10 PENDING = 'pending'
11 COMPLETED = 'completed'
12
13 logger = logging.getLogger(__name__)
14
15
16 class TaskWarriorException(Exception):
17     pass
18
19
20 class Task(object):
21
22     class DoesNotExist(Exception):
23         pass
24
25     def __init__(self, warrior, data={}):
26         self.warrior = warrior
27         self._data = data
28         self._modified_fields = set()
29
30     def __unicode__(self):
31         return self['description']
32
33     def __getitem__(self, key):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x)
36         return hydrate_func(self._data.get(key))
37
38     def __setitem__(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x)
41         self._data[key] = dehydrate_func(value)
42         self._modified_fields.add(key)
43
44     def serialize_due(self, date):
45         return date.strftime(DATE_FORMAT)
46
47     def deserialize_due(self, date_str):
48         if not date_str:
49             return None
50         return datetime.datetime.strptime(date_str, DATE_FORMAT)
51
52     def serialize_annotations(self, annotations):
53         ann_list = list(annotations)
54         for ann in ann_list:
55             ann['entry'] = ann['entry'].strftime(DATE_FORMAT)
56         return ann_list
57
58     def deserialize_annotations(self, annotations):
59         ann_list = list(annotations)
60         for ann in ann_list:
61             ann['entry'] = datetime.datetime.strptime(
62                 ann['entry'], DATE_FORMAT)
63         return ann_list
64
65     def deserialize_tags(self, tags):
66         if isinstance(tags, basestring):
67             return tags.split(',') if tags else []
68         return tags
69
70     def serialize_tags(self, tags):
71         return ','.join(tags) if tags else ''
72
73     def delete(self):
74         self.warrior.execute_command([self['id'], 'delete'], config_override={
75             'confirmation': 'no',
76         })
77
78     def done(self):
79         self.warrior.execute_command([self['id'], 'done'])
80
81     def save(self):
82         args = [self['id'], 'modify'] if self['id'] else ['add']
83         args.extend(self._get_modified_fields_as_args())
84         self.warrior.execute_command(args)
85         self._modified_fields.clear()
86
87     def _get_modified_fields_as_args(self):
88         args = []
89         for field in self._modified_fields:
90             args.append('{}:{}'.format(field, self._data[field]))
91         return args
92
93     __repr__ = __unicode__
94
95
96 class TaskFilter(object):
97     """
98     A set of parameters to filter the task list with.
99     """
100
101     def __init__(self, filter_params=[]):
102         self.filter_params = filter_params
103
104     def add_filter(self, filter_str):
105         self.filter_params.append(filter_str)
106
107     def add_filter_param(self, key, value):
108         key = key.replace('__', '.')
109         self.filter_params.append('{0}:{1}'.format(key, value))
110
111     def get_filter_params(self):
112         return [f for f in self.filter_params if f]
113
114     def clone(self):
115         c = self.__class__()
116         c.filter_params = list(self.filter_params)
117         return c
118
119
120 class TaskQuerySet(object):
121     """
122     Represents a lazy lookup for a task objects.
123     """
124
125     def __init__(self, warrior=None, filter_obj=None):
126         self.warrior = warrior
127         self._result_cache = None
128         self.filter_obj = filter_obj or TaskFilter()
129
130     def __deepcopy__(self, memo):
131         """
132         Deep copy of a QuerySet doesn't populate the cache
133         """
134         obj = self.__class__()
135         for k, v in self.__dict__.items():
136             if k in ('_iter', '_result_cache'):
137                 obj.__dict__[k] = None
138             else:
139                 obj.__dict__[k] = copy.deepcopy(v, memo)
140         return obj
141
142     def __repr__(self):
143         data = list(self[:REPR_OUTPUT_SIZE + 1])
144         if len(data) > REPR_OUTPUT_SIZE:
145             data[-1] = "...(remaining elements truncated)..."
146         return repr(data)
147
148     def __len__(self):
149         if self._result_cache is None:
150             self._result_cache = list(self)
151         return len(self._result_cache)
152
153     def __iter__(self):
154         if self._result_cache is None:
155             self._result_cache = self._execute()
156         return iter(self._result_cache)
157
158     def __getitem__(self, k):
159         if self._result_cache is None:
160             self._result_cache = list(self)
161         return self._result_cache.__getitem__(k)
162
163     def __bool__(self):
164         if self._result_cache is not None:
165             return bool(self._result_cache)
166         try:
167             next(iter(self))
168         except StopIteration:
169             return False
170         return True
171
172     def __nonzero__(self):
173         return type(self).__bool__(self)
174
175     def _clone(self, klass=None, **kwargs):
176         if klass is None:
177             klass = self.__class__
178         filter_obj = self.filter_obj.clone()
179         c = klass(warrior=self.warrior, filter_obj=filter_obj)
180         c.__dict__.update(kwargs)
181         return c
182
183     def _execute(self):
184         """
185         Fetch the tasks which match the current filters.
186         """
187         return self.warrior.filter_tasks(self.filter_obj)
188
189     def all(self):
190         """
191         Returns a new TaskQuerySet that is a copy of the current one.
192         """
193         return self._clone()
194
195     def pending(self):
196         return self.filter(status=PENDING)
197
198     def completed(self):
199         return self.filter(status=COMPLETED)
200
201     def filter(self, *args, **kwargs):
202         """
203         Returns a new TaskQuerySet with the given filters added.
204         """
205         clone = self._clone()
206         for f in args:
207             clone.filter_obj.add_filter(f)
208         for key, value in kwargs.items():
209             clone.filter_obj.add_filter_param(key, value)
210         return clone
211
212     def get(self, **kwargs):
213         """
214         Performs the query and returns a single object matching the given
215         keyword arguments.
216         """
217         clone = self.filter(**kwargs)
218         num = len(clone)
219         if num == 1:
220             return clone._result_cache[0]
221         if not num:
222             raise Task.DoesNotExist(
223                 'Task matching query does not exist. '
224                 'Lookup parameters were {0}'.format(kwargs))
225         raise ValueError(
226             'get() returned more than one Task -- it returned {0}! '
227             'Lookup parameters were {1}'.format(num, kwargs))
228
229
230 class TaskWarrior(object):
231     def __init__(self, data_location='~/.task', create=True):
232         data_location = os.path.expanduser(data_location)
233         if not os.path.exists(data_location):
234             os.makedirs(data_location)
235         self.config = {
236             'data.location': os.path.expanduser(data_location),
237         }
238         self.tasks = TaskQuerySet(self)
239
240     def _get_command_args(self, args, config_override={}):
241         command_args = ['task', 'rc:/']
242         config = self.config.copy()
243         config.update(config_override)
244         for item in config.items():
245             command_args.append('rc.{0}={1}'.format(*item))
246         command_args.extend(map(str, args))
247         return command_args
248
249     def execute_command(self, args, config_override={}):
250         command_args = self._get_command_args(
251             args, config_override=config_override)
252         logger.debug(' '.join(command_args))
253         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
254                              stderr=subprocess.PIPE)
255         stdout, stderr = p.communicate()
256         if p.returncode:
257             error_msg = stderr.strip().splitlines()[-1]
258             raise TaskWarriorException(error_msg)
259         return stdout.strip().split('\n')
260
261     def filter_tasks(self, filter_obj):
262         args = ['export', '--'] + filter_obj.get_filter_params()
263         tasks = []
264         for line in self.execute_command(args):
265             if line:
266                 tasks.append(Task(self, json.loads(line.strip(','))))
267         return tasks
268
269     def merge_with(self, path, push=False):
270         path = path.rstrip('/') + '/'
271         self.execute_command(['merge', path], config_override={
272             'merge.autopush': 'yes' if push else 'no',
273         })
274
275     def undo(self):
276         self.execute_command(['undo'], config_override={
277             'confirmation': 'no',
278         })