]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

make backends.TaskWarrior inherit from Backend
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10 import copy
11
12 from .task import Task, TaskQuerySet
13 from .filters import TaskWarriorFilter
14 from .serializing import local_zone
15
16 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
17
18 logger = logging.getLogger(__name__)
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplemented
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
98         self.taskrc_location = os.path.expanduser(taskrc_location)
99
100         # If taskrc does not exist, pass / to use defaults and avoid creating
101         # dummy .taskrc file by TaskWarrior
102         if not os.path.exists(self.taskrc_location):
103             self.taskrc_location = '/'
104
105         self.version = self._get_version()
106         self.config = {
107             'confirmation': 'no',
108             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
109             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
110
111             # Defaults to on since 2.4.5, we expect off during parsing
112             'json.array': 'off',
113
114             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
115             # arbitrary big number which is likely to be large enough
116             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
117         }
118
119         # Set data.location override if passed via kwarg
120         if data_location is not None:
121             data_location = os.path.expanduser(data_location)
122             if create and not os.path.exists(data_location):
123                 os.makedirs(data_location)
124             self.config['data.location'] = data_location
125
126         self.tasks = TaskQuerySet(self)
127
128     def _get_command_args(self, args, config_override=None):
129         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
130         config = self.config.copy()
131         config.update(config_override or dict())
132         for item in config.items():
133             command_args.append('rc.{0}={1}'.format(*item))
134         command_args.extend(map(six.text_type, args))
135         return command_args
136
137     def _get_version(self):
138         p = subprocess.Popen(
139                 ['task', '--version'],
140                 stdout=subprocess.PIPE,
141                 stderr=subprocess.PIPE)
142         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
143         return stdout.strip('\n')
144
145     def _get_modified_task_fields_as_args(self, task):
146         args = []
147
148         def add_field(field):
149             # Add the output of format_field method to args list (defaults to
150             # field:value)
151             serialized_value = task._serialize(field, task._data[field])
152
153             # Empty values should not be enclosed in quotation marks, see
154             # TW-1510
155             if serialized_value is '':
156                 escaped_serialized_value = ''
157             else:
158                 escaped_serialized_value = six.u("'{0}'").format(serialized_value)
159
160             format_default = lambda task: six.u("{0}:{1}").format(field,
161                                                       escaped_serialized_value)
162
163             format_func = getattr(self, 'format_{0}'.format(field),
164                                   format_default)
165
166             args.append(format_func(task))
167
168         # If we're modifying saved task, simply pass on all modified fields
169         if task.saved:
170             for field in task._modified_fields:
171                 add_field(field)
172         # For new tasks, pass all fields that make sense
173         else:
174             for field in task._data.keys():
175                 if field in task.read_only_fields:
176                     continue
177                 add_field(field)
178
179         return args
180
181     def format_depends(self, task):
182         # We need to generate added and removed dependencies list,
183         # since Taskwarrior does not accept redefining dependencies.
184
185         # This cannot be part of serialize_depends, since we need
186         # to keep a list of all depedencies in the _data dictionary,
187         # not just currently added/removed ones
188
189         old_dependencies = task._original_data.get('depends', set())
190
191         added = task['depends'] - old_dependencies
192         removed = old_dependencies - task['depends']
193
194         # Removed dependencies need to be prefixed with '-'
195         return 'depends:' + ','.join(
196                 [t['uuid'] for t in added] +
197                 ['-' + t['uuid'] for t in removed]
198             )
199
200     def format_description(self, task):
201         # Task version older than 2.4.0 ignores first word of the
202         # task description if description: prefix is used
203         if self.version < self.VERSION_2_4_0:
204             return task._data['description']
205         else:
206             return six.u("description:'{0}'").format(task._data['description'] or '')
207
208     def convert_datetime_string(self, value):
209
210         if self.version >= self.VERSION_2_4_0:
211             # For strings, use 'task calc' to evaluate the string to datetime
212             # available since TW 2.4.0
213             args = value.split()
214             result = self.execute_command(['calc'] + args)
215             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
216             localized = local_zone.localize(naive)
217         else:
218             raise ValueError("Provided value could not be converted to "
219                              "datetime, its type is not supported: {}"
220                              .format(type(value)))
221
222         return localized
223
224     @property
225     def filter_class(self):
226         return TaskWarriorFilter
227
228     # Public interface
229
230     def get_config(self):
231         raw_output = self.execute_command(
232                 ['show'],
233                 config_override={'verbose': 'nothing'}
234             )
235
236         config = dict()
237         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
238
239         for line in raw_output:
240             match = config_regex.match(line)
241             if match:
242                 config[match.group('key')] = match.group('value').strip()
243
244         return config
245
246     def execute_command(self, args, config_override=None, allow_failure=True,
247                         return_all=False):
248         command_args = self._get_command_args(
249             args, config_override=config_override)
250         logger.debug(' '.join(command_args))
251         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
252                              stderr=subprocess.PIPE)
253         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
254         if p.returncode and allow_failure:
255             if stderr.strip():
256                 error_msg = stderr.strip()
257             else:
258                 error_msg = stdout.strip()
259             raise TaskWarriorException(error_msg)
260
261         # Return all whole triplet only if explicitly asked for
262         if not return_all:
263             return stdout.rstrip().split('\n')
264         else:
265             return (stdout.rstrip().split('\n'),
266                     stderr.rstrip().split('\n'),
267                     p.returncode)
268
269     def enforce_recurrence(self):
270         # Run arbitrary report command which will trigger generation
271         # of recurrent tasks.
272
273         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
274         if self.version < self.VERSION_2_4_2:
275             self.execute_command(['next'], allow_failure=False)
276
277     def merge_with(self, path, push=False):
278         path = path.rstrip('/') + '/'
279         self.execute_command(['merge', path], config_override={
280             'merge.autopush': 'yes' if push else 'no',
281         })
282
283     def undo(self):
284         self.execute_command(['undo'])
285
286     # Backend interface implementation
287
288     def filter_tasks(self, filter_obj):
289         self.enforce_recurrence()
290         args = ['export', '--'] + filter_obj.get_filter_params()
291         tasks = []
292         for line in self.execute_command(args):
293             if line:
294                 data = line.strip(',')
295                 try:
296                     filtered_task = Task(self)
297                     filtered_task._load_data(json.loads(data))
298                     tasks.append(filtered_task)
299                 except ValueError:
300                     raise TaskWarriorException('Invalid JSON: %s' % data)
301         return tasks
302
303     def save_task(self, task):
304         """Save a task into TaskWarrior database using add/modify call"""
305
306         args = [task['uuid'], 'modify'] if task.saved else ['add']
307         args.extend(self._get_modified_task_fields_as_args(task))
308         output = self.execute_command(args)
309
310         # Parse out the new ID, if the task is being added for the first time
311         if not task.saved:
312             id_lines = [l for l in output if l.startswith('Created task ')]
313
314             # Complain loudly if it seems that more tasks were created
315             # Should not happen
316             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
317                 raise TaskWarriorException("Unexpected output when creating "
318                                            "task: %s" % '\n'.join(id_lines))
319
320             # Circumvent the ID storage, since ID is considered read-only
321             identifier = id_lines[0].split(' ')[2].rstrip('.')
322
323             # Identifier can be either ID or UUID for completed tasks
324             try:
325                 task._data['id'] = int(identifier)
326             except ValueError:
327                 task._data['uuid'] = identifier
328
329         # Refreshing is very important here, as not only modification time
330         # is updated, but arbitrary attribute may have changed due hooks
331         # altering the data before saving
332         task.refresh(after_save=True)
333
334     def delete_task(self, task):
335         self.execute_command([task['uuid'], 'delete'])
336
337     def start_task(self, task):
338         self.execute_command([task['uuid'], 'start'])
339
340     def stop_task(self, task):
341         self.execute_command([task['uuid'], 'stop'])
342
343     def complete_task(self, task):
344         # Older versions of TW do not stop active task at completion
345         if self.version < self.VERSION_2_4_0 and task.active:
346             task.stop()
347
348         self.execute_command([task['uuid'], 'done'])
349
350     def annotate_task(self, task, annotation):
351         args = [task['uuid'], 'annotate', annotation]
352         self.execute_command(args)
353
354     def denotate_task(self, task, annotation):
355         args = [task['uuid'], 'denotate', annotation]
356         self.execute_command(args)
357
358     def refresh_task(self, task, after_save=False):
359         # We need to use ID as backup for uuid here for the refreshes
360         # of newly saved tasks. Any other place in the code is fine
361         # with using UUID only.
362         args = [task['uuid'] or task['id'], 'export']
363         output = self.execute_command(args)
364
365         def valid(output):
366             return len(output) == 1 and output[0].startswith('{')
367
368         # For older TW versions attempt to uniquely locate the task
369         # using the data we have if it has been just saved.
370         # This can happen when adding a completed task on older TW versions.
371         if (not valid(output) and self.version < self.VERSION_2_4_5
372                 and after_save):
373
374             # Make a copy, removing ID and UUID. It's most likely invalid
375             # (ID 0) if it failed to match a unique task.
376             data = copy.deepcopy(task._data)
377             data.pop('id', None)
378             data.pop('uuid', None)
379
380             taskfilter = self.filter_class(self)
381             for key, value in data.items():
382                 taskfilter.add_filter_param(key, value)
383
384             output = self.execute_command(['export', '--'] +
385                 taskfilter.get_filter_params())
386
387         # If more than 1 task has been matched still, raise an exception
388         if not valid(output):
389             raise TaskWarriorException(
390                 "Unique identifiers {0} with description: {1} matches "
391                 "multiple tasks: {2}".format(
392                 task['uuid'] or task['id'], task['description'], output)
393             )
394
395         return json.loads(output[0])
396
397     def sync(self):
398         self.execute_command(['sync'])