]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

backend: Do not assume that all command arguments are (byte)strings
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplementedError
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
98         self.taskrc_location = os.path.expanduser(taskrc_location)
99
100         # If taskrc does not exist, pass / to use defaults and avoid creating
101         # dummy .taskrc file by TaskWarrior
102         if not os.path.exists(self.taskrc_location):
103             self.taskrc_location = '/'
104
105         self._config = None
106         self.version = self._get_version()
107         self.overrides = {
108             'confirmation': 'no',
109             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
110             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
111
112             # Defaults to on since 2.4.5, we expect off during parsing
113             'json.array': 'off',
114
115             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
116             # arbitrary big number which is likely to be large enough
117             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
118         }
119
120         # Set data.location override if passed via kwarg
121         if data_location is not None:
122             data_location = os.path.expanduser(data_location)
123             if create and not os.path.exists(data_location):
124                 os.makedirs(data_location)
125             self.overrides['data.location'] = data_location
126
127         self.tasks = TaskQuerySet(self)
128
129     def _get_command_args(self, args, config_override=None):
130         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
131         overrides = self.overrides.copy()
132         overrides.update(config_override or dict())
133         for item in overrides.items():
134             command_args.append('rc.{0}={1}'.format(*item))
135         command_args.extend([
136             x.decode('utf-8') if isinstance(x, six.binary_type)
137             else six.text_type(x) for x in args
138         ])
139         return command_args
140
141     def _get_version(self):
142         p = subprocess.Popen(
143             ['task', '--version'],
144             stdout=subprocess.PIPE,
145             stderr=subprocess.PIPE)
146         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
147         return stdout.strip('\n')
148
149     def _get_modified_task_fields_as_args(self, task):
150         args = []
151
152         def add_field(field):
153             # Add the output of format_field method to args list (defaults to
154             # field:value)
155             serialized_value = task._serialize(field, task._data[field])
156
157             # Empty values should not be enclosed in quotation marks, see
158             # TW-1510
159             if serialized_value is '':
160                 escaped_serialized_value = ''
161             else:
162                 escaped_serialized_value = six.u("'{0}'").format(
163                     serialized_value)
164
165             format_default = lambda task: six.u("{0}:{1}").format(
166                 field, escaped_serialized_value)
167
168             format_func = getattr(self, 'format_{0}'.format(field),
169                                   format_default)
170
171             args.append(format_func(task))
172
173         # If we're modifying saved task, simply pass on all modified fields
174         if task.saved:
175             for field in task._modified_fields:
176                 add_field(field)
177
178         # For new tasks, pass all fields that make sense
179         else:
180             for field in task._data.keys():
181                 # We cannot set stuff that's read only (ID, UUID, ..)
182                 if field in task.read_only_fields:
183                     continue
184                 # We do not want to do field deletion for new tasks
185                 if task._data[field] is None:
186                     continue
187                 # Otherwise we're fine
188                 add_field(field)
189
190         return args
191
192     def format_depends(self, task):
193         # We need to generate added and removed dependencies list,
194         # since Taskwarrior does not accept redefining dependencies.
195
196         # This cannot be part of serialize_depends, since we need
197         # to keep a list of all depedencies in the _data dictionary,
198         # not just currently added/removed ones
199
200         old_dependencies = task._original_data.get('depends', set())
201
202         added = task['depends'] - old_dependencies
203         removed = old_dependencies - task['depends']
204
205         # Removed dependencies need to be prefixed with '-'
206         return 'depends:' + ','.join(
207             [t['uuid'] for t in added] +
208             ['-' + t['uuid'] for t in removed]
209         )
210
211     def format_description(self, task):
212         # Task version older than 2.4.0 ignores first word of the
213         # task description if description: prefix is used
214         if self.version < self.VERSION_2_4_0:
215             return task._data['description']
216         else:
217             return six.u("description:'{0}'").format(task._data['description'] or '')
218
219     def convert_datetime_string(self, value):
220
221         if self.version >= self.VERSION_2_4_0:
222             # For strings, use 'task calc' to evaluate the string to datetime
223             # available since TW 2.4.0
224             args = value.split()
225             result = self.execute_command(['calc'] + args)
226             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
227             localized = local_zone.localize(naive)
228         else:
229             raise ValueError("Provided value could not be converted to "
230                              "datetime, its type is not supported: {}"
231                              .format(type(value)))
232
233         return localized
234
235     @property
236     def filter_class(self):
237         return TaskWarriorFilter
238
239     # Public interface
240
241     @property
242     def config(self):
243         # First, check if memoized information is available
244         if self._config:
245             return self._config
246
247         # If not, fetch the config using the 'show' command
248         raw_output = self.execute_command(
249             ['show'],
250             config_override={'verbose': 'nothing'}
251         )
252
253         config = dict()
254         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
255
256         for line in raw_output:
257             match = config_regex.match(line)
258             if match:
259                 config[match.group('key')] = match.group('value').strip()
260
261         # Memoize the config dict
262         self._config = ReadOnlyDictView(config)
263
264         return self._config
265
266     def execute_command(self, args, config_override=None, allow_failure=True,
267                         return_all=False):
268         command_args = self._get_command_args(
269             args, config_override=config_override)
270         logger.debug(u' '.join(command_args))
271
272         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
273                              stderr=subprocess.PIPE)
274         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
275         if p.returncode and allow_failure:
276             if stderr.strip():
277                 error_msg = stderr.strip()
278             else:
279                 error_msg = stdout.strip()
280             error_msg += u'\nCommand used: ' + u' '.join(command_args)
281             raise TaskWarriorException(error_msg)
282
283         # Return all whole triplet only if explicitly asked for
284         if not return_all:
285             return stdout.rstrip().split('\n')
286         else:
287             return (stdout.rstrip().split('\n'),
288                     stderr.rstrip().split('\n'),
289                     p.returncode)
290
291     def enforce_recurrence(self):
292         # Run arbitrary report command which will trigger generation
293         # of recurrent tasks.
294
295         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
296         if self.version < self.VERSION_2_4_2:
297             self.execute_command(['next'], allow_failure=False)
298
299     def merge_with(self, path, push=False):
300         path = path.rstrip('/') + '/'
301         self.execute_command(['merge', path], config_override={
302             'merge.autopush': 'yes' if push else 'no',
303         })
304
305     def undo(self):
306         self.execute_command(['undo'])
307
308     # Backend interface implementation
309
310     def filter_tasks(self, filter_obj):
311         self.enforce_recurrence()
312         args = ['export'] + filter_obj.get_filter_params()
313         tasks = []
314         for line in self.execute_command(args):
315             if line:
316                 data = line.strip(',')
317                 try:
318                     filtered_task = Task(self)
319                     filtered_task._load_data(json.loads(data))
320                     tasks.append(filtered_task)
321                 except ValueError:
322                     raise TaskWarriorException('Invalid JSON: %s' % data)
323         return tasks
324
325     def save_task(self, task):
326         """Save a task into TaskWarrior database using add/modify call"""
327
328         args = [task['uuid'], 'modify'] if task.saved else ['add']
329         args.extend(self._get_modified_task_fields_as_args(task))
330         output = self.execute_command(args)
331
332         # Parse out the new ID, if the task is being added for the first time
333         if not task.saved:
334             id_lines = [l for l in output if l.startswith('Created task ')]
335
336             # Complain loudly if it seems that more tasks were created
337             # Should not happen.
338             # Expected output: Created task 1.
339             #                  Created task 1 (recurrence template).
340             if len(id_lines) != 1 or len(id_lines[0].split(' ')) not in (3, 5):
341                 raise TaskWarriorException("Unexpected output when creating "
342                                            "task: %s" % '\n'.join(id_lines))
343
344             # Circumvent the ID storage, since ID is considered read-only
345             identifier = id_lines[0].split(' ')[2].rstrip('.')
346
347             # Identifier can be either ID or UUID for completed tasks
348             try:
349                 task._data['id'] = int(identifier)
350             except ValueError:
351                 task._data['uuid'] = identifier
352
353         # Refreshing is very important here, as not only modification time
354         # is updated, but arbitrary attribute may have changed due hooks
355         # altering the data before saving
356         task.refresh(after_save=True)
357
358     def delete_task(self, task):
359         self.execute_command([task['uuid'], 'delete'])
360
361     def start_task(self, task):
362         self.execute_command([task['uuid'], 'start'])
363
364     def stop_task(self, task):
365         self.execute_command([task['uuid'], 'stop'])
366
367     def complete_task(self, task):
368         # Older versions of TW do not stop active task at completion
369         if self.version < self.VERSION_2_4_0 and task.active:
370             task.stop()
371
372         self.execute_command([task['uuid'], 'done'])
373
374     def annotate_task(self, task, annotation):
375         args = [task['uuid'], 'annotate', annotation]
376         self.execute_command(args)
377
378     def denotate_task(self, task, annotation):
379         args = [task['uuid'], 'denotate', annotation]
380         self.execute_command(args)
381
382     def refresh_task(self, task, after_save=False):
383         # We need to use ID as backup for uuid here for the refreshes
384         # of newly saved tasks. Any other place in the code is fine
385         # with using UUID only.
386         args = [task['uuid'] or task['id'], 'export']
387         output = self.execute_command(args)
388
389         def valid(output):
390             return len(output) == 1 and output[0].startswith('{')
391
392         # For older TW versions attempt to uniquely locate the task
393         # using the data we have if it has been just saved.
394         # This can happen when adding a completed task on older TW versions.
395         if (not valid(output) and self.version < self.VERSION_2_4_5
396                 and after_save):
397
398             # Make a copy, removing ID and UUID. It's most likely invalid
399             # (ID 0) if it failed to match a unique task.
400             data = copy.deepcopy(task._data)
401             data.pop('id', None)
402             data.pop('uuid', None)
403
404             taskfilter = self.filter_class(self)
405             for key, value in data.items():
406                 taskfilter.add_filter_param(key, value)
407
408             output = self.execute_command(['export'] +
409                                           taskfilter.get_filter_params())
410
411         # If more than 1 task has been matched still, raise an exception
412         if not valid(output):
413             raise TaskWarriorException(
414                 "Unique identifiers {0} with description: {1} matches "
415                 "multiple tasks: {2}".format(
416                     task['uuid'] or task['id'], task['description'], output)
417             )
418
419         return json.loads(output[0])
420
421     def sync(self):
422         self.execute_command(['sync'])