]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add a couple of simple tests and fix typo
[etc/taskwarrior.git] / tasklib / task.py
1 import copy
2 import datetime
3 import json
4 import logging
5 import os
6 import subprocess
7
8 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
9 REPR_OUTPUT_SIZE = 10
10 PENDING = 'pending'
11
12 logger = logging.getLogger(__name__)
13
14
15 class TaskWarriorException(Exception):
16     pass
17
18
19 class Task(object):
20
21     class DoesNotExist(Exception):
22         pass
23
24     def __init__(self, warrior, data={}):
25         self.warrior = warrior
26         self._data = data
27         print data
28         self._modified_fields = set()
29
30     def __unicode__(self):
31         return self['description']
32
33     def __getitem__(self, key):
34         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
35                                lambda x: x)
36         return hydrate_func(self._data.get(key))
37
38     def __setitem__(self, key, value):
39         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
40                                  lambda x: x)
41         self._data[key] = dehydrate_func(value)
42         self._modified_fields.add(key)
43
44     def serialize_due(self, date):
45         return date.strftime(DATE_FORMAT)
46
47     def deserialize_due(self, date_str):
48         if not date_str:
49             return None
50         return datetime.datetime.strptime(date_str, DATE_FORMAT)
51
52     def serialize_annotations(self, annotations):
53         ann_list = list(annotations)
54         for ann in ann_list:
55             ann['entry'] = ann['entry'].strftime(DATE_FORMAT)
56         return ann_list
57
58     def deserialize_annotations(self, annotations):
59         ann_list = list(annotations)
60         for ann in ann_list:
61             ann['entry'] = datetime.datetime.strptime(
62                 ann['entry'], DATE_FORMAT)
63         return ann_list
64
65     def deserialize_tags(self, tags):
66         if isinstance(tags, basestring):
67             return tags.split(',') if tags else []
68         return tags
69
70     def serialize_tags(self, tags):
71         return ','.join(tags) if tags else ''
72
73     def delete(self):
74         self.warrior.execute_command([self['id'], 'delete'], config_override={
75             'confirmation': 'no',
76         })
77
78     def done(self):
79         self.warrior.execute_command([self['id'], 'done'])
80
81     def save(self):
82         args = [self['id'], 'modify'] if self['id'] else ['add']
83         args.extend(self._get_modified_fields_as_args())
84         self.warrior.execute_command(args)
85         self._modified_fields.clear()
86
87     def _get_modified_fields_as_args(self):
88         args = []
89         for field in self._modified_fields:
90             args.append('{}:{}'.format(field, self._data[field]))
91         return args
92
93     __repr__ = __unicode__
94
95
96 class TaskFilter(object):
97     """
98     A set of parameters to filter the task list with.
99     """
100
101     def __init__(self, filter_params=[]):
102         self.filter_params = filter_params
103
104     def add_filter(self, filter_str):
105         self.filter_params.append(filter_str)
106
107     def add_filter_param(self, key, value):
108         key = key.replace('__', '.')
109         self.filter_params.append('{0}:{1}'.format(key, value))
110
111     def get_filter_params(self):
112         return [f for f in self.filter_params if f]
113
114     def clone(self):
115         c = self.__class__()
116         c.filter_params = list(self.filter_params)
117         return c
118
119
120 class TaskQuerySet(object):
121     """
122     Represents a lazy lookup for a task objects.
123     """
124
125     def __init__(self, warrior=None, filter_obj=None):
126         self.warrior = warrior
127         self._result_cache = None
128         self.filter_obj = filter_obj or TaskFilter()
129
130     def __deepcopy__(self, memo):
131         """
132         Deep copy of a QuerySet doesn't populate the cache
133         """
134         obj = self.__class__()
135         for k, v in self.__dict__.items():
136             if k in ('_iter', '_result_cache'):
137                 obj.__dict__[k] = None
138             else:
139                 obj.__dict__[k] = copy.deepcopy(v, memo)
140         return obj
141
142     def __repr__(self):
143         data = list(self[:REPR_OUTPUT_SIZE + 1])
144         if len(data) > REPR_OUTPUT_SIZE:
145             data[-1] = "...(remaining elements truncated)..."
146         return repr(data)
147
148     def __len__(self):
149         if self._result_cache is None:
150             self._result_cache = list(self)
151         return len(self._result_cache)
152
153     def __iter__(self):
154         if self._result_cache is None:
155             self._result_cache = self._execute()
156         return iter(self._result_cache)
157
158     def __getitem__(self, k):
159         if self._result_cache is None:
160             self._result_cache = list(self)
161         return self._result_cache.__getitem__(k)
162
163     def __bool__(self):
164         if self._result_cache is not None:
165             return bool(self._result_cache)
166         try:
167             next(iter(self))
168         except StopIteration:
169             return False
170         return True
171
172     def __nonzero__(self):
173         return type(self).__bool__(self)
174
175     def _clone(self, klass=None, **kwargs):
176         if klass is None:
177             klass = self.__class__
178         filter_obj = self.filter_obj.clone()
179         c = klass(warrior=self.warrior, filter_obj=filter_obj)
180         c.__dict__.update(kwargs)
181         return c
182
183     def _execute(self):
184         """
185         Fetch the tasks which match the current filters.
186         """
187         return self.warrior.filter_tasks(self.filter_obj)
188
189     def all(self):
190         """
191         Returns a new TaskQuerySet that is a copy of the current one.
192         """
193         return self._clone()
194
195     def pending(self):
196         return self.filter(status=PENDING)
197
198     def filter(self, *args, **kwargs):
199         """
200         Returns a new TaskQuerySet with the given filters added.
201         """
202         clone = self._clone()
203         for f in args:
204             clone.filter_obj.add_filter(f)
205         for key, value in kwargs.items():
206             clone.filter_obj.add_filter_param(key, value)
207         return clone
208
209     def get(self, **kwargs):
210         """
211         Performs the query and returns a single object matching the given
212         keyword arguments.
213         """
214         clone = self.filter(**kwargs)
215         num = len(clone)
216         if num == 1:
217             return clone._result_cache[0]
218         if not num:
219             raise Task.DoesNotExist(
220                 'Task matching query does not exist. '
221                 'Lookup parameters were {0}'.format(kwargs))
222         raise ValueError(
223             'get() returned more than one Task -- it returned {0}! '
224             'Lookup parameters were {1}'.format(num, kwargs))
225
226
227 class TaskWarrior(object):
228     def __init__(self, data_location='~/.task', create=True):
229         data_location = os.path.expanduser(data_location)
230         if not os.path.exists(data_location):
231             os.makedirs(data_location)
232         self.config = {
233             'data.location': os.path.expanduser(data_location),
234         }
235         self.tasks = TaskQuerySet(self)
236
237     def _get_command_args(self, args, config_override={}):
238         command_args = ['task', 'rc:/']
239         config = self.config.copy()
240         config.update(config_override)
241         for item in config.items():
242             command_args.append('rc.{0}={1}'.format(*item))
243         command_args.extend(map(str, args))
244         return command_args
245
246     def execute_command(self, args, config_override={}):
247         command_args = self._get_command_args(
248             args, config_override=config_override)
249         logger.debug(' '.join(command_args))
250         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
251                              stderr=subprocess.PIPE)
252         stdout, stderr = p.communicate()
253         if p.returncode:
254             error_msg = stderr.strip().splitlines()[-1]
255             raise TaskWarriorException(error_msg)
256         return stdout.strip().split('\n')
257
258     def filter_tasks(self, filter_obj):
259         args = ['export', '--'] + filter_obj.get_filter_params()
260         tasks = []
261         for line in self.execute_command(args):
262             if line:
263                 tasks.append(Task(self, json.loads(line.strip(','))))
264         return tasks
265
266     def merge_with(self, path, push=False):
267         path = path.rstrip('/') + '/'
268         self.execute_command(['merge', path], config_override={
269             'merge.autopush': 'yes' if push else 'no',
270         })
271
272     def undo(self):
273         self.execute_command(['undo'], config_override={
274             'confirmation': 'no',
275         })