]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Merge pull request #54 from nesaro/removes_warning
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplementedError
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True,
98                  taskrc_location='~/.taskrc'):
99         self.taskrc_location = os.path.expanduser(taskrc_location)
100
101         # If taskrc does not exist, pass / to use defaults and avoid creating
102         # dummy .taskrc file by TaskWarrior
103         if not os.path.exists(self.taskrc_location):
104             self.taskrc_location = '/'
105
106         self._config = None
107         self.version = self._get_version()
108         self.overrides = {
109             'confirmation': 'no',
110             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
111             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
112
113             # Defaults to on since 2.4.5, we expect off during parsing
114             'json.array': 'off',
115
116             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
117             # arbitrary big number which is likely to be large enough
118             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
119         }
120
121         # Set data.location override if passed via kwarg
122         if data_location is not None:
123             data_location = os.path.expanduser(data_location)
124             if create and not os.path.exists(data_location):
125                 os.makedirs(data_location)
126             self.overrides['data.location'] = data_location
127
128         self.tasks = TaskQuerySet(self)
129
130     def _get_command_args(self, args, config_override=None):
131         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
132         overrides = self.overrides.copy()
133         overrides.update(config_override or dict())
134         for item in overrides.items():
135             command_args.append('rc.{0}={1}'.format(*item))
136         command_args.extend([
137             x.decode('utf-8') if isinstance(x, six.binary_type)
138             else six.text_type(x) for x in args
139         ])
140         return command_args
141
142     def _get_version(self):
143         p = subprocess.Popen(
144             ['task', '--version'],
145             stdout=subprocess.PIPE,
146             stderr=subprocess.PIPE)
147         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
148         return stdout.strip('\n')
149
150     def _get_modified_task_fields_as_args(self, task):
151         args = []
152
153         def add_field(field):
154             # Add the output of format_field method to args list (defaults to
155             # field:value)
156             serialized_value = task._serialize(field, task._data[field])
157
158             # Empty values should not be enclosed in quotation marks, see
159             # TW-1510
160             if serialized_value is '':
161                 escaped_serialized_value = ''
162             else:
163                 escaped_serialized_value = six.u("'{0}'").format(
164                     serialized_value)
165
166             format_default = lambda task: six.u("{0}:{1}").format(
167                 field, escaped_serialized_value)
168
169             format_func = getattr(self, 'format_{0}'.format(field),
170                                   format_default)
171
172             args.append(format_func(task))
173
174         # If we're modifying saved task, simply pass on all modified fields
175         if task.saved:
176             for field in task._modified_fields:
177                 add_field(field)
178
179         # For new tasks, pass all fields that make sense
180         else:
181             for field in task._data.keys():
182                 # We cannot set stuff that's read only (ID, UUID, ..)
183                 if field in task.read_only_fields:
184                     continue
185                 # We do not want to do field deletion for new tasks
186                 if task._data[field] is None:
187                     continue
188                 # Otherwise we're fine
189                 add_field(field)
190
191         return args
192
193     def format_depends(self, task):
194         # We need to generate added and removed dependencies list,
195         # since Taskwarrior does not accept redefining dependencies.
196
197         # This cannot be part of serialize_depends, since we need
198         # to keep a list of all depedencies in the _data dictionary,
199         # not just currently added/removed ones
200
201         old_dependencies = task._original_data.get('depends', set())
202
203         added = task['depends'] - old_dependencies
204         removed = old_dependencies - task['depends']
205
206         # Removed dependencies need to be prefixed with '-'
207         return 'depends:' + ','.join(
208             [t['uuid'] for t in added] +
209             ['-' + t['uuid'] for t in removed]
210         )
211
212     def format_description(self, task):
213         # Task version older than 2.4.0 ignores first word of the
214         # task description if description: prefix is used
215         if self.version < self.VERSION_2_4_0:
216             return task._data['description']
217         else:
218             return six.u("description:'{0}'").format(
219                 task._data['description'] or '',
220             )
221
222     def convert_datetime_string(self, value):
223
224         if self.version >= self.VERSION_2_4_0:
225             # For strings, use 'task calc' to evaluate the string to datetime
226             # available since TW 2.4.0
227             args = value.split()
228             result = self.execute_command(['calc'] + args)
229             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
230             localized = local_zone.localize(naive)
231         else:
232             raise ValueError(
233                 'Provided value could not be converted to '
234                 'datetime, its type is not supported: {}'
235                 .format(type(value)),
236             )
237
238         return localized
239
240     @property
241     def filter_class(self):
242         return TaskWarriorFilter
243
244     # Public interface
245
246     @property
247     def config(self):
248         # First, check if memoized information is available
249         if self._config:
250             return self._config
251
252         # If not, fetch the config using the 'show' command
253         raw_output = self.execute_command(
254             ['show'],
255             config_override={'verbose': 'nothing'}
256         )
257
258         config = dict()
259         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].*$)')
260
261         for line in raw_output:
262             match = config_regex.match(line)
263             if match:
264                 config[match.group('key')] = match.group('value').strip()
265
266         # Memoize the config dict
267         self._config = ReadOnlyDictView(config)
268
269         return self._config
270
271     def execute_command(self, args, config_override=None, allow_failure=True,
272                         return_all=False):
273         command_args = self._get_command_args(
274             args, config_override=config_override)
275         logger.debug(u' '.join(command_args))
276
277         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
278                              stderr=subprocess.PIPE)
279         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
280         if p.returncode and allow_failure:
281             if stderr.strip():
282                 error_msg = stderr.strip()
283             else:
284                 error_msg = stdout.strip()
285             error_msg += u'\nCommand used: ' + u' '.join(command_args)
286             raise TaskWarriorException(error_msg)
287
288         # Return all whole triplet only if explicitly asked for
289         if not return_all:
290             return stdout.rstrip().split('\n')
291         else:
292             return (stdout.rstrip().split('\n'),
293                     stderr.rstrip().split('\n'),
294                     p.returncode)
295
296     def enforce_recurrence(self):
297         # Run arbitrary report command which will trigger generation
298         # of recurrent tasks.
299
300         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
301         if self.version < self.VERSION_2_4_2:
302             self.execute_command(['next'], allow_failure=False)
303
304     def merge_with(self, path, push=False):
305         path = path.rstrip('/') + '/'
306         self.execute_command(['merge', path], config_override={
307             'merge.autopush': 'yes' if push else 'no',
308         })
309
310     def undo(self):
311         self.execute_command(['undo'])
312
313     # Backend interface implementation
314
315     def filter_tasks(self, filter_obj):
316         self.enforce_recurrence()
317         args = ['export'] + filter_obj.get_filter_params()
318         tasks = []
319         for line in self.execute_command(args):
320             if line:
321                 data = line.strip(',')
322                 try:
323                     filtered_task = Task(self)
324                     filtered_task._load_data(json.loads(data))
325                     tasks.append(filtered_task)
326                 except ValueError:
327                     raise TaskWarriorException('Invalid JSON: %s' % data)
328         return tasks
329
330     def save_task(self, task):
331         """Save a task into TaskWarrior database using add/modify call"""
332
333         args = [task['uuid'], 'modify'] if task.saved else ['add']
334         args.extend(self._get_modified_task_fields_as_args(task))
335         output = self.execute_command(args)
336
337         # Parse out the new ID, if the task is being added for the first time
338         if not task.saved:
339             id_lines = [l for l in output if l.startswith('Created task ')]
340
341             # Complain loudly if it seems that more tasks were created
342             # Should not happen.
343             # Expected output: Created task 1.
344             #                  Created task 1 (recurrence template).
345             if len(id_lines) != 1 or len(id_lines[0].split(' ')) not in (3, 5):
346                 raise TaskWarriorException(
347                     'Unexpected output when creating '
348                     'task: %s' % '\n'.join(id_lines),
349                 )
350
351             # Circumvent the ID storage, since ID is considered read-only
352             identifier = id_lines[0].split(' ')[2].rstrip('.')
353
354             # Identifier can be either ID or UUID for completed tasks
355             try:
356                 task._data['id'] = int(identifier)
357             except ValueError:
358                 task._data['uuid'] = identifier
359
360         # Refreshing is very important here, as not only modification time
361         # is updated, but arbitrary attribute may have changed due hooks
362         # altering the data before saving
363         task.refresh(after_save=True)
364
365     def delete_task(self, task):
366         self.execute_command([task['uuid'], 'delete'])
367
368     def start_task(self, task):
369         self.execute_command([task['uuid'], 'start'])
370
371     def stop_task(self, task):
372         self.execute_command([task['uuid'], 'stop'])
373
374     def complete_task(self, task):
375         # Older versions of TW do not stop active task at completion
376         if self.version < self.VERSION_2_4_0 and task.active:
377             task.stop()
378
379         self.execute_command([task['uuid'], 'done'])
380
381     def annotate_task(self, task, annotation):
382         args = [task['uuid'], 'annotate', annotation]
383         self.execute_command(args)
384
385     def denotate_task(self, task, annotation):
386         args = [task['uuid'], 'denotate', annotation]
387         self.execute_command(args)
388
389     def refresh_task(self, task, after_save=False):
390         # We need to use ID as backup for uuid here for the refreshes
391         # of newly saved tasks. Any other place in the code is fine
392         # with using UUID only.
393         args = [task['uuid'] or task['id'], 'export']
394         output = self.execute_command(args)
395
396         def valid(output):
397             return len(output) == 1 and output[0].startswith('{')
398
399         # For older TW versions attempt to uniquely locate the task
400         # using the data we have if it has been just saved.
401         # This can happen when adding a completed task on older TW versions.
402         if (not valid(output) and self.version < self.VERSION_2_4_5
403                 and after_save):
404
405             # Make a copy, removing ID and UUID. It's most likely invalid
406             # (ID 0) if it failed to match a unique task.
407             data = copy.deepcopy(task._data)
408             data.pop('id', None)
409             data.pop('uuid', None)
410
411             taskfilter = self.filter_class(self)
412             for key, value in data.items():
413                 taskfilter.add_filter_param(key, value)
414
415             output = self.execute_command(['export'] +
416                                           taskfilter.get_filter_params())
417
418         # If more than 1 task has been matched still, raise an exception
419         if not valid(output):
420             raise TaskWarriorException(
421                 'Unique identifiers {0} with description: {1} matches '
422                 'multiple tasks: {2}'.format(
423                     task['uuid'] or task['id'], task['description'], output)
424             )
425
426         return json.loads(output[0])
427
428     def sync(self):
429         self.execute_command(['sync'])