]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

make backends.TaskWarrior inherit from Backend
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import datetime
3 import json
4 import logging
5 import os
6 import re
7 import six
8 import subprocess
9 import copy
10
11 from .task import Task, TaskQuerySet
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19 class Backend(object):
20
21     @abc.abstractproperty
22     def filter_class(self):
23         """Returns the TaskFilter class used by this backend"""
24         pass
25
26     @abc.abstractmethod
27     def filter_tasks(self, filter_obj):
28         """Returns a list of Task objects matching the given filter"""
29         pass
30
31     @abc.abstractmethod
32     def save_task(self, task):
33         pass
34
35     @abc.abstractmethod
36     def delete_task(self, task):
37         pass
38
39     @abc.abstractmethod
40     def start_task(self, task):
41         pass
42
43     @abc.abstractmethod
44     def stop_task(self, task):
45         pass
46
47     @abc.abstractmethod
48     def complete_task(self, task):
49         pass
50
51     @abc.abstractmethod
52     def refresh_task(self, task, after_save=False):
53         """
54         Refreshes the given task. Returns new data dict with serialized
55         attributes.
56         """
57         pass
58
59     @abc.abstractmethod
60     def annotate_task(self, task, annotation):
61         pass
62
63     @abc.abstractmethod
64     def denotate_task(self, task, annotation):
65         pass
66
67     @abc.abstractmethod
68     def sync(self):
69         """Syncs the backend database with the taskd server"""
70         pass
71
72     def convert_datetime_string(self, value):
73         """
74         Converts TW syntax datetime string to a localized datetime
75         object. This method is not mandatory.
76         """
77         raise NotImplemented
78
79
80 class TaskWarriorException(Exception):
81     pass
82
83
84 class TaskWarrior(Backend):
85
86     VERSION_2_1_0 = six.u('2.1.0')
87     VERSION_2_2_0 = six.u('2.2.0')
88     VERSION_2_3_0 = six.u('2.3.0')
89     VERSION_2_4_0 = six.u('2.4.0')
90     VERSION_2_4_1 = six.u('2.4.1')
91     VERSION_2_4_2 = six.u('2.4.2')
92     VERSION_2_4_3 = six.u('2.4.3')
93     VERSION_2_4_4 = six.u('2.4.4')
94     VERSION_2_4_5 = six.u('2.4.5')
95
96     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
97         self.taskrc_location = os.path.expanduser(taskrc_location)
98
99         # If taskrc does not exist, pass / to use defaults and avoid creating
100         # dummy .taskrc file by TaskWarrior
101         if not os.path.exists(self.taskrc_location):
102             self.taskrc_location = '/'
103
104         self.version = self._get_version()
105         self.config = {
106             'confirmation': 'no',
107             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
108             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
109
110             # Defaults to on since 2.4.5, we expect off during parsing
111             'json.array': 'off',
112
113             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
114             # arbitrary big number which is likely to be large enough
115             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
116         }
117
118         # Set data.location override if passed via kwarg
119         if data_location is not None:
120             data_location = os.path.expanduser(data_location)
121             if create and not os.path.exists(data_location):
122                 os.makedirs(data_location)
123             self.config['data.location'] = data_location
124
125         self.tasks = TaskQuerySet(self)
126
127     def _get_command_args(self, args, config_override=None):
128         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
129         config = self.config.copy()
130         config.update(config_override or dict())
131         for item in config.items():
132             command_args.append('rc.{0}={1}'.format(*item))
133         command_args.extend(map(six.text_type, args))
134         return command_args
135
136     def _get_version(self):
137         p = subprocess.Popen(
138                 ['task', '--version'],
139                 stdout=subprocess.PIPE,
140                 stderr=subprocess.PIPE)
141         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
142         return stdout.strip('\n')
143
144     def _get_modified_task_fields_as_args(self, task):
145         args = []
146
147         def add_field(field):
148             # Add the output of format_field method to args list (defaults to
149             # field:value)
150             serialized_value = task._serialize(field, task._data[field])
151
152             # Empty values should not be enclosed in quotation marks, see
153             # TW-1510
154             if serialized_value is '':
155                 escaped_serialized_value = ''
156             else:
157                 escaped_serialized_value = six.u("'{0}'").format(serialized_value)
158
159             format_default = lambda task: six.u("{0}:{1}").format(field,
160                                                       escaped_serialized_value)
161
162             format_func = getattr(self, 'format_{0}'.format(field),
163                                   format_default)
164
165             args.append(format_func(task))
166
167         # If we're modifying saved task, simply pass on all modified fields
168         if task.saved:
169             for field in task._modified_fields:
170                 add_field(field)
171         # For new tasks, pass all fields that make sense
172         else:
173             for field in task._data.keys():
174                 if field in task.read_only_fields:
175                     continue
176                 add_field(field)
177
178         return args
179
180     def format_depends(self, task):
181         # We need to generate added and removed dependencies list,
182         # since Taskwarrior does not accept redefining dependencies.
183
184         # This cannot be part of serialize_depends, since we need
185         # to keep a list of all depedencies in the _data dictionary,
186         # not just currently added/removed ones
187
188         old_dependencies = task._original_data.get('depends', set())
189
190         added = task['depends'] - old_dependencies
191         removed = old_dependencies - task['depends']
192
193         # Removed dependencies need to be prefixed with '-'
194         return 'depends:' + ','.join(
195                 [t['uuid'] for t in added] +
196                 ['-' + t['uuid'] for t in removed]
197             )
198
199     def format_description(self, task):
200         # Task version older than 2.4.0 ignores first word of the
201         # task description if description: prefix is used
202         if self.version < self.VERSION_2_4_0:
203             return task._data['description']
204         else:
205             return six.u("description:'{0}'").format(task._data['description'] or '')
206
207     def convert_datetime_string(self, value):
208
209         if self.version >= self.VERSION_2_4_0:
210             # For strings, use 'task calc' to evaluate the string to datetime
211             # available since TW 2.4.0
212             args = value.split()
213             result = self.execute_command(['calc'] + args)
214             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
215             localized = local_zone.localize(naive)
216         else:
217             raise ValueError("Provided value could not be converted to "
218                              "datetime, its type is not supported: {}"
219                              .format(type(value)))
220
221         return localized
222
223     @property
224     def filter_class(self):
225         return TaskWarriorFilter
226
227     # Public interface
228
229     def get_config(self):
230         raw_output = self.execute_command(
231                 ['show'],
232                 config_override={'verbose': 'nothing'}
233             )
234
235         config = dict()
236         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
237
238         for line in raw_output:
239             match = config_regex.match(line)
240             if match:
241                 config[match.group('key')] = match.group('value').strip()
242
243         return config
244
245     def execute_command(self, args, config_override=None, allow_failure=True,
246                         return_all=False):
247         command_args = self._get_command_args(
248             args, config_override=config_override)
249         logger.debug(' '.join(command_args))
250         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
251                              stderr=subprocess.PIPE)
252         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
253         if p.returncode and allow_failure:
254             if stderr.strip():
255                 error_msg = stderr.strip()
256             else:
257                 error_msg = stdout.strip()
258             raise TaskWarriorException(error_msg)
259
260         # Return all whole triplet only if explicitly asked for
261         if not return_all:
262             return stdout.rstrip().split('\n')
263         else:
264             return (stdout.rstrip().split('\n'),
265                     stderr.rstrip().split('\n'),
266                     p.returncode)
267
268     def enforce_recurrence(self):
269         # Run arbitrary report command which will trigger generation
270         # of recurrent tasks.
271
272         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
273         if self.version < self.VERSION_2_4_2:
274             self.execute_command(['next'], allow_failure=False)
275
276     def merge_with(self, path, push=False):
277         path = path.rstrip('/') + '/'
278         self.execute_command(['merge', path], config_override={
279             'merge.autopush': 'yes' if push else 'no',
280         })
281
282     def undo(self):
283         self.execute_command(['undo'])
284
285     # Backend interface implementation
286
287     def filter_tasks(self, filter_obj):
288         self.enforce_recurrence()
289         args = ['export', '--'] + filter_obj.get_filter_params()
290         tasks = []
291         for line in self.execute_command(args):
292             if line:
293                 data = line.strip(',')
294                 try:
295                     filtered_task = Task(self)
296                     filtered_task._load_data(json.loads(data))
297                     tasks.append(filtered_task)
298                 except ValueError:
299                     raise TaskWarriorException('Invalid JSON: %s' % data)
300         return tasks
301
302     def save_task(self, task):
303         """Save a task into TaskWarrior database using add/modify call"""
304
305         args = [task['uuid'], 'modify'] if task.saved else ['add']
306         args.extend(self._get_modified_task_fields_as_args(task))
307         output = self.execute_command(args)
308
309         # Parse out the new ID, if the task is being added for the first time
310         if not task.saved:
311             id_lines = [l for l in output if l.startswith('Created task ')]
312
313             # Complain loudly if it seems that more tasks were created
314             # Should not happen
315             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
316                 raise TaskWarriorException("Unexpected output when creating "
317                                            "task: %s" % '\n'.join(id_lines))
318
319             # Circumvent the ID storage, since ID is considered read-only
320             identifier = id_lines[0].split(' ')[2].rstrip('.')
321
322             # Identifier can be either ID or UUID for completed tasks
323             try:
324                 task._data['id'] = int(identifier)
325             except ValueError:
326                 task._data['uuid'] = identifier
327
328         # Refreshing is very important here, as not only modification time
329         # is updated, but arbitrary attribute may have changed due hooks
330         # altering the data before saving
331         task.refresh(after_save=True)
332
333     def delete_task(self, task):
334         self.execute_command([task['uuid'], 'delete'])
335
336     def start_task(self, task):
337         self.execute_command([task['uuid'], 'start'])
338
339     def stop_task(self, task):
340         self.execute_command([task['uuid'], 'stop'])
341
342     def complete_task(self, task):
343         # Older versions of TW do not stop active task at completion
344         if self.version < self.VERSION_2_4_0 and task.active:
345             task.stop()
346
347         self.execute_command([task['uuid'], 'done'])
348
349     def annotate_task(self, task, annotation):
350         args = [task['uuid'], 'annotate', annotation]
351         self.execute_command(args)
352
353     def denotate_task(self, task, annotation):
354         args = [task['uuid'], 'denotate', annotation]
355         self.execute_command(args)
356
357     def refresh_task(self, task, after_save=False):
358         # We need to use ID as backup for uuid here for the refreshes
359         # of newly saved tasks. Any other place in the code is fine
360         # with using UUID only.
361         args = [task['uuid'] or task['id'], 'export']
362         output = self.execute_command(args)
363
364         def valid(output):
365             return len(output) == 1 and output[0].startswith('{')
366
367         # For older TW versions attempt to uniquely locate the task
368         # using the data we have if it has been just saved.
369         # This can happen when adding a completed task on older TW versions.
370         if (not valid(output) and self.version < self.VERSION_2_4_5
371                 and after_save):
372
373             # Make a copy, removing ID and UUID. It's most likely invalid
374             # (ID 0) if it failed to match a unique task.
375             data = copy.deepcopy(task._data)
376             data.pop('id', None)
377             data.pop('uuid', None)
378
379             taskfilter = self.filter_class(self)
380             for key, value in data.items():
381                 taskfilter.add_filter_param(key, value)
382
383             output = self.execute_command(['export', '--'] +
384                 taskfilter.get_filter_params())
385
386         # If more than 1 task has been matched still, raise an exception
387         if not valid(output):
388             raise TaskWarriorException(
389                 "Unique identifiers {0} with description: {1} matches "
390                 "multiple tasks: {2}".format(
391                 task['uuid'] or task['id'], task['description'], output)
392             )
393
394         return json.loads(output[0])
395
396     def sync(self):
397         self.execute_command(['sync'])