]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

pep8/flake8 fixes
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import re
8 import six
9 import subprocess
10 import copy
11
12 from .task import Task, TaskQuerySet
13 from .filters import TaskWarriorFilter
14 from .serializing import local_zone
15
16 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
17
18 logger = logging.getLogger(__name__)
19
20
21 class Backend(object):
22
23     @abc.abstractproperty
24     def filter_class(self):
25         """Returns the TaskFilter class used by this backend"""
26         pass
27
28     @abc.abstractmethod
29     def filter_tasks(self, filter_obj):
30         """Returns a list of Task objects matching the given filter"""
31         pass
32
33     @abc.abstractmethod
34     def save_task(self, task):
35         pass
36
37     @abc.abstractmethod
38     def delete_task(self, task):
39         pass
40
41     @abc.abstractmethod
42     def start_task(self, task):
43         pass
44
45     @abc.abstractmethod
46     def stop_task(self, task):
47         pass
48
49     @abc.abstractmethod
50     def complete_task(self, task):
51         pass
52
53     @abc.abstractmethod
54     def refresh_task(self, task, after_save=False):
55         """
56         Refreshes the given task. Returns new data dict with serialized
57         attributes.
58         """
59         pass
60
61     @abc.abstractmethod
62     def annotate_task(self, task, annotation):
63         pass
64
65     @abc.abstractmethod
66     def denotate_task(self, task, annotation):
67         pass
68
69     @abc.abstractmethod
70     def sync(self):
71         """Syncs the backend database with the taskd server"""
72         pass
73
74     def convert_datetime_string(self, value):
75         """
76         Converts TW syntax datetime string to a localized datetime
77         object. This method is not mandatory.
78         """
79         raise NotImplemented
80
81
82 class TaskWarriorException(Exception):
83     pass
84
85
86 class TaskWarrior(Backend):
87
88     VERSION_2_1_0 = six.u('2.1.0')
89     VERSION_2_2_0 = six.u('2.2.0')
90     VERSION_2_3_0 = six.u('2.3.0')
91     VERSION_2_4_0 = six.u('2.4.0')
92     VERSION_2_4_1 = six.u('2.4.1')
93     VERSION_2_4_2 = six.u('2.4.2')
94     VERSION_2_4_3 = six.u('2.4.3')
95     VERSION_2_4_4 = six.u('2.4.4')
96     VERSION_2_4_5 = six.u('2.4.5')
97
98     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
99         self.taskrc_location = os.path.expanduser(taskrc_location)
100
101         # If taskrc does not exist, pass / to use defaults and avoid creating
102         # dummy .taskrc file by TaskWarrior
103         if not os.path.exists(self.taskrc_location):
104             self.taskrc_location = '/'
105
106         self.version = self._get_version()
107         self.config = {
108             'confirmation': 'no',
109             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
110             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
111
112             # Defaults to on since 2.4.5, we expect off during parsing
113             'json.array': 'off',
114
115             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
116             # arbitrary big number which is likely to be large enough
117             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
118         }
119
120         # Set data.location override if passed via kwarg
121         if data_location is not None:
122             data_location = os.path.expanduser(data_location)
123             if create and not os.path.exists(data_location):
124                 os.makedirs(data_location)
125             self.config['data.location'] = data_location
126
127         self.tasks = TaskQuerySet(self)
128
129     def _get_command_args(self, args, config_override=None):
130         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
131         config = self.config.copy()
132         config.update(config_override or dict())
133         for item in config.items():
134             command_args.append('rc.{0}={1}'.format(*item))
135         command_args.extend(map(six.text_type, args))
136         return command_args
137
138     def _get_version(self):
139         p = subprocess.Popen(
140             ['task', '--version'],
141             stdout=subprocess.PIPE,
142             stderr=subprocess.PIPE)
143         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
144         return stdout.strip('\n')
145
146     def _get_modified_task_fields_as_args(self, task):
147         args = []
148
149         def add_field(field):
150             # Add the output of format_field method to args list (defaults to
151             # field:value)
152             serialized_value = task._serialize(field, task._data[field])
153
154             # Empty values should not be enclosed in quotation marks, see
155             # TW-1510
156             if serialized_value is '':
157                 escaped_serialized_value = ''
158             else:
159                 escaped_serialized_value = six.u("'{0}'").format(
160                     serialized_value)
161
162             format_default = lambda task: six.u("{0}:{1}").format(
163                 field, escaped_serialized_value)
164
165             format_func = getattr(self, 'format_{0}'.format(field),
166                                   format_default)
167
168             args.append(format_func(task))
169
170         # If we're modifying saved task, simply pass on all modified fields
171         if task.saved:
172             for field in task._modified_fields:
173                 add_field(field)
174         # For new tasks, pass all fields that make sense
175         else:
176             for field in task._data.keys():
177                 if field in task.read_only_fields:
178                     continue
179                 add_field(field)
180
181         return args
182
183     def format_depends(self, task):
184         # We need to generate added and removed dependencies list,
185         # since Taskwarrior does not accept redefining dependencies.
186
187         # This cannot be part of serialize_depends, since we need
188         # to keep a list of all depedencies in the _data dictionary,
189         # not just currently added/removed ones
190
191         old_dependencies = task._original_data.get('depends', set())
192
193         added = task['depends'] - old_dependencies
194         removed = old_dependencies - task['depends']
195
196         # Removed dependencies need to be prefixed with '-'
197         return 'depends:' + ','.join(
198             [t['uuid'] for t in added] +
199             ['-' + t['uuid'] for t in removed]
200         )
201
202     def format_description(self, task):
203         # Task version older than 2.4.0 ignores first word of the
204         # task description if description: prefix is used
205         if self.version < self.VERSION_2_4_0:
206             return task._data['description']
207         else:
208             return six.u("description:'{0}'").format(task._data['description'] or '')
209
210     def convert_datetime_string(self, value):
211
212         if self.version >= self.VERSION_2_4_0:
213             # For strings, use 'task calc' to evaluate the string to datetime
214             # available since TW 2.4.0
215             args = value.split()
216             result = self.execute_command(['calc'] + args)
217             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
218             localized = local_zone.localize(naive)
219         else:
220             raise ValueError("Provided value could not be converted to "
221                              "datetime, its type is not supported: {}"
222                              .format(type(value)))
223
224         return localized
225
226     @property
227     def filter_class(self):
228         return TaskWarriorFilter
229
230     # Public interface
231
232     def get_config(self):
233         raw_output = self.execute_command(
234             ['show'],
235             config_override={'verbose': 'nothing'}
236         )
237
238         config = dict()
239         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
240
241         for line in raw_output:
242             match = config_regex.match(line)
243             if match:
244                 config[match.group('key')] = match.group('value').strip()
245
246         return config
247
248     def execute_command(self, args, config_override=None, allow_failure=True,
249                         return_all=False):
250         command_args = self._get_command_args(
251             args, config_override=config_override)
252         logger.debug(' '.join(command_args))
253         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
254                              stderr=subprocess.PIPE)
255         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
256         if p.returncode and allow_failure:
257             if stderr.strip():
258                 error_msg = stderr.strip()
259             else:
260                 error_msg = stdout.strip()
261             raise TaskWarriorException(error_msg)
262
263         # Return all whole triplet only if explicitly asked for
264         if not return_all:
265             return stdout.rstrip().split('\n')
266         else:
267             return (stdout.rstrip().split('\n'),
268                     stderr.rstrip().split('\n'),
269                     p.returncode)
270
271     def enforce_recurrence(self):
272         # Run arbitrary report command which will trigger generation
273         # of recurrent tasks.
274
275         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
276         if self.version < self.VERSION_2_4_2:
277             self.execute_command(['next'], allow_failure=False)
278
279     def merge_with(self, path, push=False):
280         path = path.rstrip('/') + '/'
281         self.execute_command(['merge', path], config_override={
282             'merge.autopush': 'yes' if push else 'no',
283         })
284
285     def undo(self):
286         self.execute_command(['undo'])
287
288     # Backend interface implementation
289
290     def filter_tasks(self, filter_obj):
291         self.enforce_recurrence()
292         args = ['export', '--'] + filter_obj.get_filter_params()
293         tasks = []
294         for line in self.execute_command(args):
295             if line:
296                 data = line.strip(',')
297                 try:
298                     filtered_task = Task(self)
299                     filtered_task._load_data(json.loads(data))
300                     tasks.append(filtered_task)
301                 except ValueError:
302                     raise TaskWarriorException('Invalid JSON: %s' % data)
303         return tasks
304
305     def save_task(self, task):
306         """Save a task into TaskWarrior database using add/modify call"""
307
308         args = [task['uuid'], 'modify'] if task.saved else ['add']
309         args.extend(self._get_modified_task_fields_as_args(task))
310         output = self.execute_command(args)
311
312         # Parse out the new ID, if the task is being added for the first time
313         if not task.saved:
314             id_lines = [l for l in output if l.startswith('Created task ')]
315
316             # Complain loudly if it seems that more tasks were created
317             # Should not happen
318             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
319                 raise TaskWarriorException("Unexpected output when creating "
320                                            "task: %s" % '\n'.join(id_lines))
321
322             # Circumvent the ID storage, since ID is considered read-only
323             identifier = id_lines[0].split(' ')[2].rstrip('.')
324
325             # Identifier can be either ID or UUID for completed tasks
326             try:
327                 task._data['id'] = int(identifier)
328             except ValueError:
329                 task._data['uuid'] = identifier
330
331         # Refreshing is very important here, as not only modification time
332         # is updated, but arbitrary attribute may have changed due hooks
333         # altering the data before saving
334         task.refresh(after_save=True)
335
336     def delete_task(self, task):
337         self.execute_command([task['uuid'], 'delete'])
338
339     def start_task(self, task):
340         self.execute_command([task['uuid'], 'start'])
341
342     def stop_task(self, task):
343         self.execute_command([task['uuid'], 'stop'])
344
345     def complete_task(self, task):
346         # Older versions of TW do not stop active task at completion
347         if self.version < self.VERSION_2_4_0 and task.active:
348             task.stop()
349
350         self.execute_command([task['uuid'], 'done'])
351
352     def annotate_task(self, task, annotation):
353         args = [task['uuid'], 'annotate', annotation]
354         self.execute_command(args)
355
356     def denotate_task(self, task, annotation):
357         args = [task['uuid'], 'denotate', annotation]
358         self.execute_command(args)
359
360     def refresh_task(self, task, after_save=False):
361         # We need to use ID as backup for uuid here for the refreshes
362         # of newly saved tasks. Any other place in the code is fine
363         # with using UUID only.
364         args = [task['uuid'] or task['id'], 'export']
365         output = self.execute_command(args)
366
367         def valid(output):
368             return len(output) == 1 and output[0].startswith('{')
369
370         # For older TW versions attempt to uniquely locate the task
371         # using the data we have if it has been just saved.
372         # This can happen when adding a completed task on older TW versions.
373         if (not valid(output) and self.version < self.VERSION_2_4_5
374                 and after_save):
375
376             # Make a copy, removing ID and UUID. It's most likely invalid
377             # (ID 0) if it failed to match a unique task.
378             data = copy.deepcopy(task._data)
379             data.pop('id', None)
380             data.pop('uuid', None)
381
382             taskfilter = self.filter_class(self)
383             for key, value in data.items():
384                 taskfilter.add_filter_param(key, value)
385
386             output = self.execute_command(['export', '--'] +
387                                           taskfilter.get_filter_params())
388
389         # If more than 1 task has been matched still, raise an exception
390         if not valid(output):
391             raise TaskWarriorException(
392                 "Unique identifiers {0} with description: {1} matches "
393                 "multiple tasks: {2}".format(
394                     task['uuid'] or task['id'], task['description'], output)
395             )
396
397         return json.loads(output[0])
398
399     def sync(self):
400         self.execute_command(['sync'])