]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for custom `task` location or command
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplementedError
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True,
98                  taskrc_location='~/.taskrc', task_command='task'):
99         self.taskrc_location = os.path.expanduser(taskrc_location)
100
101         # If taskrc does not exist, pass / to use defaults and avoid creating
102         # dummy .taskrc file by TaskWarrior
103         if not os.path.exists(self.taskrc_location):
104             self.taskrc_location = '/'
105
106         self.task_command = task_command
107
108         self._config = None
109         self.version = self._get_version()
110         self.overrides = {
111             'confirmation': 'no',
112             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
113             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
114
115             # Defaults to on since 2.4.5, we expect off during parsing
116             'json.array': 'off',
117
118             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
119             # arbitrary big number which is likely to be large enough
120             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
121         }
122
123         # Set data.location override if passed via kwarg
124         if data_location is not None:
125             data_location = os.path.expanduser(data_location)
126             if create and not os.path.exists(data_location):
127                 os.makedirs(data_location)
128             self.overrides['data.location'] = data_location
129
130         self.tasks = TaskQuerySet(self)
131
132     def _get_task_command(self):
133         return self.task_command.split()
134
135     def _get_command_args(self, args, config_override=None):
136         command_args = self._get_task_command() + ['rc:{0}'.format(self.taskrc_location)]
137         overrides = self.overrides.copy()
138         overrides.update(config_override or dict())
139         for item in overrides.items():
140             command_args.append('rc.{0}={1}'.format(*item))
141         command_args.extend([
142             x.decode('utf-8') if isinstance(x, six.binary_type)
143             else six.text_type(x) for x in args
144         ])
145         return command_args
146
147     def _get_version(self):
148         p = subprocess.Popen(
149             self._get_task_command() + ['--version'],
150             stdout=subprocess.PIPE,
151             stderr=subprocess.PIPE)
152         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
153         return stdout.strip('\n')
154
155     def _get_modified_task_fields_as_args(self, task):
156         args = []
157
158         def add_field(field):
159             # Add the output of format_field method to args list (defaults to
160             # field:value)
161             serialized_value = task._serialize(field, task._data[field])
162
163             # Empty values should not be enclosed in quotation marks, see
164             # TW-1510
165             if serialized_value is '':
166                 escaped_serialized_value = ''
167             else:
168                 escaped_serialized_value = six.u("'{0}'").format(
169                     serialized_value)
170
171             format_default = lambda task: six.u("{0}:{1}").format(
172                 field, escaped_serialized_value)
173
174             format_func = getattr(self, 'format_{0}'.format(field),
175                                   format_default)
176
177             args.append(format_func(task))
178
179         # If we're modifying saved task, simply pass on all modified fields
180         if task.saved:
181             for field in task._modified_fields:
182                 add_field(field)
183
184         # For new tasks, pass all fields that make sense
185         else:
186             for field in task._data.keys():
187                 # We cannot set stuff that's read only (ID, UUID, ..)
188                 if field in task.read_only_fields:
189                     continue
190                 # We do not want to do field deletion for new tasks
191                 if task._data[field] is None:
192                     continue
193                 # Otherwise we're fine
194                 add_field(field)
195
196         return args
197
198     def format_depends(self, task):
199         # We need to generate added and removed dependencies list,
200         # since Taskwarrior does not accept redefining dependencies.
201
202         # This cannot be part of serialize_depends, since we need
203         # to keep a list of all depedencies in the _data dictionary,
204         # not just currently added/removed ones
205
206         old_dependencies = task._original_data.get('depends', set())
207
208         added = task['depends'] - old_dependencies
209         removed = old_dependencies - task['depends']
210
211         # Removed dependencies need to be prefixed with '-'
212         return 'depends:' + ','.join(
213             [t['uuid'] for t in added] +
214             ['-' + t['uuid'] for t in removed]
215         )
216
217     def format_description(self, task):
218         # Task version older than 2.4.0 ignores first word of the
219         # task description if description: prefix is used
220         if self.version < self.VERSION_2_4_0:
221             return task._data['description']
222         else:
223             return six.u("description:'{0}'").format(
224                 task._data['description'] or '',
225             )
226
227     def convert_datetime_string(self, value):
228
229         if self.version >= self.VERSION_2_4_0:
230             # For strings, use 'calc' to evaluate the string to datetime
231             # available since TW 2.4.0
232             args = value.split()
233             result = self.execute_command(['calc'] + args)
234             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
235             localized = local_zone.localize(naive)
236         else:
237             raise ValueError(
238                 'Provided value could not be converted to '
239                 'datetime, its type is not supported: {}'
240                 .format(type(value)),
241             )
242
243         return localized
244
245     @property
246     def filter_class(self):
247         return TaskWarriorFilter
248
249     # Public interface
250
251     @property
252     def config(self):
253         # First, check if memoized information is available
254         if self._config:
255             return self._config
256
257         # If not, fetch the config using the 'show' command
258         raw_output = self.execute_command(
259             ['show'],
260             config_override={'verbose': 'nothing'}
261         )
262
263         config = dict()
264         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].*$)')
265
266         for line in raw_output:
267             match = config_regex.match(line)
268             if match:
269                 config[match.group('key')] = match.group('value').strip()
270
271         # Memoize the config dict
272         self._config = ReadOnlyDictView(config)
273
274         return self._config
275
276     def execute_command(self, args, config_override=None, allow_failure=True,
277                         return_all=False):
278         command_args = self._get_command_args(
279             args, config_override=config_override)
280         logger.debug(u' '.join(command_args))
281
282         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
283                              stderr=subprocess.PIPE)
284         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
285         if p.returncode and allow_failure:
286             if stderr.strip():
287                 error_msg = stderr.strip()
288             else:
289                 error_msg = stdout.strip()
290             error_msg += u'\nCommand used: ' + u' '.join(command_args)
291             raise TaskWarriorException(error_msg)
292
293         # Return all whole triplet only if explicitly asked for
294         if not return_all:
295             return stdout.rstrip().split('\n')
296         else:
297             return (stdout.rstrip().split('\n'),
298                     stderr.rstrip().split('\n'),
299                     p.returncode)
300
301     def enforce_recurrence(self):
302         # Run arbitrary report command which will trigger generation
303         # of recurrent tasks.
304
305         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
306         if self.version < self.VERSION_2_4_2:
307             self.execute_command(['next'], allow_failure=False)
308
309     def merge_with(self, path, push=False):
310         path = path.rstrip('/') + '/'
311         self.execute_command(['merge', path], config_override={
312             'merge.autopush': 'yes' if push else 'no',
313         })
314
315     def undo(self):
316         self.execute_command(['undo'])
317
318     # Backend interface implementation
319
320     def filter_tasks(self, filter_obj):
321         self.enforce_recurrence()
322         args = ['export'] + filter_obj.get_filter_params()
323         tasks = []
324         for line in self.execute_command(args):
325             if line:
326                 data = line.strip(',')
327                 try:
328                     filtered_task = Task(self)
329                     filtered_task._load_data(json.loads(data))
330                     tasks.append(filtered_task)
331                 except ValueError:
332                     raise TaskWarriorException('Invalid JSON: %s' % data)
333         return tasks
334
335     def save_task(self, task):
336         """Save a task into TaskWarrior database using add/modify call"""
337
338         args = [task['uuid'], 'modify'] if task.saved else ['add']
339         args.extend(self._get_modified_task_fields_as_args(task))
340         output = self.execute_command(args)
341
342         # Parse out the new ID, if the task is being added for the first time
343         if not task.saved:
344             id_lines = [l for l in output if l.startswith('Created task ')]
345
346             # Complain loudly if it seems that more tasks were created
347             # Should not happen.
348             # Expected output: Created task 1.
349             #                  Created task 1 (recurrence template).
350             if len(id_lines) != 1 or len(id_lines[0].split(' ')) not in (3, 5):
351                 raise TaskWarriorException(
352                     'Unexpected output when creating '
353                     'task: %s' % '\n'.join(id_lines),
354                 )
355
356             # Circumvent the ID storage, since ID is considered read-only
357             identifier = id_lines[0].split(' ')[2].rstrip('.')
358
359             # Identifier can be either ID or UUID for completed tasks
360             try:
361                 task._data['id'] = int(identifier)
362             except ValueError:
363                 task._data['uuid'] = identifier
364
365         # Refreshing is very important here, as not only modification time
366         # is updated, but arbitrary attribute may have changed due hooks
367         # altering the data before saving
368         task.refresh(after_save=True)
369
370     def delete_task(self, task):
371         self.execute_command([task['uuid'], 'delete'])
372
373     def start_task(self, task):
374         self.execute_command([task['uuid'], 'start'])
375
376     def stop_task(self, task):
377         self.execute_command([task['uuid'], 'stop'])
378
379     def complete_task(self, task):
380         # Older versions of TW do not stop active task at completion
381         if self.version < self.VERSION_2_4_0 and task.active:
382             task.stop()
383
384         self.execute_command([task['uuid'], 'done'])
385
386     def annotate_task(self, task, annotation):
387         args = [task['uuid'], 'annotate', annotation]
388         self.execute_command(args)
389
390     def denotate_task(self, task, annotation):
391         args = [task['uuid'], 'denotate', annotation]
392         self.execute_command(args)
393
394     def refresh_task(self, task, after_save=False):
395         # We need to use ID as backup for uuid here for the refreshes
396         # of newly saved tasks. Any other place in the code is fine
397         # with using UUID only.
398         args = [task['uuid'] or task['id'], 'export']
399         output = self.execute_command(args)
400
401         def valid(output):
402             return len(output) == 1 and output[0].startswith('{')
403
404         # For older TW versions attempt to uniquely locate the task
405         # using the data we have if it has been just saved.
406         # This can happen when adding a completed task on older TW versions.
407         if (not valid(output) and self.version < self.VERSION_2_4_5
408                 and after_save):
409
410             # Make a copy, removing ID and UUID. It's most likely invalid
411             # (ID 0) if it failed to match a unique task.
412             data = copy.deepcopy(task._data)
413             data.pop('id', None)
414             data.pop('uuid', None)
415
416             taskfilter = self.filter_class(self)
417             for key, value in data.items():
418                 taskfilter.add_filter_param(key, value)
419
420             output = self.execute_command(['export'] +
421                                           taskfilter.get_filter_params())
422
423         # If more than 1 task has been matched still, raise an exception
424         if not valid(output):
425             raise TaskWarriorException(
426                 'Unique identifiers {0} with description: {1} matches '
427                 'multiple tasks: {2}'.format(
428                     task['uuid'] or task['id'], task['description'], output)
429             )
430
431         return json.loads(output[0])
432
433     def sync(self):
434         self.execute_command(['sync'])