]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/backends.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

TaskWarrior: Actually return converted datetime object
[etc/taskwarrior.git] / tasklib / backends.py
1 import abc
2 import datetime
3 import json
4 import logging
5 import os
6 import re
7 import six
8 import subprocess
9
10 from .task import Task, TaskQuerySet
11 from .filters import TaskWarriorFilter
12 from .serializing import local_zone
13
14 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
15
16 logger = logging.getLogger(__name__)
17
18 class Backend(object):
19
20     @abc.abstractproperty
21     def filter_class(self):
22         """Returns the TaskFilter class used by this backend"""
23         pass
24
25     @abc.abstractmethod
26     def filter_tasks(self, filter_obj):
27         """Returns a list of Task objects matching the given filter"""
28         pass
29
30     @abc.abstractmethod
31     def save_task(self, task):
32         pass
33
34     @abc.abstractmethod
35     def delete_task(self, task):
36         pass
37
38     @abc.abstractmethod
39     def start_task(self, task):
40         pass
41
42     @abc.abstractmethod
43     def stop_task(self, task):
44         pass
45
46     @abc.abstractmethod
47     def complete_task(self, task):
48         pass
49
50     @abc.abstractmethod
51     def refresh_task(self, task, after_save=False):
52         """
53         Refreshes the given task. Returns new data dict with serialized
54         attributes.
55         """
56         pass
57
58     @abc.abstractmethod
59     def annotate_task(self, task, annotation):
60         pass
61
62     @abc.abstractmethod
63     def denotate_task(self, task, annotation):
64         pass
65
66     @abc.abstractmethod
67     def sync(self):
68         """Syncs the backend database with the taskd server"""
69         pass
70
71     def convert_datetime_string(self, value):
72         """
73         Converts TW syntax datetime string to a localized datetime
74         object. This method is not mandatory.
75         """
76         raise NotImplemented
77
78
79 class TaskWarriorException(Exception):
80     pass
81
82
83 class TaskWarrior(object):
84
85     VERSION_2_1_0 = six.u('2.1.0')
86     VERSION_2_2_0 = six.u('2.2.0')
87     VERSION_2_3_0 = six.u('2.3.0')
88     VERSION_2_4_0 = six.u('2.4.0')
89     VERSION_2_4_1 = six.u('2.4.1')
90     VERSION_2_4_2 = six.u('2.4.2')
91     VERSION_2_4_3 = six.u('2.4.3')
92     VERSION_2_4_4 = six.u('2.4.4')
93     VERSION_2_4_5 = six.u('2.4.5')
94
95     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
96         self.taskrc_location = os.path.expanduser(taskrc_location)
97
98         # If taskrc does not exist, pass / to use defaults and avoid creating
99         # dummy .taskrc file by TaskWarrior
100         if not os.path.exists(self.taskrc_location):
101             self.taskrc_location = '/'
102
103         self.version = self._get_version()
104         self.config = {
105             'confirmation': 'no',
106             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
107             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
108
109             # Defaults to on since 2.4.5, we expect off during parsing
110             'json.array': 'off',
111
112             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
113             # arbitrary big number which is likely to be large enough
114             'bulk': 0 if self.version >= self.VERSION_2_4_3 else 100000,
115         }
116
117         # Set data.location override if passed via kwarg
118         if data_location is not None:
119             data_location = os.path.expanduser(data_location)
120             if create and not os.path.exists(data_location):
121                 os.makedirs(data_location)
122             self.config['data.location'] = data_location
123
124         self.tasks = TaskQuerySet(self)
125
126     def _get_command_args(self, args, config_override=None):
127         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
128         config = self.config.copy()
129         config.update(config_override or dict())
130         for item in config.items():
131             command_args.append('rc.{0}={1}'.format(*item))
132         command_args.extend(map(six.text_type, args))
133         return command_args
134
135     def _get_version(self):
136         p = subprocess.Popen(
137                 ['task', '--version'],
138                 stdout=subprocess.PIPE,
139                 stderr=subprocess.PIPE)
140         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
141         return stdout.strip('\n')
142
143     def _get_modified_task_fields_as_args(self, task):
144         args = []
145
146         def add_field(field):
147             # Add the output of format_field method to args list (defaults to
148             # field:value)
149             serialized_value = task._serialize(field, task._data[field])
150
151             # Empty values should not be enclosed in quotation marks, see
152             # TW-1510
153             if serialized_value is '':
154                 escaped_serialized_value = ''
155             else:
156                 escaped_serialized_value = six.u("'{0}'").format(serialized_value)
157
158             format_default = lambda task: six.u("{0}:{1}").format(field,
159                                                       escaped_serialized_value)
160
161             format_func = getattr(self, 'format_{0}'.format(field),
162                                   format_default)
163
164             args.append(format_func(task))
165
166         # If we're modifying saved task, simply pass on all modified fields
167         if task.saved:
168             for field in task._modified_fields:
169                 add_field(field)
170         # For new tasks, pass all fields that make sense
171         else:
172             for field in task._data.keys():
173                 if field in task.read_only_fields:
174                     continue
175                 add_field(field)
176
177         return args
178
179     def format_depends(self, task):
180         # We need to generate added and removed dependencies list,
181         # since Taskwarrior does not accept redefining dependencies.
182
183         # This cannot be part of serialize_depends, since we need
184         # to keep a list of all depedencies in the _data dictionary,
185         # not just currently added/removed ones
186
187         old_dependencies = task._original_data.get('depends', set())
188
189         added = task['depends'] - old_dependencies
190         removed = old_dependencies - task['depends']
191
192         # Removed dependencies need to be prefixed with '-'
193         return 'depends:' + ','.join(
194                 [t['uuid'] for t in added] +
195                 ['-' + t['uuid'] for t in removed]
196             )
197
198     def format_description(self, task):
199         # Task version older than 2.4.0 ignores first word of the
200         # task description if description: prefix is used
201         if self.version < self.VERSION_2_4_0:
202             return task._data['description']
203         else:
204             return six.u("description:'{0}'").format(task._data['description'] or '')
205
206     def convert_datetime_string(self, value):
207
208         if self.version >= self.VERSION_2_4_0:
209             # For strings, use 'task calc' to evaluate the string to datetime
210             # available since TW 2.4.0
211             args = value.split()
212             result = self.execute_command(['calc'] + args)
213             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
214             localized = local_zone.localize(naive)
215         else:
216             raise ValueError("Provided value could not be converted to "
217                              "datetime, its type is not supported: {}"
218                              .format(type(value)))
219
220         return localized
221
222     @property
223     def filter_class(self):
224         return TaskWarriorFilter
225
226     # Public interface
227
228     def get_config(self):
229         raw_output = self.execute_command(
230                 ['show'],
231                 config_override={'verbose': 'nothing'}
232             )
233
234         config = dict()
235         config_regex = re.compile(r'^(?P<key>[^\s]+)\s+(?P<value>[^\s].+$)')
236
237         for line in raw_output:
238             match = config_regex.match(line)
239             if match:
240                 config[match.group('key')] = match.group('value').strip()
241
242         return config
243
244     def execute_command(self, args, config_override=None, allow_failure=True,
245                         return_all=False):
246         command_args = self._get_command_args(
247             args, config_override=config_override)
248         logger.debug(' '.join(command_args))
249         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
250                              stderr=subprocess.PIPE)
251         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
252         if p.returncode and allow_failure:
253             if stderr.strip():
254                 error_msg = stderr.strip()
255             else:
256                 error_msg = stdout.strip()
257             raise TaskWarriorException(error_msg)
258
259         # Return all whole triplet only if explicitly asked for
260         if not return_all:
261             return stdout.rstrip().split('\n')
262         else:
263             return (stdout.rstrip().split('\n'),
264                     stderr.rstrip().split('\n'),
265                     p.returncode)
266
267     def enforce_recurrence(self):
268         # Run arbitrary report command which will trigger generation
269         # of recurrent tasks.
270
271         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
272         if self.version < self.VERSION_2_4_2:
273             self.execute_command(['next'], allow_failure=False)
274
275     def merge_with(self, path, push=False):
276         path = path.rstrip('/') + '/'
277         self.execute_command(['merge', path], config_override={
278             'merge.autopush': 'yes' if push else 'no',
279         })
280
281     def undo(self):
282         self.execute_command(['undo'])
283
284     # Backend interface implementation
285
286     def filter_tasks(self, filter_obj):
287         self.enforce_recurrence()
288         args = ['export', '--'] + filter_obj.get_filter_params()
289         tasks = []
290         for line in self.execute_command(args):
291             if line:
292                 data = line.strip(',')
293                 try:
294                     filtered_task = Task(self)
295                     filtered_task._load_data(json.loads(data))
296                     tasks.append(filtered_task)
297                 except ValueError:
298                     raise TaskWarriorException('Invalid JSON: %s' % data)
299         return tasks
300
301     def save_task(self, task):
302         """Save a task into TaskWarrior database using add/modify call"""
303
304         args = [task['uuid'], 'modify'] if task.saved else ['add']
305         args.extend(self._get_modified_task_fields_as_args(task))
306         output = self.execute_command(args)
307
308         # Parse out the new ID, if the task is being added for the first time
309         if not task.saved:
310             id_lines = [l for l in output if l.startswith('Created task ')]
311
312             # Complain loudly if it seems that more tasks were created
313             # Should not happen
314             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
315                 raise TaskWarriorException("Unexpected output when creating "
316                                            "task: %s" % '\n'.join(id_lines))
317
318             # Circumvent the ID storage, since ID is considered read-only
319             identifier = id_lines[0].split(' ')[2].rstrip('.')
320
321             # Identifier can be either ID or UUID for completed tasks
322             try:
323                 task._data['id'] = int(identifier)
324             except ValueError:
325                 task._data['uuid'] = identifier
326
327         # Refreshing is very important here, as not only modification time
328         # is updated, but arbitrary attribute may have changed due hooks
329         # altering the data before saving
330         task.refresh(after_save=True)
331
332     def delete_task(self, task):
333         self.execute_command([task['uuid'], 'delete'])
334
335     def start_task(self, task):
336         self.execute_command([task['uuid'], 'start'])
337
338     def stop_task(self, task):
339         self.execute_command([task['uuid'], 'stop'])
340
341     def complete_task(self, task):
342         # Older versions of TW do not stop active task at completion
343         if self.version < self.VERSION_2_4_0 and task.active:
344             task.stop()
345
346         self.execute_command([task['uuid'], 'done'])
347
348     def annotate_task(self, task, annotation):
349         args = [task['uuid'], 'annotate', annotation]
350         self.execute_command(args)
351
352     def denotate_task(self, task, annotation):
353         args = [task['uuid'], 'denotate', annotation]
354         self.execute_command(args)
355
356     def refresh_task(self, task, after_save=False):
357         # We need to use ID as backup for uuid here for the refreshes
358         # of newly saved tasks. Any other place in the code is fine
359         # with using UUID only.
360         args = [task['uuid'] or task['id'], 'export']
361         output = self.execute_command(args)
362
363         def valid(output):
364             return len(output) == 1 and output[0].startswith('{')
365
366         # For older TW versions attempt to uniquely locate the task
367         # using the data we have if it has been just saved.
368         # This can happen when adding a completed task on older TW versions.
369         if (not valid(output) and self.version < self.VERSION_2_4_5
370                 and after_save):
371
372             # Make a copy, removing ID and UUID. It's most likely invalid
373             # (ID 0) if it failed to match a unique task.
374             data = copy.deepcopy(task._data)
375             data.pop('id', None)
376             data.pop('uuid', None)
377
378             taskfilter = self.filter_class(self)
379             for key, value in data.items():
380                 taskfilter.add_filter_param(key, value)
381
382             output = self.execute_command(['export', '--'] +
383                 taskfilter.get_filter_params())
384
385         # If more than 1 task has been matched still, raise an exception
386         if not valid(output):
387             raise TaskWarriorException(
388                 "Unique identifiers {0} with description: {1} matches "
389                 "multiple tasks: {2}".format(
390                 task['uuid'] or task['id'], task['description'], output)
391             )
392
393         return json.loads(output[0])
394
395     def sync(self):
396         self.execute_command(['sync'])