]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

lazy: Provide full implementation of the set interface in LazyUUIDTaskSet
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10
11 from .task import Task, TaskQuerySet, ReadOnlyDictView
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplemented
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
98         self.taskrc_location = os.path.expanduser(taskrc_location)
99
100         # If taskrc does not exist, pass / to use defaults and avoid creating
101         # dummy .taskrc file by TaskWarrior
102         if not os.path.exists(self.taskrc_location):
103             self.taskrc_location = '/'
104
105         self._config = None
106         self.version = self._get_version()
107         self.overrides = {
108             'confirmation': 'no',
109             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
110             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
111
112             # Defaults to on since 2.4.5, we expect off during parsing
113             'json.array': 'off',
114
115             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
116             # arbitrary big number which is likely to be large enough
117             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
118         }
119
120         # Set data.location override if passed via kwarg
121         if data_location is not None:
122             data_location = os.path.expanduser(data_location)
123             if create and not os.path.exists(data_location):
124                 os.makedirs(data_location)
125             self.overrides['data.location'] = data_location
126
127         self.tasks = TaskQuerySet(self)
128
129     def _get_command_args(self, args, config_override=None):
130         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
131         overrides = self.overrides.copy()
132         overrides.update(config_override or dict())
133         for item in overrides.items():
134             command_args.append('rc.{0}={1}'.format(*item))
135         command_args.extend(map(six.text_type, args))
136         return command_args
137
138     def _get_version(self):
139         p = subprocess.Popen(
140             ['task', '--version'],
141             stdout=subprocess.PIPE,
142             stderr=subprocess.PIPE)
143         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
144         return stdout.strip('\n')
145
146     def _get_modified_task_fields_as_args(self, task):
147         args = []
148
149         def add_field(field):
150             # Add the output of format_field method to args list (defaults to
151             # field:value)
152             serialized_value = task._serialize(field, task._data[field])
153
154             # Empty values should not be enclosed in quotation marks, see
155             # TW-1510
156             if serialized_value is '':
157                 escaped_serialized_value = ''
158             else:
159                 escaped_serialized_value = six.u("'{0}'").format(
160                     serialized_value)
161
162             format_default = lambda task: six.u("{0}:{1}").format(
163                 field, escaped_serialized_value)
164
165             format_func = getattr(self, 'format_{0}'.format(field),
166                                   format_default)
167
168             args.append(format_func(task))
169
170         # If we're modifying saved task, simply pass on all modified fields
171         if task.saved:
172             for field in task._modified_fields:
173                 add_field(field)
174         # For new tasks, pass all fields that make sense
175         else:
176             for field in task._data.keys():
177                 if field in task.read_only_fields:
178                     continue
179                 add_field(field)
180
181         return args
182
183     def format_depends(self, task):
184         # We need to generate added and removed dependencies list,
185         # since Taskwarrior does not accept redefining dependencies.
186
187         # This cannot be part of serialize_depends, since we need
188         # to keep a list of all depedencies in the _data dictionary,
189         # not just currently added/removed ones
190
191         old_dependencies = task._original_data.get('depends', set())
192
193         added = task['depends'] - old_dependencies
194         removed = old_dependencies - task['depends']
195
196         # Removed dependencies need to be prefixed with '-'
197         return 'depends:' + ','.join(
198             [t['uuid'] for t in added] +
199             ['-' + t['uuid'] for t in removed]
200         )
201
202     def format_description(self, task):
203         # Task version older than 2.4.0 ignores first word of the
204         # task description if description: prefix is used
205         if self.version < self.VERSION_2_4_0:
206             return task._data['description']
207         else:
208             return six.u("description:'{0}'").format(task._data['description'] or '')
209
210     def convert_datetime_string(self, value):
211
212         if self.version >= self.VERSION_2_4_0:
213             # For strings, use 'task calc' to evaluate the string to datetime
214             # available since TW 2.4.0
215             args = value.split()
216             result = self.execute_command(['calc'] + args)
217             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
218             localized = local_zone.localize(naive)
219         else:
220             raise ValueError("Provided value could not be converted to "
221                              "datetime, its type is not supported: {}"
222                              .format(type(value)))
223
224         return localized
225
226     @property
227     def filter_class(self):
228         return TaskWarriorFilter
229
230     # Public interface
231
232     @property
233     def config(self):
234         # First, check if memoized information is available
235         if self._config:
236             return self._config
237
238         # If not, fetch the config using the 'show' command
239         raw_output = self.execute_command(
240             ['show'],
241             config_override={'verbose': 'nothing'}
242         )
243
244         config = dict()
245         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
246
247         for line in raw_output:
248             match = config_regex.match(line)
249             if match:
250                 config[match.group('key')] = match.group('value').strip()
251
252         # Memoize the config dict
253         self._config = ReadOnlyDictView(config)
254
255         return self._config
256
257     def execute_command(self, args, config_override=None, allow_failure=True,
258                         return_all=False):
259         command_args = self._get_command_args(
260             args, config_override=config_override)
261         logger.debug(' '.join(command_args))
262         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
263                              stderr=subprocess.PIPE)
264         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
265         if p.returncode and allow_failure:
266             if stderr.strip():
267                 error_msg = stderr.strip()
268             else:
269                 error_msg = stdout.strip()
270             raise TaskWarriorException(error_msg)
271
272         # Return all whole triplet only if explicitly asked for
273         if not return_all:
274             return stdout.rstrip().split('\n')
275         else:
276             return (stdout.rstrip().split('\n'),
277                     stderr.rstrip().split('\n'),
278                     p.returncode)
279
280     def enforce_recurrence(self):
281         # Run arbitrary report command which will trigger generation
282         # of recurrent tasks.
283
284         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
285         if self.version < self.VERSION_2_4_2:
286             self.execute_command(['next'], allow_failure=False)
287
288     def merge_with(self, path, push=False):
289         path = path.rstrip('/') + '/'
290         self.execute_command(['merge', path], config_override={
291             'merge.autopush': 'yes' if push else 'no',
292         })
293
294     def undo(self):
295         self.execute_command(['undo'])
296
297     # Backend interface implementation
298
299     def filter_tasks(self, filter_obj):
300         self.enforce_recurrence()
301         args = ['export'] + filter_obj.get_filter_params()
302         tasks = []
303         for line in self.execute_command(args):
304             if line:
305                 data = line.strip(',')
306                 try:
307                     filtered_task = Task(self)
308                     filtered_task._load_data(json.loads(data))
309                     tasks.append(filtered_task)
310                 except ValueError:
311                     raise TaskWarriorException('Invalid JSON: %s' % data)
312         return tasks
313
314     def save_task(self, task):
315         """Save a task into TaskWarrior database using add/modify call"""
316
317         args = [task['uuid'], 'modify'] if task.saved else ['add']
318         args.extend(self._get_modified_task_fields_as_args(task))
319         output = self.execute_command(args)
320
321         # Parse out the new ID, if the task is being added for the first time
322         if not task.saved:
323             id_lines = [l for l in output if l.startswith('Created task ')]
324
325             # Complain loudly if it seems that more tasks were created
326             # Should not happen
327             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
328                 raise TaskWarriorException("Unexpected output when creating "
329                                            "task: %s" % '\n'.join(id_lines))
330
331             # Circumvent the ID storage, since ID is considered read-only
332             identifier = id_lines[0].split(' ')[2].rstrip('.')
333
334             # Identifier can be either ID or UUID for completed tasks
335             try:
336                 task._data['id'] = int(identifier)
337             except ValueError:
338                 task._data['uuid'] = identifier
339
340         # Refreshing is very important here, as not only modification time
341         # is updated, but arbitrary attribute may have changed due hooks
342         # altering the data before saving
343         task.refresh(after_save=True)
344
345     def delete_task(self, task):
346         self.execute_command([task['uuid'], 'delete'])
347
348     def start_task(self, task):
349         self.execute_command([task['uuid'], 'start'])
350
351     def stop_task(self, task):
352         self.execute_command([task['uuid'], 'stop'])
353
354     def complete_task(self, task):
355         # Older versions of TW do not stop active task at completion
356         if self.version < self.VERSION_2_4_0 and task.active:
357             task.stop()
358
359         self.execute_command([task['uuid'], 'done'])
360
361     def annotate_task(self, task, annotation):
362         args = [task['uuid'], 'annotate', annotation]
363         self.execute_command(args)
364
365     def denotate_task(self, task, annotation):
366         args = [task['uuid'], 'denotate', annotation]
367         self.execute_command(args)
368
369     def refresh_task(self, task, after_save=False):
370         # We need to use ID as backup for uuid here for the refreshes
371         # of newly saved tasks. Any other place in the code is fine
372         # with using UUID only.
373         args = [task['uuid'] or task['id'], 'export']
374         output = self.execute_command(args)
375
376         def valid(output):
377             return len(output) == 1 and output[0].startswith('{')
378
379         # For older TW versions attempt to uniquely locate the task
380         # using the data we have if it has been just saved.
381         # This can happen when adding a completed task on older TW versions.
382         if (not valid(output) and self.version < self.VERSION_2_4_5
383                 and after_save):
384
385             # Make a copy, removing ID and UUID. It's most likely invalid
386             # (ID 0) if it failed to match a unique task.
387             data = copy.deepcopy(task._data)
388             data.pop('id', None)
389             data.pop('uuid', None)
390
391             taskfilter = self.filter_class(self)
392             for key, value in data.items():
393                 taskfilter.add_filter_param(key, value)
394
395             output = self.execute_command(['export'] +
396                                           taskfilter.get_filter_params())
397
398         # If more than 1 task has been matched still, raise an exception
399         if not valid(output):
400             raise TaskWarriorException(
401                 "Unique identifiers {0} with description: {1} matches "
402                 "multiple tasks: {2}".format(
403                     task['uuid'] or task['id'], task['description'], output)
404             )
405
406         return json.loads(output[0])
407
408     def sync(self):
409         self.execute_command(['sync'])