]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

coverage: Configure coveralls to not include test files
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import pytz
8 import six
9 import sys
10 import subprocess
11 import tzlocal
12
13 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
14 REPR_OUTPUT_SIZE = 10
15 PENDING = 'pending'
16 COMPLETED = 'completed'
17
18 VERSION_2_1_0 = six.u('2.1.0')
19 VERSION_2_2_0 = six.u('2.2.0')
20 VERSION_2_3_0 = six.u('2.3.0')
21 VERSION_2_4_0 = six.u('2.4.0')
22
23 logger = logging.getLogger(__name__)
24 local_zone = tzlocal.get_localzone()
25
26
27 class TaskWarriorException(Exception):
28     pass
29
30
31 class ReadOnlyDictView(object):
32     """
33     Provides simplified read-only view upon dict object.
34     """
35
36     def __init__(self, viewed_dict):
37         self.viewed_dict = viewed_dict
38
39     def __getitem__(self, key):
40         return copy.deepcopy(self.viewed_dict.__getitem__(key))
41
42     def __contains__(self, k):
43         return self.viewed_dict.__contains__(k)
44
45     def __iter__(self):
46         for value in self.viewed_dict:
47             yield copy.deepcopy(value)
48
49     def __len__(self):
50         return len(self.viewed_dict)
51
52     def get(self, key, default=None):
53         return copy.deepcopy(self.viewed_dict.get(key, default))
54
55     def items(self):
56         return [copy.deepcopy(v) for v in self.viewed_dict.items()]
57
58     def values(self):
59         return [copy.deepcopy(v) for v in self.viewed_dict.values()]
60
61
62 class SerializingObject(object):
63     """
64     Common ancestor for TaskResource & TaskFilter, since they both
65     need to serialize arguments.
66
67     Serializing method should hold the following contract:
68       - any empty value (meaning removal of the attribute)
69         is deserialized into a empty string
70       - None denotes a empty value for any attribute
71
72     Deserializing method should hold the following contract:
73       - None denotes a empty value for any attribute (however,
74         this is here as a safeguard, TaskWarrior currently does
75         not export empty-valued attributes) if the attribute
76         is not iterable (e.g. list or set), in which case
77         a empty iterable should be used.
78
79     Normalizing methods should hold the following contract:
80       - They are used to validate and normalize the user input.
81         Any attribute value that comes from the user (during Task
82         initialization, assignign values to Task attributes, or
83         filtering by user-provided values of attributes) is first
84         validated and normalized using the normalize_{key} method.
85       - If validation or normalization fails, normalizer is expected
86         to raise ValueError.
87     """
88
89     def _deserialize(self, key, value):
90         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
91                                lambda x: x if x != '' else None)
92         return hydrate_func(value)
93
94     def _serialize(self, key, value):
95         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
96                                  lambda x: x if x is not None else '')
97         return dehydrate_func(value)
98
99     def _normalize(self, key, value):
100         """
101         Use normalize_<key> methods to normalize user input. Any user
102         input will be normalized at the moment it is used as filter,
103         or entered as a value of Task attribute.
104         """
105
106         normalize_func = getattr(self, 'normalize_{0}'.format(key),
107                                  lambda x: x)
108
109         return normalize_func(value)
110
111     def timestamp_serializer(self, date):
112         if not date:
113             return ''
114
115         # Any serialized timestamp should be localized, we need to
116         # convert to UTC before converting to string (DATE_FORMAT uses UTC)
117         date = date.astimezone(pytz.utc)
118
119         return date.strftime(DATE_FORMAT)
120
121     def timestamp_deserializer(self, date_str):
122         if not date_str:
123             return None
124
125         # Return timestamp localized in the local zone
126         naive_timestamp = datetime.datetime.strptime(date_str, DATE_FORMAT)
127         localized_timestamp = pytz.utc.localize(naive_timestamp)
128         return localized_timestamp.astimezone(local_zone)
129
130     def serialize_entry(self, value):
131         return self.timestamp_serializer(value)
132
133     def deserialize_entry(self, value):
134         return self.timestamp_deserializer(value)
135
136     def normalize_entry(self, value):
137         return self.datetime_normalizer(value)
138
139     def serialize_modified(self, value):
140         return self.timestamp_serializer(value)
141
142     def deserialize_modified(self, value):
143         return self.timestamp_deserializer(value)
144
145     def normalize_modified(self, value):
146         return self.datetime_normalizer(value)
147
148     def serialize_due(self, value):
149         return self.timestamp_serializer(value)
150
151     def deserialize_due(self, value):
152         return self.timestamp_deserializer(value)
153
154     def normalize_due(self, value):
155         return self.datetime_normalizer(value)
156
157     def serialize_scheduled(self, value):
158         return self.timestamp_serializer(value)
159
160     def deserialize_scheduled(self, value):
161         return self.timestamp_deserializer(value)
162
163     def normalize_scheduled(self, value):
164         return self.datetime_normalizer(value)
165
166     def serialize_until(self, value):
167         return self.timestamp_serializer(value)
168
169     def deserialize_until(self, value):
170         return self.timestamp_deserializer(value)
171
172     def normalize_until(self, value):
173         return self.datetime_normalizer(value)
174
175     def serialize_wait(self, value):
176         return self.timestamp_serializer(value)
177
178     def deserialize_wait(self, value):
179         return self.timestamp_deserializer(value)
180
181     def normalize_wait(self, value):
182         return self.datetime_normalizer(value)
183
184     def serialize_annotations(self, value):
185         value = value if value is not None else []
186
187         # This may seem weird, but it's correct, we want to export
188         # a list of dicts as serialized value
189         serialized_annotations = [json.loads(annotation.export_data())
190                                   for annotation in value]
191         return serialized_annotations if serialized_annotations else ''
192
193     def deserialize_annotations(self, data):
194         return [TaskAnnotation(self, d) for d in data] if data else []
195
196     def serialize_tags(self, tags):
197         return ','.join(tags) if tags else ''
198
199     def deserialize_tags(self, tags):
200         if isinstance(tags, six.string_types):
201             return tags.split(',') if tags else []
202         return tags or []
203
204     def serialize_depends(self, value):
205         # Return the list of uuids
206         value = value if value is not None else set()
207         return ','.join(task['uuid'] for task in value)
208
209     def deserialize_depends(self, raw_uuids):
210         raw_uuids = raw_uuids or ''  # Convert None to empty string
211         uuids = raw_uuids.split(',')
212         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
213
214     def datetime_normalizer(self, value):
215         """
216         Normalizes date/datetime value (considered to come from user input)
217         to localized datetime value. Following conversions happen:
218
219         naive date -> localized datetime with the same date, and time=midnight
220         naive datetime -> localized datetime with the same value
221         localized datetime -> localized datetime (no conversion)
222         """
223
224         if (isinstance(value, datetime.date)
225             and not isinstance(value, datetime.datetime)):
226             # Convert to local midnight
227             value_full = datetime.datetime.combine(value, datetime.time.min)
228             localized = local_zone.localize(value_full)
229         elif isinstance(value, datetime.datetime) and value.tzinfo is None:
230             # Convert to localized datetime object
231             localized = local_zone.localize(value)
232         else:
233             # If the value is already localized, there is no need to change
234             # time zone at this point. Also None is a valid value too.
235             localized = value
236         
237         return localized
238
239     def normalize_uuid(self, value):
240         # Enforce sane UUID
241         if not isinstance(value, six.text_type) or value == '':
242             raise ValueError("UUID must be a valid non-empty string.")
243
244         return value
245
246
247 class TaskResource(SerializingObject):
248     read_only_fields = []
249
250     def _load_data(self, data):
251         self._data = dict((key, self._deserialize(key, value))
252                           for key, value in data.items())
253         # We need to use a copy for original data, so that changes
254         # are not propagated.
255         self._original_data = copy.deepcopy(self._data)
256
257     def _update_data(self, data, update_original=False):
258         """
259         Low level update of the internal _data dict. Data which are coming as
260         updates should already be serialized. If update_original is True, the
261         original_data dict is updated as well.
262         """
263         self._data.update(dict((key, self._deserialize(key, value))
264                                for key, value in data.items()))
265
266         if update_original:
267             self._original_data = copy.deepcopy(self._data)
268
269
270     def __getitem__(self, key):
271         # This is a workaround to make TaskResource non-iterable
272         # over simple index-based iteration
273         try:
274             int(key)
275             raise StopIteration
276         except ValueError:
277             pass
278
279         if key not in self._data:
280             self._data[key] = self._deserialize(key, None)
281
282         return self._data.get(key)
283
284     def __setitem__(self, key, value):
285         if key in self.read_only_fields:
286             raise RuntimeError('Field \'%s\' is read-only' % key)
287
288         # Normalize the user input before saving it
289         value = self._normalize(key, value)
290         self._data[key] = value
291
292     def __str__(self):
293         s = six.text_type(self.__unicode__())
294         if not six.PY3:
295             s = s.encode('utf-8')
296         return s
297
298     def __repr__(self):
299         return str(self)
300
301     def export_data(self):
302         """
303         Exports current data contained in the Task as JSON
304         """
305
306         # We need to remove spaces for TW-1504, use custom separators
307         data_tuples = ((key, self._serialize(key, value))
308                        for key, value in six.iteritems(self._data))
309
310         # Empty string denotes empty serialized value, we do not want
311         # to pass that to TaskWarrior.
312         data_tuples = filter(lambda t: t[1] is not '', data_tuples)
313         data = dict(data_tuples)
314         return json.dumps(data, separators=(',',':'))
315
316     @property
317     def _modified_fields(self):
318         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
319         for key in writable_fields:
320             new_value = self._data.get(key)
321             old_value = self._original_data.get(key)
322
323             # Make sure not to mark data removal as modified field if the
324             # field originally had some empty value
325             if key in self._data and not new_value and not old_value:
326                 continue
327
328             if new_value != old_value:
329                 yield key
330
331     @property
332     def modified(self):
333         return bool(list(self._modified_fields))
334
335
336 class TaskAnnotation(TaskResource):
337     read_only_fields = ['entry', 'description']
338
339     def __init__(self, task, data={}):
340         self.task = task
341         self._load_data(data)
342
343     def remove(self):
344         self.task.remove_annotation(self)
345
346     def __unicode__(self):
347         return self['description']
348
349     def __eq__(self, other):
350         # consider 2 annotations equal if they belong to the same task, and
351         # their data dics are the same
352         return self.task == other.task and self._data == other._data
353
354     __repr__ = __unicode__
355
356
357 class Task(TaskResource):
358     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
359
360     class DoesNotExist(Exception):
361         pass
362
363     class CompletedTask(Exception):
364         """
365         Raised when the operation cannot be performed on the completed task.
366         """
367         pass
368
369     class DeletedTask(Exception):
370         """
371         Raised when the operation cannot be performed on the deleted task.
372         """
373         pass
374
375     class NotSaved(Exception):
376         """
377         Raised when the operation cannot be performed on the task, because
378         it has not been saved to TaskWarrior yet.
379         """
380         pass
381
382     @classmethod
383     def from_input(cls, input_file=sys.stdin, modify=None):
384         """
385         Creates a Task object, directly from the stdin, by reading one line.
386         If modify=True, two lines are used, first line interpreted as the
387         original state of the Task object, and second line as its new,
388         modified value. This is consistent with the TaskWarrior's hook
389         system.
390
391         Object created by this method should not be saved, deleted
392         or refreshed, as t could create a infinite loop. For this
393         reason, TaskWarrior instance is set to None.
394
395         Input_file argument can be used to specify the input file,
396         but defaults to sys.stdin.
397         """
398
399         # TaskWarrior instance is set to None
400         task = cls(None)
401
402         # Detect the hook type if not given directly
403         name = os.path.basename(sys.argv[0])
404         modify = name.startswith('on-modify') if modify is None else modify
405
406         # Load the data from the input
407         task._load_data(json.loads(input_file.readline().strip()))
408
409         # If this is a on-modify event, we are provided with additional
410         # line of input, which provides updated data
411         if modify:
412             task._update_data(json.loads(input_file.readline().strip()))
413
414         return task
415
416     def __init__(self, warrior, **kwargs):
417         self.warrior = warrior
418
419         # Check that user is not able to set read-only value in __init__
420         for key in kwargs.keys():
421             if key in self.read_only_fields:
422                 raise RuntimeError('Field \'%s\' is read-only' % key)
423
424         # We serialize the data in kwargs so that users of the library
425         # do not have to pass different data formats via __setitem__ and
426         # __init__ methods, that would be confusing
427
428         # Rather unfortunate syntax due to python2.6 comaptiblity
429         self._data = dict((key, self._normalize(key, value))
430                           for (key, value) in six.iteritems(kwargs))
431         self._original_data = copy.deepcopy(self._data)
432
433         # Provide read only access to the original data
434         self.original = ReadOnlyDictView(self._original_data)
435
436     def __unicode__(self):
437         return self['description']
438
439     def __eq__(self, other):
440         if self['uuid'] and other['uuid']:
441             # For saved Tasks, just define equality by equality of uuids
442             return self['uuid'] == other['uuid']
443         else:
444             # If the tasks are not saved, compare the actual instances
445             return id(self) == id(other)
446
447
448     def __hash__(self):
449         if self['uuid']:
450             # For saved Tasks, just define equality by equality of uuids
451             return self['uuid'].__hash__()
452         else:
453             # If the tasks are not saved, return hash of instance id
454             return id(self).__hash__()
455
456     @property
457     def completed(self):
458         return self['status'] == six.text_type('completed')
459
460     @property
461     def deleted(self):
462         return self['status'] == six.text_type('deleted')
463
464     @property
465     def waiting(self):
466         return self['status'] == six.text_type('waiting')
467
468     @property
469     def pending(self):
470         return self['status'] == six.text_type('pending')
471
472     @property
473     def saved(self):
474         return self['uuid'] is not None or self['id'] is not None
475
476     def serialize_depends(self, cur_dependencies):
477         # Check that all the tasks are saved
478         for task in (cur_dependencies or set()):
479             if not task.saved:
480                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
481                                     'it can be set as dependency.' % task)
482
483         return super(Task, self).serialize_depends(cur_dependencies)
484
485     def format_depends(self):
486         # We need to generate added and removed dependencies list,
487         # since Taskwarrior does not accept redefining dependencies.
488
489         # This cannot be part of serialize_depends, since we need
490         # to keep a list of all depedencies in the _data dictionary,
491         # not just currently added/removed ones
492
493         old_dependencies = self._original_data.get('depends', set())
494
495         added = self['depends'] - old_dependencies
496         removed = old_dependencies - self['depends']
497
498         # Removed dependencies need to be prefixed with '-'
499         return 'depends:' + ','.join(
500                 [t['uuid'] for t in added] +
501                 ['-' + t['uuid'] for t in removed]
502             )
503
504     def format_description(self):
505         # Task version older than 2.4.0 ignores first word of the
506         # task description if description: prefix is used
507         if self.warrior.version < VERSION_2_4_0:
508             return self._data['description']
509         else:
510             return "description:'{0}'".format(self._data['description'] or '')
511
512     def delete(self):
513         if not self.saved:
514             raise Task.NotSaved("Task needs to be saved before it can be deleted")
515
516         # Refresh the status, and raise exception if the task is deleted
517         self.refresh(only_fields=['status'])
518
519         if self.deleted:
520             raise Task.DeletedTask("Task was already deleted")
521
522         self.warrior.execute_command([self['uuid'], 'delete'])
523
524         # Refresh the status again, so that we have updated info stored
525         self.refresh(only_fields=['status'])
526
527
528     def done(self):
529         if not self.saved:
530             raise Task.NotSaved("Task needs to be saved before it can be completed")
531
532         # Refresh, and raise exception if task is already completed/deleted
533         self.refresh(only_fields=['status'])
534
535         if self.completed:
536             raise Task.CompletedTask("Cannot complete a completed task")
537         elif self.deleted:
538             raise Task.DeletedTask("Deleted task cannot be completed")
539
540         self.warrior.execute_command([self['uuid'], 'done'])
541
542         # Refresh the status again, so that we have updated info stored
543         self.refresh(only_fields=['status'])
544
545     def save(self):
546         if self.saved and not self.modified:
547             return
548
549         args = [self['uuid'], 'modify'] if self.saved else ['add']
550         args.extend(self._get_modified_fields_as_args())
551         output = self.warrior.execute_command(args)
552
553         # Parse out the new ID, if the task is being added for the first time
554         if not self.saved:
555             id_lines = [l for l in output if l.startswith('Created task ')]
556
557             # Complain loudly if it seems that more tasks were created
558             # Should not happen
559             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
560                 raise TaskWarriorException("Unexpected output when creating "
561                                            "task: %s" % '\n'.join(id_lines))
562
563             # Circumvent the ID storage, since ID is considered read-only
564             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
565
566         # Refreshing is very important here, as not only modification time
567         # is updated, but arbitrary attribute may have changed due hooks
568         # altering the data before saving
569         self.refresh()
570
571     def add_annotation(self, annotation):
572         if not self.saved:
573             raise Task.NotSaved("Task needs to be saved to add annotation")
574
575         args = [self['uuid'], 'annotate', annotation]
576         self.warrior.execute_command(args)
577         self.refresh(only_fields=['annotations'])
578
579     def remove_annotation(self, annotation):
580         if not self.saved:
581             raise Task.NotSaved("Task needs to be saved to remove annotation")
582
583         if isinstance(annotation, TaskAnnotation):
584             annotation = annotation['description']
585         args = [self['uuid'], 'denotate', annotation]
586         self.warrior.execute_command(args)
587         self.refresh(only_fields=['annotations'])
588
589     def _get_modified_fields_as_args(self):
590         args = []
591
592         def add_field(field):
593             # Add the output of format_field method to args list (defaults to
594             # field:value)
595             serialized_value = self._serialize(field, self._data[field])
596
597             # Empty values should not be enclosed in quotation marks, see
598             # TW-1510
599             if serialized_value is '':
600                 escaped_serialized_value = ''
601             else:
602                 escaped_serialized_value = "'{0}'".format(serialized_value)
603
604             format_default = lambda: "{0}:{1}".format(field,
605                                                       escaped_serialized_value)
606
607             format_func = getattr(self, 'format_{0}'.format(field),
608                                   format_default)
609
610             args.append(format_func())
611
612         # If we're modifying saved task, simply pass on all modified fields
613         if self.saved:
614             for field in self._modified_fields:
615                 add_field(field)
616         # For new tasks, pass all fields that make sense
617         else:
618             for field in self._data.keys():
619                 if field in self.read_only_fields:
620                     continue
621                 add_field(field)
622
623         return args
624
625     def refresh(self, only_fields=[]):
626         # Raise error when trying to refresh a task that has not been saved
627         if not self.saved:
628             raise Task.NotSaved("Task needs to be saved to be refreshed")
629
630         # We need to use ID as backup for uuid here for the refreshes
631         # of newly saved tasks. Any other place in the code is fine
632         # with using UUID only.
633         args = [self['uuid'] or self['id'], 'export']
634         new_data = json.loads(self.warrior.execute_command(args)[0])
635         if only_fields:
636             to_update = dict(
637                 [(k, new_data.get(k)) for k in only_fields])
638             self._update_data(to_update, update_original=True)
639         else:
640             self._load_data(new_data)
641
642 class TaskFilter(SerializingObject):
643     """
644     A set of parameters to filter the task list with.
645     """
646
647     def __init__(self, filter_params=[]):
648         self.filter_params = filter_params
649
650     def add_filter(self, filter_str):
651         self.filter_params.append(filter_str)
652
653     def add_filter_param(self, key, value):
654         key = key.replace('__', '.')
655
656         # Replace the value with empty string, since that is the
657         # convention in TW for empty values
658         attribute_key = key.split('.')[0]
659
660         # Since this is user input, we need to normalize before we serialize
661         value = self._normalize(key, value)
662         value = self._serialize(attribute_key, value)
663
664         # If we are filtering by uuid:, do not use uuid keyword
665         # due to TW-1452 bug
666         if key == 'uuid':
667             self.filter_params.insert(0, value)
668         else:
669             # Surround value with aphostrophes unless it's a empty string
670             value = "'%s'" % value if value else ''
671
672             # We enforce equality match by using 'is' (or 'none') modifier
673             # Without using this syntax, filter fails due to TW-1479
674             modifier = '.is' if value else '.none'
675             key = key + modifier if '.' not in key else key
676
677             self.filter_params.append("{0}:{1}".format(key, value))
678
679     def get_filter_params(self):
680         return [f for f in self.filter_params if f]
681
682     def clone(self):
683         c = self.__class__()
684         c.filter_params = list(self.filter_params)
685         return c
686
687
688 class TaskQuerySet(object):
689     """
690     Represents a lazy lookup for a task objects.
691     """
692
693     def __init__(self, warrior=None, filter_obj=None):
694         self.warrior = warrior
695         self._result_cache = None
696         self.filter_obj = filter_obj or TaskFilter()
697
698     def __deepcopy__(self, memo):
699         """
700         Deep copy of a QuerySet doesn't populate the cache
701         """
702         obj = self.__class__()
703         for k, v in self.__dict__.items():
704             if k in ('_iter', '_result_cache'):
705                 obj.__dict__[k] = None
706             else:
707                 obj.__dict__[k] = copy.deepcopy(v, memo)
708         return obj
709
710     def __repr__(self):
711         data = list(self[:REPR_OUTPUT_SIZE + 1])
712         if len(data) > REPR_OUTPUT_SIZE:
713             data[-1] = "...(remaining elements truncated)..."
714         return repr(data)
715
716     def __len__(self):
717         if self._result_cache is None:
718             self._result_cache = list(self)
719         return len(self._result_cache)
720
721     def __iter__(self):
722         if self._result_cache is None:
723             self._result_cache = self._execute()
724         return iter(self._result_cache)
725
726     def __getitem__(self, k):
727         if self._result_cache is None:
728             self._result_cache = list(self)
729         return self._result_cache.__getitem__(k)
730
731     def __bool__(self):
732         if self._result_cache is not None:
733             return bool(self._result_cache)
734         try:
735             next(iter(self))
736         except StopIteration:
737             return False
738         return True
739
740     def __nonzero__(self):
741         return type(self).__bool__(self)
742
743     def _clone(self, klass=None, **kwargs):
744         if klass is None:
745             klass = self.__class__
746         filter_obj = self.filter_obj.clone()
747         c = klass(warrior=self.warrior, filter_obj=filter_obj)
748         c.__dict__.update(kwargs)
749         return c
750
751     def _execute(self):
752         """
753         Fetch the tasks which match the current filters.
754         """
755         return self.warrior.filter_tasks(self.filter_obj)
756
757     def all(self):
758         """
759         Returns a new TaskQuerySet that is a copy of the current one.
760         """
761         return self._clone()
762
763     def pending(self):
764         return self.filter(status=PENDING)
765
766     def completed(self):
767         return self.filter(status=COMPLETED)
768
769     def filter(self, *args, **kwargs):
770         """
771         Returns a new TaskQuerySet with the given filters added.
772         """
773         clone = self._clone()
774         for f in args:
775             clone.filter_obj.add_filter(f)
776         for key, value in kwargs.items():
777             clone.filter_obj.add_filter_param(key, value)
778         return clone
779
780     def get(self, **kwargs):
781         """
782         Performs the query and returns a single object matching the given
783         keyword arguments.
784         """
785         clone = self.filter(**kwargs)
786         num = len(clone)
787         if num == 1:
788             return clone._result_cache[0]
789         if not num:
790             raise Task.DoesNotExist(
791                 'Task matching query does not exist. '
792                 'Lookup parameters were {0}'.format(kwargs))
793         raise ValueError(
794             'get() returned more than one Task -- it returned {0}! '
795             'Lookup parameters were {1}'.format(num, kwargs))
796
797
798 class TaskWarrior(object):
799     def __init__(self, data_location='~/.task', create=True):
800         data_location = os.path.expanduser(data_location)
801         if create and not os.path.exists(data_location):
802             os.makedirs(data_location)
803         self.config = {
804             'data.location': os.path.expanduser(data_location),
805             'confirmation': 'no',
806             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
807             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
808         }
809         self.tasks = TaskQuerySet(self)
810         self.version = self._get_version()
811
812     def _get_command_args(self, args, config_override={}):
813         command_args = ['task', 'rc:/']
814         config = self.config.copy()
815         config.update(config_override)
816         for item in config.items():
817             command_args.append('rc.{0}={1}'.format(*item))
818         command_args.extend(map(str, args))
819         return command_args
820
821     def _get_version(self):
822         p = subprocess.Popen(
823                 ['task', '--version'],
824                 stdout=subprocess.PIPE,
825                 stderr=subprocess.PIPE)
826         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
827         return stdout.strip('\n')
828
829     def execute_command(self, args, config_override={}, allow_failure=True):
830         command_args = self._get_command_args(
831             args, config_override=config_override)
832         logger.debug(' '.join(command_args))
833         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
834                              stderr=subprocess.PIPE)
835         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
836         if p.returncode and allow_failure:
837             if stderr.strip():
838                 error_msg = stderr.strip().splitlines()[-1]
839             else:
840                 error_msg = stdout.strip()
841             raise TaskWarriorException(error_msg)
842         return stdout.strip().split('\n')
843
844     def enforce_recurrence(self):
845         # Run arbitrary report command which will trigger generation
846         # of recurrent tasks.
847         # TODO: Make a version dependant enforcement once
848         #       TW-1531 is handled
849         self.execute_command(['next'], allow_failure=False)
850
851     def filter_tasks(self, filter_obj):
852         self.enforce_recurrence()
853         args = ['export', '--'] + filter_obj.get_filter_params()
854         tasks = []
855         for line in self.execute_command(args):
856             if line:
857                 data = line.strip(',')
858                 try:
859                     filtered_task = Task(self)
860                     filtered_task._load_data(json.loads(data))
861                     tasks.append(filtered_task)
862                 except ValueError:
863                     raise TaskWarriorException('Invalid JSON: %s' % data)
864         return tasks
865
866     def merge_with(self, path, push=False):
867         path = path.rstrip('/') + '/'
868         self.execute_command(['merge', path], config_override={
869             'merge.autopush': 'yes' if push else 'no',
870         })
871
872     def undo(self):
873         self.execute_command(['undo'])