]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Merge pull request #42 from robgolding63/issue-41-fix
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplementedError
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
98         self.taskrc_location = os.path.expanduser(taskrc_location)
99
100         # If taskrc does not exist, pass / to use defaults and avoid creating
101         # dummy .taskrc file by TaskWarrior
102         if not os.path.exists(self.taskrc_location):
103             self.taskrc_location = '/'
104
105         self._config = None
106         self.version = self._get_version()
107         self.overrides = {
108             'confirmation': 'no',
109             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
110             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
111
112             # Defaults to on since 2.4.5, we expect off during parsing
113             'json.array': 'off',
114
115             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
116             # arbitrary big number which is likely to be large enough
117             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
118         }
119
120         # Set data.location override if passed via kwarg
121         if data_location is not None:
122             data_location = os.path.expanduser(data_location)
123             if create and not os.path.exists(data_location):
124                 os.makedirs(data_location)
125             self.overrides['data.location'] = data_location
126
127         self.tasks = TaskQuerySet(self)
128
129     def _get_command_args(self, args, config_override=None):
130         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
131         overrides = self.overrides.copy()
132         overrides.update(config_override or dict())
133         for item in overrides.items():
134             command_args.append('rc.{0}={1}'.format(*item))
135         command_args.extend(map(six.text_type, args))
136         return command_args
137
138     def _get_version(self):
139         p = subprocess.Popen(
140             ['task', '--version'],
141             stdout=subprocess.PIPE,
142             stderr=subprocess.PIPE)
143         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
144         return stdout.strip('\n')
145
146     def _get_modified_task_fields_as_args(self, task):
147         args = []
148
149         def add_field(field):
150             # Add the output of format_field method to args list (defaults to
151             # field:value)
152             serialized_value = task._serialize(field, task._data[field])
153
154             # Empty values should not be enclosed in quotation marks, see
155             # TW-1510
156             if serialized_value is '':
157                 escaped_serialized_value = ''
158             else:
159                 escaped_serialized_value = six.u("'{0}'").format(
160                     serialized_value)
161
162             format_default = lambda task: six.u("{0}:{1}").format(
163                 field, escaped_serialized_value)
164
165             format_func = getattr(self, 'format_{0}'.format(field),
166                                   format_default)
167
168             args.append(format_func(task))
169
170         # If we're modifying saved task, simply pass on all modified fields
171         if task.saved:
172             for field in task._modified_fields:
173                 add_field(field)
174         # For new tasks, pass all fields that make sense
175         else:
176             for field in task._data.keys():
177                 if field in task.read_only_fields:
178                     continue
179                 add_field(field)
180
181         return args
182
183     def format_depends(self, task):
184         # We need to generate added and removed dependencies list,
185         # since Taskwarrior does not accept redefining dependencies.
186
187         # This cannot be part of serialize_depends, since we need
188         # to keep a list of all depedencies in the _data dictionary,
189         # not just currently added/removed ones
190
191         old_dependencies = task._original_data.get('depends', set())
192
193         added = task['depends'] - old_dependencies
194         removed = old_dependencies - task['depends']
195
196         # Removed dependencies need to be prefixed with '-'
197         return 'depends:' + ','.join(
198             [t['uuid'] for t in added] +
199             ['-' + t['uuid'] for t in removed]
200         )
201
202     def format_description(self, task):
203         # Task version older than 2.4.0 ignores first word of the
204         # task description if description: prefix is used
205         if self.version < self.VERSION_2_4_0:
206             return task._data['description']
207         else:
208             return six.u("description:'{0}'").format(task._data['description'] or '')
209
210     def convert_datetime_string(self, value):
211
212         if self.version >= self.VERSION_2_4_0:
213             # For strings, use 'task calc' to evaluate the string to datetime
214             # available since TW 2.4.0
215             args = value.split()
216             result = self.execute_command(['calc'] + args)
217             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
218             localized = local_zone.localize(naive)
219         else:
220             raise ValueError("Provided value could not be converted to "
221                              "datetime, its type is not supported: {}"
222                              .format(type(value)))
223
224         return localized
225
226     @property
227     def filter_class(self):
228         return TaskWarriorFilter
229
230     # Public interface
231
232     @property
233     def config(self):
234         # First, check if memoized information is available
235         if self._config:
236             return self._config
237
238         # If not, fetch the config using the 'show' command
239         raw_output = self.execute_command(
240             ['show'],
241             config_override={'verbose': 'nothing'}
242         )
243
244         config = dict()
245         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
246
247         for line in raw_output:
248             match = config_regex.match(line)
249             if match:
250                 config[match.group('key')] = match.group('value').strip()
251
252         # Memoize the config dict
253         self._config = ReadOnlyDictView(config)
254
255         return self._config
256
257     def execute_command(self, args, config_override=None, allow_failure=True,
258                         return_all=False):
259         command_args = self._get_command_args(
260             args, config_override=config_override)
261         logger.debug(' '.join(command_args))
262         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
263                              stderr=subprocess.PIPE)
264         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
265         if p.returncode and allow_failure:
266             if stderr.strip():
267                 error_msg = stderr.strip()
268             else:
269                 error_msg = stdout.strip()
270             raise TaskWarriorException(error_msg)
271
272         # Return all whole triplet only if explicitly asked for
273         if not return_all:
274             return stdout.rstrip().split('\n')
275         else:
276             return (stdout.rstrip().split('\n'),
277                     stderr.rstrip().split('\n'),
278                     p.returncode)
279
280     def enforce_recurrence(self):
281         # Run arbitrary report command which will trigger generation
282         # of recurrent tasks.
283
284         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
285         if self.version < self.VERSION_2_4_2:
286             self.execute_command(['next'], allow_failure=False)
287
288     def merge_with(self, path, push=False):
289         path = path.rstrip('/') + '/'
290         self.execute_command(['merge', path], config_override={
291             'merge.autopush': 'yes' if push else 'no',
292         })
293
294     def undo(self):
295         self.execute_command(['undo'])
296
297     # Backend interface implementation
298
299     def filter_tasks(self, filter_obj):
300         self.enforce_recurrence()
301         args = ['export'] + filter_obj.get_filter_params()
302         tasks = []
303         for line in self.execute_command(args):
304             if line:
305                 data = line.strip(',')
306                 try:
307                     filtered_task = Task(self)
308                     filtered_task._load_data(json.loads(data))
309                     tasks.append(filtered_task)
310                 except ValueError:
311                     raise TaskWarriorException('Invalid JSON: %s' % data)
312         return tasks
313
314     def save_task(self, task):
315         """Save a task into TaskWarrior database using add/modify call"""
316
317         args = [task['uuid'], 'modify'] if task.saved else ['add']
318         args.extend(self._get_modified_task_fields_as_args(task))
319         output = self.execute_command(args)
320
321         # Parse out the new ID, if the task is being added for the first time
322         if not task.saved:
323             id_lines = [l for l in output if l.startswith('Created task ')]
324
325             # Complain loudly if it seems that more tasks were created
326             # Should not happen.
327             # Expected output: Created task 1.
328             #                  Created task 1 (recurrence template).
329             if len(id_lines) != 1 or len(id_lines[0].split(' ')) not in (3, 5):
330                 raise TaskWarriorException("Unexpected output when creating "
331                                            "task: %s" % '\n'.join(id_lines))
332
333             # Circumvent the ID storage, since ID is considered read-only
334             identifier = id_lines[0].split(' ')[2].rstrip('.')
335
336             # Identifier can be either ID or UUID for completed tasks
337             try:
338                 task._data['id'] = int(identifier)
339             except ValueError:
340                 task._data['uuid'] = identifier
341
342         # Refreshing is very important here, as not only modification time
343         # is updated, but arbitrary attribute may have changed due hooks
344         # altering the data before saving
345         task.refresh(after_save=True)
346
347     def delete_task(self, task):
348         self.execute_command([task['uuid'], 'delete'])
349
350     def start_task(self, task):
351         self.execute_command([task['uuid'], 'start'])
352
353     def stop_task(self, task):
354         self.execute_command([task['uuid'], 'stop'])
355
356     def complete_task(self, task):
357         # Older versions of TW do not stop active task at completion
358         if self.version < self.VERSION_2_4_0 and task.active:
359             task.stop()
360
361         self.execute_command([task['uuid'], 'done'])
362
363     def annotate_task(self, task, annotation):
364         args = [task['uuid'], 'annotate', annotation]
365         self.execute_command(args)
366
367     def denotate_task(self, task, annotation):
368         args = [task['uuid'], 'denotate', annotation]
369         self.execute_command(args)
370
371     def refresh_task(self, task, after_save=False):
372         # We need to use ID as backup for uuid here for the refreshes
373         # of newly saved tasks. Any other place in the code is fine
374         # with using UUID only.
375         args = [task['uuid'] or task['id'], 'export']
376         output = self.execute_command(args)
377
378         def valid(output):
379             return len(output) == 1 and output[0].startswith('{')
380
381         # For older TW versions attempt to uniquely locate the task
382         # using the data we have if it has been just saved.
383         # This can happen when adding a completed task on older TW versions.
384         if (not valid(output) and self.version < self.VERSION_2_4_5
385                 and after_save):
386
387             # Make a copy, removing ID and UUID. It's most likely invalid
388             # (ID 0) if it failed to match a unique task.
389             data = copy.deepcopy(task._data)
390             data.pop('id', None)
391             data.pop('uuid', None)
392
393             taskfilter = self.filter_class(self)
394             for key, value in data.items():
395                 taskfilter.add_filter_param(key, value)
396
397             output = self.execute_command(['export'] +
398                                           taskfilter.get_filter_params())
399
400         # If more than 1 task has been matched still, raise an exception
401         if not valid(output):
402             raise TaskWarriorException(
403                 "Unique identifiers {0} with description: {1} matches "
404                 "multiple tasks: {2}".format(
405                     task['uuid'] or task['id'], task['description'], output)
406             )
407
408         return json.loads(output[0])
409
410     def sync(self):
411         self.execute_command(['sync'])