]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

pep8/flake8 fixes
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import datetime
3 import json
4 import logging
5 import os
6 import re
7 import six
8 import subprocess
9 import copy
10
11 from .task import Task, TaskQuerySet
12 from .filters import TaskWarriorFilter
13 from .serializing import local_zone
14
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16
17 logger = logging.getLogger(__name__)
18
19
20 class Backend(object):
21
22     @abc.abstractproperty
23     def filter_class(self):
24         """Returns the TaskFilter class used by this backend"""
25         pass
26
27     @abc.abstractmethod
28     def filter_tasks(self, filter_obj):
29         """Returns a list of Task objects matching the given filter"""
30         pass
31
32     @abc.abstractmethod
33     def save_task(self, task):
34         pass
35
36     @abc.abstractmethod
37     def delete_task(self, task):
38         pass
39
40     @abc.abstractmethod
41     def start_task(self, task):
42         pass
43
44     @abc.abstractmethod
45     def stop_task(self, task):
46         pass
47
48     @abc.abstractmethod
49     def complete_task(self, task):
50         pass
51
52     @abc.abstractmethod
53     def refresh_task(self, task, after_save=False):
54         """
55         Refreshes the given task. Returns new data dict with serialized
56         attributes.
57         """
58         pass
59
60     @abc.abstractmethod
61     def annotate_task(self, task, annotation):
62         pass
63
64     @abc.abstractmethod
65     def denotate_task(self, task, annotation):
66         pass
67
68     @abc.abstractmethod
69     def sync(self):
70         """Syncs the backend database with the taskd server"""
71         pass
72
73     def convert_datetime_string(self, value):
74         """
75         Converts TW syntax datetime string to a localized datetime
76         object. This method is not mandatory.
77         """
78         raise NotImplemented
79
80
81 class TaskWarriorException(Exception):
82     pass
83
84
85 class TaskWarrior(Backend):
86
87     VERSION_2_1_0 = six.u('2.1.0')
88     VERSION_2_2_0 = six.u('2.2.0')
89     VERSION_2_3_0 = six.u('2.3.0')
90     VERSION_2_4_0 = six.u('2.4.0')
91     VERSION_2_4_1 = six.u('2.4.1')
92     VERSION_2_4_2 = six.u('2.4.2')
93     VERSION_2_4_3 = six.u('2.4.3')
94     VERSION_2_4_4 = six.u('2.4.4')
95     VERSION_2_4_5 = six.u('2.4.5')
96
97     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
98         self.taskrc_location = os.path.expanduser(taskrc_location)
99
100         # If taskrc does not exist, pass / to use defaults and avoid creating
101         # dummy .taskrc file by TaskWarrior
102         if not os.path.exists(self.taskrc_location):
103             self.taskrc_location = '/'
104
105         self.version = self._get_version()
106         self.config = {
107             'confirmation': 'no',
108             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
109             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
110
111             # Defaults to on since 2.4.5, we expect off during parsing
112             'json.array': 'off',
113
114             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
115             # arbitrary big number which is likely to be large enough
116             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
117         }
118
119         # Set data.location override if passed via kwarg
120         if data_location is not None:
121             data_location = os.path.expanduser(data_location)
122             if create and not os.path.exists(data_location):
123                 os.makedirs(data_location)
124             self.config['data.location'] = data_location
125
126         self.tasks = TaskQuerySet(self)
127
128     def _get_command_args(self, args, config_override=None):
129         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
130         config = self.config.copy()
131         config.update(config_override or dict())
132         for item in config.items():
133             command_args.append('rc.{0}={1}'.format(*item))
134         command_args.extend(map(six.text_type, args))
135         return command_args
136
137     def _get_version(self):
138         p = subprocess.Popen(
139             ['task', '--version'],
140             stdout=subprocess.PIPE,
141             stderr=subprocess.PIPE)
142         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
143         return stdout.strip('\n')
144
145     def _get_modified_task_fields_as_args(self, task):
146         args = []
147
148         def add_field(field):
149             # Add the output of format_field method to args list (defaults to
150             # field:value)
151             serialized_value = task._serialize(field, task._data[field])
152
153             # Empty values should not be enclosed in quotation marks, see
154             # TW-1510
155             if serialized_value is '':
156                 escaped_serialized_value = ''
157             else:
158                 escaped_serialized_value = six.u("'{0}'").format(
159                     serialized_value)
160
161             format_default = lambda task: six.u("{0}:{1}").format(
162                 field, escaped_serialized_value)
163
164             format_func = getattr(self, 'format_{0}'.format(field),
165                                   format_default)
166
167             args.append(format_func(task))
168
169         # If we're modifying saved task, simply pass on all modified fields
170         if task.saved:
171             for field in task._modified_fields:
172                 add_field(field)
173         # For new tasks, pass all fields that make sense
174         else:
175             for field in task._data.keys():
176                 if field in task.read_only_fields:
177                     continue
178                 add_field(field)
179
180         return args
181
182     def format_depends(self, task):
183         # We need to generate added and removed dependencies list,
184         # since Taskwarrior does not accept redefining dependencies.
185
186         # This cannot be part of serialize_depends, since we need
187         # to keep a list of all depedencies in the _data dictionary,
188         # not just currently added/removed ones
189
190         old_dependencies = task._original_data.get('depends', set())
191
192         added = task['depends'] - old_dependencies
193         removed = old_dependencies - task['depends']
194
195         # Removed dependencies need to be prefixed with '-'
196         return 'depends:' + ','.join(
197             [t['uuid'] for t in added] +
198             ['-' + t['uuid'] for t in removed]
199         )
200
201     def format_description(self, task):
202         # Task version older than 2.4.0 ignores first word of the
203         # task description if description: prefix is used
204         if self.version < self.VERSION_2_4_0:
205             return task._data['description']
206         else:
207             return six.u("description:'{0}'").format(task._data['description'] or '')
208
209     def convert_datetime_string(self, value):
210
211         if self.version >= self.VERSION_2_4_0:
212             # For strings, use 'task calc' to evaluate the string to datetime
213             # available since TW 2.4.0
214             args = value.split()
215             result = self.execute_command(['calc'] + args)
216             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
217             localized = local_zone.localize(naive)
218         else:
219             raise ValueError("Provided value could not be converted to "
220                              "datetime, its type is not supported: {}"
221                              .format(type(value)))
222
223         return localized
224
225     @property
226     def filter_class(self):
227         return TaskWarriorFilter
228
229     # Public interface
230
231     def get_config(self):
232         raw_output = self.execute_command(
233             ['show'],
234             config_override={'verbose': 'nothing'}
235         )
236
237         config = dict()
238         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
239
240         for line in raw_output:
241             match = config_regex.match(line)
242             if match:
243                 config[match.group('key')] = match.group('value').strip()
244
245         return config
246
247     def execute_command(self, args, config_override=None, allow_failure=True,
248                         return_all=False):
249         command_args = self._get_command_args(
250             args, config_override=config_override)
251         logger.debug(' '.join(command_args))
252         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
253                              stderr=subprocess.PIPE)
254         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
255         if p.returncode and allow_failure:
256             if stderr.strip():
257                 error_msg = stderr.strip()
258             else:
259                 error_msg = stdout.strip()
260             raise TaskWarriorException(error_msg)
261
262         # Return all whole triplet only if explicitly asked for
263         if not return_all:
264             return stdout.rstrip().split('\n')
265         else:
266             return (stdout.rstrip().split('\n'),
267                     stderr.rstrip().split('\n'),
268                     p.returncode)
269
270     def enforce_recurrence(self):
271         # Run arbitrary report command which will trigger generation
272         # of recurrent tasks.
273
274         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
275         if self.version < self.VERSION_2_4_2:
276             self.execute_command(['next'], allow_failure=False)
277
278     def merge_with(self, path, push=False):
279         path = path.rstrip('/') + '/'
280         self.execute_command(['merge', path], config_override={
281             'merge.autopush': 'yes' if push else 'no',
282         })
283
284     def undo(self):
285         self.execute_command(['undo'])
286
287     # Backend interface implementation
288
289     def filter_tasks(self, filter_obj):
290         self.enforce_recurrence()
291         args = ['export', '--'] + filter_obj.get_filter_params()
292         tasks = []
293         for line in self.execute_command(args):
294             if line:
295                 data = line.strip(',')
296                 try:
297                     filtered_task = Task(self)
298                     filtered_task._load_data(json.loads(data))
299                     tasks.append(filtered_task)
300                 except ValueError:
301                     raise TaskWarriorException('Invalid JSON: %s' % data)
302         return tasks
303
304     def save_task(self, task):
305         """Save a task into TaskWarrior database using add/modify call"""
306
307         args = [task['uuid'], 'modify'] if task.saved else ['add']
308         args.extend(self._get_modified_task_fields_as_args(task))
309         output = self.execute_command(args)
310
311         # Parse out the new ID, if the task is being added for the first time
312         if not task.saved:
313             id_lines = [l for l in output if l.startswith('Created task ')]
314
315             # Complain loudly if it seems that more tasks were created
316             # Should not happen
317             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
318                 raise TaskWarriorException("Unexpected output when creating "
319                                            "task: %s" % '\n'.join(id_lines))
320
321             # Circumvent the ID storage, since ID is considered read-only
322             identifier = id_lines[0].split(' ')[2].rstrip('.')
323
324             # Identifier can be either ID or UUID for completed tasks
325             try:
326                 task._data['id'] = int(identifier)
327             except ValueError:
328                 task._data['uuid'] = identifier
329
330         # Refreshing is very important here, as not only modification time
331         # is updated, but arbitrary attribute may have changed due hooks
332         # altering the data before saving
333         task.refresh(after_save=True)
334
335     def delete_task(self, task):
336         self.execute_command([task['uuid'], 'delete'])
337
338     def start_task(self, task):
339         self.execute_command([task['uuid'], 'start'])
340
341     def stop_task(self, task):
342         self.execute_command([task['uuid'], 'stop'])
343
344     def complete_task(self, task):
345         # Older versions of TW do not stop active task at completion
346         if self.version < self.VERSION_2_4_0 and task.active:
347             task.stop()
348
349         self.execute_command([task['uuid'], 'done'])
350
351     def annotate_task(self, task, annotation):
352         args = [task['uuid'], 'annotate', annotation]
353         self.execute_command(args)
354
355     def denotate_task(self, task, annotation):
356         args = [task['uuid'], 'denotate', annotation]
357         self.execute_command(args)
358
359     def refresh_task(self, task, after_save=False):
360         # We need to use ID as backup for uuid here for the refreshes
361         # of newly saved tasks. Any other place in the code is fine
362         # with using UUID only.
363         args = [task['uuid'] or task['id'], 'export']
364         output = self.execute_command(args)
365
366         def valid(output):
367             return len(output) == 1 and output[0].startswith('{')
368
369         # For older TW versions attempt to uniquely locate the task
370         # using the data we have if it has been just saved.
371         # This can happen when adding a completed task on older TW versions.
372         if (not valid(output) and self.version < self.VERSION_2_4_5
373                 and after_save):
374
375             # Make a copy, removing ID and UUID. It's most likely invalid
376             # (ID 0) if it failed to match a unique task.
377             data = copy.deepcopy(task._data)
378             data.pop('id', None)
379             data.pop('uuid', None)
380
381             taskfilter = self.filter_class(self)
382             for key, value in data.items():
383                 taskfilter.add_filter_param(key, value)
384
385             output = self.execute_command(['export', '--'] +
386                                           taskfilter.get_filter_params())
387
388         # If more than 1 task has been matched still, raise an exception
389         if not valid(output):
390             raise TaskWarriorException(
391                 "Unique identifiers {0} with description: {1} matches "
392                 "multiple tasks: {2}".format(
393                     task['uuid'] or task['id'], task['description'], output)
394             )
395
396         return json.loads(output[0])
397
398     def sync(self):
399         self.execute_command(['sync'])