]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[fix] small fix to rebuild
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplementedError
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True,
98                  taskrc_location='~/.taskrc'):
99         self.taskrc_location = os.path.expanduser(taskrc_location)
100
101         # If taskrc does not exist, pass / to use defaults and avoid creating
102         # dummy .taskrc file by TaskWarrior
103         if not os.path.exists(self.taskrc_location):
104             self.taskrc_location = '/'
105
106         self._config = None
107         self.version = self._get_version()
108         self.overrides = {
109             'confirmation': 'no',
110             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
111             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
112
113             # Defaults to on since 2.4.5, we expect off during parsing
114             'json.array': 'off',
115
116             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
117             # arbitrary big number which is likely to be large enough
118             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
119         }
120
121         # Set data.location override if passed via kwarg
122         if data_location is not None:
123             data_location = os.path.expanduser(data_location)
124             if create and not os.path.exists(data_location):
125                 os.makedirs(data_location)
126             self.overrides['data.location'] = data_location
127
128         self.tasks = TaskQuerySet(self)
129
130     def _get_command_args(self, args, config_override=None):
131         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
132         overrides = self.overrides.copy()
133         overrides.update(config_override or dict())
134         for item in overrides.items():
135             command_args.append('rc.{0}={1}'.format(*item))
136         command_args.extend([
137             x.decode('utf-8') if isinstance(x, six.binary_type)
138             else six.text_type(x) for x in args
139         ])
140         return command_args
141
142     def _get_version(self):
143         p = subprocess.Popen(
144             ['task', '--version'],
145             stdout=subprocess.PIPE,
146             stderr=subprocess.PIPE)
147         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
148         return stdout.strip('\n')
149
150     def _get_modified_task_fields_as_args(self, task):
151         args = []
152
153         def add_field(field):
154             # Add the output of format_field method to args list (defaults to
155             # field:value)
156             serialized_value = task._serialize(field, task._data[field])
157
158             # Empty values should not be enclosed in quotation marks, see
159             # TW-1510
160             if serialized_value is '':
161                 escaped_serialized_value = ''
162             else:
163                 escaped_serialized_value = six.u("'{0}'").format(
164                     serialized_value)
165
166             format_default = lambda task: six.u("{0}:{1}").format(
167                 field, escaped_serialized_value)
168
169             format_func = getattr(self, 'format_{0}'.format(field),
170                                   format_default)
171
172             args.append(format_func(task))
173
174         # If we're modifying saved task, simply pass on all modified fields
175         if task.saved:
176             for field in task._modified_fields:
177                 add_field(field)
178
179         # For new tasks, pass all fields that make sense
180         else:
181             for field in task._data.keys():
182                 # We cannot set stuff that's read only (ID, UUID, ..)
183                 if field in task.read_only_fields:
184                     continue
185                 # We do not want to do field deletion for new tasks
186                 if task._data[field] is None:
187                     continue
188                 # Otherwise we're fine
189                 add_field(field)
190
191         return args
192
193     def format_depends(self, task):
194         # We need to generate added and removed dependencies list,
195         # since Taskwarrior does not accept redefining dependencies.
196
197         # This cannot be part of serialize_depends, since we need
198         # to keep a list of all depedencies in the _data dictionary,
199         # not just currently added/removed ones
200
201         old_dependencies = task._original_data.get('depends', set())
202
203         added = task['depends'] - old_dependencies
204         removed = old_dependencies - task['depends']
205
206         # Removed dependencies need to be prefixed with '-'
207         return 'depends:' + ','.join(
208             [t['uuid'] for t in added] +
209             ['-' + t['uuid'] for t in removed]
210         )
211
212     def format_description(self, task):
213         # Task version older than 2.4.0 ignores first word of the
214         # task description if description: prefix is used
215         if self.version < self.VERSION_2_4_0:
216             return task._data['description']
217         else:
218             return six.u("description:'{0}'").format(task._data['description']
219                                                      or '')
220
221     def convert_datetime_string(self, value):
222
223         if self.version >= self.VERSION_2_4_0:
224             # For strings, use 'task calc' to evaluate the string to datetime
225             # available since TW 2.4.0
226             args = value.split()
227             result = self.execute_command(['calc'] + args)
228             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
229             localized = local_zone.localize(naive)
230         else:
231             raise ValueError("Provided value could not be converted to "
232                              "datetime, its type is not supported: {}"
233                              .format(type(value)))
234
235         return localized
236
237     @property
238     def filter_class(self):
239         return TaskWarriorFilter
240
241     # Public interface
242
243     @property
244     def config(self):
245         # First, check if memoized information is available
246         if self._config:
247             return self._config
248
249         # If not, fetch the config using the 'show' command
250         raw_output = self.execute_command(
251             ['show'],
252             config_override={'verbose': 'nothing'}
253         )
254
255         config = dict()
256         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].*$)')
257
258         for line in raw_output:
259             match = config_regex.match(line)
260             if match:
261                 config[match.group('key')] = match.group('value').strip()
262
263         # Memoize the config dict
264         self._config = ReadOnlyDictView(config)
265
266         return self._config
267
268     def execute_command(self, args, config_override=None, allow_failure=True,
269                         return_all=False):
270         command_args = self._get_command_args(
271             args, config_override=config_override)
272         logger.debug(u' '.join(command_args))
273
274         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
275                              stderr=subprocess.PIPE)
276         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
277         if p.returncode and allow_failure:
278             if stderr.strip():
279                 error_msg = stderr.strip()
280             else:
281                 error_msg = stdout.strip()
282             error_msg += u'\nCommand used: ' + u' '.join(command_args)
283             raise TaskWarriorException(error_msg)
284
285         # Return all whole triplet only if explicitly asked for
286         if not return_all:
287             return stdout.rstrip().split('\n')
288         else:
289             return (stdout.rstrip().split('\n'),
290                     stderr.rstrip().split('\n'),
291                     p.returncode)
292
293     def enforce_recurrence(self):
294         # Run arbitrary report command which will trigger generation
295         # of recurrent tasks.
296
297         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
298         if self.version < self.VERSION_2_4_2:
299             self.execute_command(['next'], allow_failure=False)
300
301     def merge_with(self, path, push=False):
302         path = path.rstrip('/') + '/'
303         self.execute_command(['merge', path], config_override={
304             'merge.autopush': 'yes' if push else 'no',
305         })
306
307     def undo(self):
308         self.execute_command(['undo'])
309
310     # Backend interface implementation
311
312     def filter_tasks(self, filter_obj):
313         self.enforce_recurrence()
314         args = ['export'] + filter_obj.get_filter_params()
315         tasks = []
316         for line in self.execute_command(args):
317             if line:
318                 data = line.strip(',')
319                 try:
320                     filtered_task = Task(self)
321                     filtered_task._load_data(json.loads(data))
322                     tasks.append(filtered_task)
323                 except ValueError:
324                     raise TaskWarriorException('Invalid JSON: %s' % data)
325         return tasks
326
327     def save_task(self, task):
328         """Save a task into TaskWarrior database using add/modify call"""
329
330         args = [task['uuid'], 'modify'] if task.saved else ['add']
331         args.extend(self._get_modified_task_fields_as_args(task))
332         output = self.execute_command(args)
333
334         # Parse out the new ID, if the task is being added for the first time
335         if not task.saved:
336             id_lines = [l for l in output if l.startswith('Created task ')]
337
338             # Complain loudly if it seems that more tasks were created
339             # Should not happen.
340             # Expected output: Created task 1.
341             #                  Created task 1 (recurrence template).
342             if len(id_lines) != 1 or len(id_lines[0].split(' ')) not in (3, 5):
343                 raise TaskWarriorException("Unexpected output when creating "
344                                            "task: %s" % '\n'.join(id_lines))
345
346             # Circumvent the ID storage, since ID is considered read-only
347             identifier = id_lines[0].split(' ')[2].rstrip('.')
348
349             # Identifier can be either ID or UUID for completed tasks
350             try:
351                 task._data['id'] = int(identifier)
352             except ValueError:
353                 task._data['uuid'] = identifier
354
355         # Refreshing is very important here, as not only modification time
356         # is updated, but arbitrary attribute may have changed due hooks
357         # altering the data before saving
358         task.refresh(after_save=True)
359
360     def delete_task(self, task):
361         self.execute_command([task['uuid'], 'delete'])
362
363     def start_task(self, task):
364         self.execute_command([task['uuid'], 'start'])
365
366     def stop_task(self, task):
367         self.execute_command([task['uuid'], 'stop'])
368
369     def complete_task(self, task):
370         # Older versions of TW do not stop active task at completion
371         if self.version < self.VERSION_2_4_0 and task.active:
372             task.stop()
373
374         self.execute_command([task['uuid'], 'done'])
375
376     def annotate_task(self, task, annotation):
377         args = [task['uuid'], 'annotate', annotation]
378         self.execute_command(args)
379
380     def denotate_task(self, task, annotation):
381         args = [task['uuid'], 'denotate', annotation]
382         self.execute_command(args)
383
384     def refresh_task(self, task, after_save=False):
385         # We need to use ID as backup for uuid here for the refreshes
386         # of newly saved tasks. Any other place in the code is fine
387         # with using UUID only.
388         args = [task['uuid'] or task['id'], 'export']
389         output = self.execute_command(args)
390
391         def valid(output):
392             return len(output) == 1 and output[0].startswith('{')
393
394         # For older TW versions attempt to uniquely locate the task
395         # using the data we have if it has been just saved.
396         # This can happen when adding a completed task on older TW versions.
397         if (not valid(output) and self.version < self.VERSION_2_4_5
398                 and after_save):
399
400             # Make a copy, removing ID and UUID. It's most likely invalid
401             # (ID 0) if it failed to match a unique task.
402             data = copy.deepcopy(task._data)
403             data.pop('id', None)
404             data.pop('uuid', None)
405
406             taskfilter = self.filter_class(self)
407             for key, value in data.items():
408                 taskfilter.add_filter_param(key, value)
409
410             output = self.execute_command(['export'] +
411                                           taskfilter.get_filter_params())
412
413         # If more than 1 task has been matched still, raise an exception
414         if not valid(output):
415             raise TaskWarriorException(
416                 "Unique identifiers {0} with description: {1} matches "
417                 "multiple tasks: {2}".format(
418                     task['uuid'] or task['id'], task['description'], output)
419             )
420
421         return json.loads(output[0])
422
423     def sync(self):
424         self.execute_command(['sync'])