]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

TaskQuerySet: Always require backend set properly
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import datetime
3 import json
4 import logging
5 import os
6 import re
7 import six
8 import subprocess
9
10 from .task import Task, TaskQuerySet
11 from .filters import TaskWarriorFilter
12 from .serializing import local_zone
13
14 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
15
16 logger = logging.getLogger(__name__)
17
18 class Backend(object):
19
20     @abc.abstractproperty
21     def filter_class(self):
22         """Returns the TaskFilter class used by this backend"""
23         pass
24
25     @abc.abstractmethod
26     def filter_tasks(self, filter_obj):
27         """Returns a list of Task objects matching the given filter"""
28         pass
29
30     @abc.abstractmethod
31     def save_task(self, task):
32         pass
33
34     @abc.abstractmethod
35     def delete_task(self, task):
36         pass
37
38     @abc.abstractmethod
39     def start_task(self, task):
40         pass
41
42     @abc.abstractmethod
43     def stop_task(self, task):
44         pass
45
46     @abc.abstractmethod
47     def complete_task(self, task):
48         pass
49
50     @abc.abstractmethod
51     def refresh_task(self, task, after_save=False):
52         """
53         Refreshes the given task. Returns new data dict with serialized
54         attributes.
55         """
56         pass
57
58     @abc.abstractmethod
59     def annotate_task(self, task, annotation):
60         pass
61
62     @abc.abstractmethod
63     def denotate_task(self, task, annotation):
64         pass
65
66     @abc.abstractmethod
67     def sync(self):
68         """Syncs the backend database with the taskd server"""
69         pass
70
71     def convert_datetime_string(self, value):
72         """
73         Converts TW syntax datetime string to a localized datetime
74         object. This method is not mandatory.
75         """
76         raise NotImplemented
77
78
79 class TaskWarriorException(Exception):
80     pass
81
82
83 class TaskWarrior(object):
84
85     VERSION_2_1_0 = six.u('2.1.0')
86     VERSION_2_2_0 = six.u('2.2.0')
87     VERSION_2_3_0 = six.u('2.3.0')
88     VERSION_2_4_0 = six.u('2.4.0')
89     VERSION_2_4_1 = six.u('2.4.1')
90     VERSION_2_4_2 = six.u('2.4.2')
91     VERSION_2_4_3 = six.u('2.4.3')
92     VERSION_2_4_4 = six.u('2.4.4')
93     VERSION_2_4_5 = six.u('2.4.5')
94
95     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
96         self.taskrc_location = os.path.expanduser(taskrc_location)
97
98         # If taskrc does not exist, pass / to use defaults and avoid creating
99         # dummy .taskrc file by TaskWarrior
100         if not os.path.exists(self.taskrc_location):
101             self.taskrc_location = '/'
102
103         self.version = self._get_version()
104         self.config = {
105             'confirmation': 'no',
106             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
107             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
108
109             # Defaults to on since 2.4.5, we expect off during parsing
110             'json.array': 'off',
111
112             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
113             # arbitrary big number which is likely to be large enough
114             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
115         }
116
117         # Set data.location override if passed via kwarg
118         if data_location is not None:
119             data_location = os.path.expanduser(data_location)
120             if create and not os.path.exists(data_location):
121                 os.makedirs(data_location)
122             self.config['data.location'] = data_location
123
124         self.tasks = TaskQuerySet(self)
125
126     def _get_command_args(self, args, config_override=None):
127         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
128         config = self.config.copy()
129         config.update(config_override or dict())
130         for item in config.items():
131             command_args.append('rc.{0}={1}'.format(*item))
132         command_args.extend(map(six.text_type, args))
133         return command_args
134
135     def _get_version(self):
136         p = subprocess.Popen(
137                 ['task', '--version'],
138                 stdout=subprocess.PIPE,
139                 stderr=subprocess.PIPE)
140         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
141         return stdout.strip('\n')
142
143     def _get_modified_task_fields_as_args(self, task):
144         args = []
145
146         def add_field(field):
147             # Add the output of format_field method to args list (defaults to
148             # field:value)
149             serialized_value = task._serialize(field, task._data[field])
150
151             # Empty values should not be enclosed in quotation marks, see
152             # TW-1510
153             if serialized_value is '':
154                 escaped_serialized_value = ''
155             else:
156                 escaped_serialized_value = six.u("'{0}'").format(serialized_value)
157
158             format_default = lambda task: six.u("{0}:{1}").format(field,
159                                                       escaped_serialized_value)
160
161             format_func = getattr(self, 'format_{0}'.format(field),
162                                   format_default)
163
164             args.append(format_func(task))
165
166         # If we're modifying saved task, simply pass on all modified fields
167         if task.saved:
168             for field in task._modified_fields:
169                 add_field(field)
170         # For new tasks, pass all fields that make sense
171         else:
172             for field in task._data.keys():
173                 if field in task.read_only_fields:
174                     continue
175                 add_field(field)
176
177         return args
178
179     def format_depends(self, task):
180         # We need to generate added and removed dependencies list,
181         # since Taskwarrior does not accept redefining dependencies.
182
183         # This cannot be part of serialize_depends, since we need
184         # to keep a list of all depedencies in the _data dictionary,
185         # not just currently added/removed ones
186
187         old_dependencies = task._original_data.get('depends', set())
188
189         added = task['depends'] - old_dependencies
190         removed = old_dependencies - task['depends']
191
192         # Removed dependencies need to be prefixed with '-'
193         return 'depends:' + ','.join(
194                 [t['uuid'] for t in added] +
195                 ['-' + t['uuid'] for t in removed]
196             )
197
198     def format_description(self, task):
199         # Task version older than 2.4.0 ignores first word of the
200         # task description if description: prefix is used
201         if self.version < self.VERSION_2_4_0:
202             return task._data['description']
203         else:
204             return six.u("description:'{0}'").format(task._data['description'] or '')
205
206     def convert_datetime_string(self, value):
207
208         if self.version >= self.VERSION_2_4_0:
209             # For strings, use 'task calc' to evaluate the string to datetime
210             # available since TW 2.4.0
211             args = value.split()
212             result = self.execute_command(['calc'] + args)
213             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
214             localized = local_zone.localize(naive)
215         else:
216             raise ValueError("Provided value could not be converted to "
217                              "datetime, its type is not supported: {}"
218                              .format(type(value)))
219
220     @property
221     def filter_class(self):
222         return TaskWarriorFilter
223
224     # Public interface
225
226     def get_config(self):
227         raw_output = self.execute_command(
228                 ['show'],
229                 config_override={'verbose': 'nothing'}
230             )
231
232         config = dict()
233         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
234
235         for line in raw_output:
236             match = config_regex.match(line)
237             if match:
238                 config[match.group('key')] = match.group('value').strip()
239
240         return config
241
242     def execute_command(self, args, config_override=None, allow_failure=True,
243                         return_all=False):
244         command_args = self._get_command_args(
245             args, config_override=config_override)
246         logger.debug(' '.join(command_args))
247         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
248                              stderr=subprocess.PIPE)
249         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
250         if p.returncode and allow_failure:
251             if stderr.strip():
252                 error_msg = stderr.strip()
253             else:
254                 error_msg = stdout.strip()
255             raise TaskWarriorException(error_msg)
256
257         # Return all whole triplet only if explicitly asked for
258         if not return_all:
259             return stdout.rstrip().split('\n')
260         else:
261             return (stdout.rstrip().split('\n'),
262                     stderr.rstrip().split('\n'),
263                     p.returncode)
264
265     def enforce_recurrence(self):
266         # Run arbitrary report command which will trigger generation
267         # of recurrent tasks.
268
269         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
270         if self.version < self.VERSION_2_4_2:
271             self.execute_command(['next'], allow_failure=False)
272
273     def merge_with(self, path, push=False):
274         path = path.rstrip('/') + '/'
275         self.execute_command(['merge', path], config_override={
276             'merge.autopush': 'yes' if push else 'no',
277         })
278
279     def undo(self):
280         self.execute_command(['undo'])
281
282     # Backend interface implementation
283
284     def filter_tasks(self, filter_obj):
285         self.enforce_recurrence()
286         args = ['export', '--'] + filter_obj.get_filter_params()
287         tasks = []
288         for line in self.execute_command(args):
289             if line:
290                 data = line.strip(',')
291                 try:
292                     filtered_task = Task(self)
293                     filtered_task._load_data(json.loads(data))
294                     tasks.append(filtered_task)
295                 except ValueError:
296                     raise TaskWarriorException('Invalid JSON: %s' % data)
297         return tasks
298
299     def save_task(self, task):
300         """Save a task into TaskWarrior database using add/modify call"""
301
302         args = [task['uuid'], 'modify'] if task.saved else ['add']
303         args.extend(self._get_modified_task_fields_as_args(task))
304         output = self.execute_command(args)
305
306         # Parse out the new ID, if the task is being added for the first time
307         if not task.saved:
308             id_lines = [l for l in output if l.startswith('Created task ')]
309
310             # Complain loudly if it seems that more tasks were created
311             # Should not happen
312             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
313                 raise TaskWarriorException("Unexpected output when creating "
314                                            "task: %s" % '\n'.join(id_lines))
315
316             # Circumvent the ID storage, since ID is considered read-only
317             identifier = id_lines[0].split(' ')[2].rstrip('.')
318
319             # Identifier can be either ID or UUID for completed tasks
320             try:
321                 task._data['id'] = int(identifier)
322             except ValueError:
323                 task._data['uuid'] = identifier
324
325         # Refreshing is very important here, as not only modification time
326         # is updated, but arbitrary attribute may have changed due hooks
327         # altering the data before saving
328         task.refresh(after_save=True)
329
330     def delete_task(self, task):
331         self.execute_command([task['uuid'], 'delete'])
332
333     def start_task(self, task):
334         self.execute_command([task['uuid'], 'start'])
335
336     def stop_task(self, task):
337         self.execute_command([task['uuid'], 'stop'])
338
339     def complete_task(self, task):
340         # Older versions of TW do not stop active task at completion
341         if self.version < self.VERSION_2_4_0 and task.active:
342             task.stop()
343
344         self.execute_command([task['uuid'], 'done'])
345
346     def annotate_task(self, task, annotation):
347         args = [task['uuid'], 'annotate', annotation]
348         self.execute_command(args)
349
350     def denotate_task(self, task, annotation):
351         args = [task['uuid'], 'denotate', annotation]
352         self.execute_command(args)
353
354     def refresh_task(self, task, after_save=False):
355         # We need to use ID as backup for uuid here for the refreshes
356         # of newly saved tasks. Any other place in the code is fine
357         # with using UUID only.
358         args = [task['uuid'] or task['id'], 'export']
359         output = self.execute_command(args)
360
361         def valid(output):
362             return len(output) == 1 and output[0].startswith('{')
363
364         # For older TW versions attempt to uniquely locate the task
365         # using the data we have if it has been just saved.
366         # This can happen when adding a completed task on older TW versions.
367         if (not valid(output) and self.version < self.VERSION_2_4_5
368                 and after_save):
369
370             # Make a copy, removing ID and UUID. It's most likely invalid
371             # (ID 0) if it failed to match a unique task.
372             data = copy.deepcopy(task._data)
373             data.pop('id', None)
374             data.pop('uuid', None)
375
376             taskfilter = self.filter_class(self)
377             for key, value in data.items():
378                 taskfilter.add_filter_param(key, value)
379
380             output = self.execute_command(['export', '--'] +
381                 taskfilter.get_filter_params())
382
383         # If more than 1 task has been matched still, raise an exception
384         if not valid(output):
385             raise TaskWarriorException(
386                 "Unique identifiers {0} with description: {1} matches "
387                 "multiple tasks: {2}".format(
388                 task['uuid'] or task['id'], task['description'], output)
389             )
390
391         return json.loads(output[0])
392
393     def sync(self):
394         self.execute_command(['sync'])