]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fa5ce866a97f8d3f00964ac817aa68f9e1613e9f
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10 import copy
11
12 from .task import Task, TaskQuerySet, ReadOnlyDictView
13 from .filters import TaskWarriorFilter
14 from .serializing import local_zone
15
16 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
17
18 logger = logging.getLogger(__name__)
19
20
21 class Backend(object):
22
23     @abc.abstractproperty
24     def filter_class(self):
25         """Returns the TaskFilter class used by this backend"""
26         pass
27
28     @abc.abstractmethod
29     def filter_tasks(self, filter_obj):
30         """Returns a list of Task objects matching the given filter"""
31         pass
32
33     @abc.abstractmethod
34     def save_task(self, task):
35         pass
36
37     @abc.abstractmethod
38     def delete_task(self, task):
39         pass
40
41     @abc.abstractmethod
42     def start_task(self, task):
43         pass
44
45     @abc.abstractmethod
46     def stop_task(self, task):
47         pass
48
49     @abc.abstractmethod
50     def complete_task(self, task):
51         pass
52
53     @abc.abstractmethod
54     def refresh_task(self, task, after_save=False):
55         """
56         Refreshes the given task. Returns new data dict with serialized
57         attributes.
58         """
59         pass
60
61     @abc.abstractmethod
62     def annotate_task(self, task, annotation):
63         pass
64
65     @abc.abstractmethod
66     def denotate_task(self, task, annotation):
67         pass
68
69     @abc.abstractmethod
70     def sync(self):
71         """Syncs the backend database with the taskd server"""
72         pass
73
74     def convert_datetime_string(self, value):
75         """
76         Converts TW syntax datetime string to a localized datetime
77         object. This method is not mandatory.
78         """
79         raise NotImplemented
80
81
82 class TaskWarriorException(Exception):
83     pass
84
85
86 class TaskWarrior(Backend):
87
88     VERSION_2_1_0 = six.u('2.1.0')
89     VERSION_2_2_0 = six.u('2.2.0')
90     VERSION_2_3_0 = six.u('2.3.0')
91     VERSION_2_4_0 = six.u('2.4.0')
92     VERSION_2_4_1 = six.u('2.4.1')
93     VERSION_2_4_2 = six.u('2.4.2')
94     VERSION_2_4_3 = six.u('2.4.3')
95     VERSION_2_4_4 = six.u('2.4.4')
96     VERSION_2_4_5 = six.u('2.4.5')
97
98     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
99         self.taskrc_location = os.path.expanduser(taskrc_location)
100
101         # If taskrc does not exist, pass / to use defaults and avoid creating
102         # dummy .taskrc file by TaskWarrior
103         if not os.path.exists(self.taskrc_location):
104             self.taskrc_location = '/'
105
106         self._config = None
107         self.version = self._get_version()
108         self.overrides = {
109             'confirmation': 'no',
110             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
111             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
112
113             # Defaults to on since 2.4.5, we expect off during parsing
114             'json.array': 'off',
115
116             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
117             # arbitrary big number which is likely to be large enough
118             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
119         }
120
121         # Set data.location override if passed via kwarg
122         if data_location is not None:
123             data_location = os.path.expanduser(data_location)
124             if create and not os.path.exists(data_location):
125                 os.makedirs(data_location)
126             self.overrides['data.location'] = data_location
127
128         self.tasks = TaskQuerySet(self)
129
130     def _get_command_args(self, args, config_override=None):
131         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
132         overrides = self.overrides.copy()
133         overrides.update(config_override or dict())
134         for item in overrides.items():
135             command_args.append('rc.{0}={1}'.format(*item))
136         command_args.extend(map(six.text_type, args))
137         return command_args
138
139     def _get_version(self):
140         p = subprocess.Popen(
141             ['task', '--version'],
142             stdout=subprocess.PIPE,
143             stderr=subprocess.PIPE)
144         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
145         return stdout.strip('\n')
146
147     def _get_modified_task_fields_as_args(self, task):
148         args = []
149
150         def add_field(field):
151             # Add the output of format_field method to args list (defaults to
152             # field:value)
153             serialized_value = task._serialize(field, task._data[field])
154
155             # Empty values should not be enclosed in quotation marks, see
156             # TW-1510
157             if serialized_value is '':
158                 escaped_serialized_value = ''
159             else:
160                 escaped_serialized_value = six.u("'{0}'").format(
161                     serialized_value)
162
163             format_default = lambda task: six.u("{0}:{1}").format(
164                 field, escaped_serialized_value)
165
166             format_func = getattr(self, 'format_{0}'.format(field),
167                                   format_default)
168
169             args.append(format_func(task))
170
171         # If we're modifying saved task, simply pass on all modified fields
172         if task.saved:
173             for field in task._modified_fields:
174                 add_field(field)
175         # For new tasks, pass all fields that make sense
176         else:
177             for field in task._data.keys():
178                 if field in task.read_only_fields:
179                     continue
180                 add_field(field)
181
182         return args
183
184     def format_depends(self, task):
185         # We need to generate added and removed dependencies list,
186         # since Taskwarrior does not accept redefining dependencies.
187
188         # This cannot be part of serialize_depends, since we need
189         # to keep a list of all depedencies in the _data dictionary,
190         # not just currently added/removed ones
191
192         old_dependencies = task._original_data.get('depends', set())
193
194         added = task['depends'] - old_dependencies
195         removed = old_dependencies - task['depends']
196
197         # Removed dependencies need to be prefixed with '-'
198         return 'depends:' + ','.join(
199             [t['uuid'] for t in added] +
200             ['-' + t['uuid'] for t in removed]
201         )
202
203     def format_description(self, task):
204         # Task version older than 2.4.0 ignores first word of the
205         # task description if description: prefix is used
206         if self.version < self.VERSION_2_4_0:
207             return task._data['description']
208         else:
209             return six.u("description:'{0}'").format(task._data['description'] or '')
210
211     def convert_datetime_string(self, value):
212
213         if self.version >= self.VERSION_2_4_0:
214             # For strings, use 'task calc' to evaluate the string to datetime
215             # available since TW 2.4.0
216             args = value.split()
217             result = self.execute_command(['calc'] + args)
218             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
219             localized = local_zone.localize(naive)
220         else:
221             raise ValueError("Provided value could not be converted to "
222                              "datetime, its type is not supported: {}"
223                              .format(type(value)))
224
225         return localized
226
227     @property
228     def filter_class(self):
229         return TaskWarriorFilter
230
231     # Public interface
232
233     @property
234     def config(self):
235         # First, check if memoized information is available
236         if self._config:
237             return self._config
238
239         # If not, fetch the config using the 'show' command
240         raw_output = self.execute_command(
241             ['show'],
242             config_override={'verbose': 'nothing'}
243         )
244
245         config = dict()
246         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
247
248         for line in raw_output:
249             match = config_regex.match(line)
250             if match:
251                 config[match.group('key')] = match.group('value').strip()
252
253         # Memoize the config dict
254         self._config = ReadOnlyDictView(config)
255
256         return self._config
257
258     def execute_command(self, args, config_override=None, allow_failure=True,
259                         return_all=False):
260         command_args = self._get_command_args(
261             args, config_override=config_override)
262         logger.debug(' '.join(command_args))
263         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
264                              stderr=subprocess.PIPE)
265         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
266         if p.returncode and allow_failure:
267             if stderr.strip():
268                 error_msg = stderr.strip()
269             else:
270                 error_msg = stdout.strip()
271             raise TaskWarriorException(error_msg)
272
273         # Return all whole triplet only if explicitly asked for
274         if not return_all:
275             return stdout.rstrip().split('\n')
276         else:
277             return (stdout.rstrip().split('\n'),
278                     stderr.rstrip().split('\n'),
279                     p.returncode)
280
281     def enforce_recurrence(self):
282         # Run arbitrary report command which will trigger generation
283         # of recurrent tasks.
284
285         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
286         if self.version < self.VERSION_2_4_2:
287             self.execute_command(['next'], allow_failure=False)
288
289     def merge_with(self, path, push=False):
290         path = path.rstrip('/') + '/'
291         self.execute_command(['merge', path], config_override={
292             'merge.autopush': 'yes' if push else 'no',
293         })
294
295     def undo(self):
296         self.execute_command(['undo'])
297
298     # Backend interface implementation
299
300     def filter_tasks(self, filter_obj):
301         self.enforce_recurrence()
302         args = ['export', '--'] + filter_obj.get_filter_params()
303         tasks = []
304         for line in self.execute_command(args):
305             if line:
306                 data = line.strip(',')
307                 try:
308                     filtered_task = Task(self)
309                     filtered_task._load_data(json.loads(data))
310                     tasks.append(filtered_task)
311                 except ValueError:
312                     raise TaskWarriorException('Invalid JSON: %s' % data)
313         return tasks
314
315     def save_task(self, task):
316         """Save a task into TaskWarrior database using add/modify call"""
317
318         args = [task['uuid'], 'modify'] if task.saved else ['add']
319         args.extend(self._get_modified_task_fields_as_args(task))
320         output = self.execute_command(args)
321
322         # Parse out the new ID, if the task is being added for the first time
323         if not task.saved:
324             id_lines = [l for l in output if l.startswith('Created task ')]
325
326             # Complain loudly if it seems that more tasks were created
327             # Should not happen
328             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
329                 raise TaskWarriorException("Unexpected output when creating "
330                                            "task: %s" % '\n'.join(id_lines))
331
332             # Circumvent the ID storage, since ID is considered read-only
333             identifier = id_lines[0].split(' ')[2].rstrip('.')
334
335             # Identifier can be either ID or UUID for completed tasks
336             try:
337                 task._data['id'] = int(identifier)
338             except ValueError:
339                 task._data['uuid'] = identifier
340
341         # Refreshing is very important here, as not only modification time
342         # is updated, but arbitrary attribute may have changed due hooks
343         # altering the data before saving
344         task.refresh(after_save=True)
345
346     def delete_task(self, task):
347         self.execute_command([task['uuid'], 'delete'])
348
349     def start_task(self, task):
350         self.execute_command([task['uuid'], 'start'])
351
352     def stop_task(self, task):
353         self.execute_command([task['uuid'], 'stop'])
354
355     def complete_task(self, task):
356         # Older versions of TW do not stop active task at completion
357         if self.version < self.VERSION_2_4_0 and task.active:
358             task.stop()
359
360         self.execute_command([task['uuid'], 'done'])
361
362     def annotate_task(self, task, annotation):
363         args = [task['uuid'], 'annotate', annotation]
364         self.execute_command(args)
365
366     def denotate_task(self, task, annotation):
367         args = [task['uuid'], 'denotate', annotation]
368         self.execute_command(args)
369
370     def refresh_task(self, task, after_save=False):
371         # We need to use ID as backup for uuid here for the refreshes
372         # of newly saved tasks. Any other place in the code is fine
373         # with using UUID only.
374         args = [task['uuid'] or task['id'], 'export']
375         output = self.execute_command(args)
376
377         def valid(output):
378             return len(output) == 1 and output[0].startswith('{')
379
380         # For older TW versions attempt to uniquely locate the task
381         # using the data we have if it has been just saved.
382         # This can happen when adding a completed task on older TW versions.
383         if (not valid(output) and self.version < self.VERSION_2_4_5
384                 and after_save):
385
386             # Make a copy, removing ID and UUID. It's most likely invalid
387             # (ID 0) if it failed to match a unique task.
388             data = copy.deepcopy(task._data)
389             data.pop('id', None)
390             data.pop('uuid', None)
391
392             taskfilter = self.filter_class(self)
393             for key, value in data.items():
394                 taskfilter.add_filter_param(key, value)
395
396             output = self.execute_command(['export', '--'] +
397                                           taskfilter.get_filter_params())
398
399         # If more than 1 task has been matched still, raise an exception
400         if not valid(output):
401             raise TaskWarriorException(
402                 "Unique identifiers {0} with description: {1} matches "
403                 "multiple tasks: {2}".format(
404                     task['uuid'] or task['id'], task['description'], output)
405             )
406
407         return json.loads(output[0])
408
409     def sync(self):
410         self.execute_command(['sync'])