]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e0e10409451cfeb5d2c38e72453cf980b4b18265
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import pytz
8 import six
9 import sys
10 import subprocess
11 import tzlocal
12
13 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
14 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
15 REPR_OUTPUT_SIZE = 10
16 PENDING = 'pending'
17 COMPLETED = 'completed'
18
19 VERSION_2_1_0 = six.u('2.1.0')
20 VERSION_2_2_0 = six.u('2.2.0')
21 VERSION_2_3_0 = six.u('2.3.0')
22 VERSION_2_4_0 = six.u('2.4.0')
23 VERSION_2_4_1 = six.u('2.4.1')
24 VERSION_2_4_2 = six.u('2.4.2')
25
26 logger = logging.getLogger(__name__)
27 local_zone = tzlocal.get_localzone()
28
29
30 class TaskWarriorException(Exception):
31     pass
32
33
34 class ReadOnlyDictView(object):
35     """
36     Provides simplified read-only view upon dict object.
37     """
38
39     def __init__(self, viewed_dict):
40         self.viewed_dict = viewed_dict
41
42     def __getitem__(self, key):
43         return copy.deepcopy(self.viewed_dict.__getitem__(key))
44
45     def __contains__(self, k):
46         return self.viewed_dict.__contains__(k)
47
48     def __iter__(self):
49         for value in self.viewed_dict:
50             yield copy.deepcopy(value)
51
52     def __len__(self):
53         return len(self.viewed_dict)
54
55     def get(self, key, default=None):
56         return copy.deepcopy(self.viewed_dict.get(key, default))
57
58     def items(self):
59         return [copy.deepcopy(v) for v in self.viewed_dict.items()]
60
61     def values(self):
62         return [copy.deepcopy(v) for v in self.viewed_dict.values()]
63
64
65 class SerializingObject(object):
66     """
67     Common ancestor for TaskResource & TaskFilter, since they both
68     need to serialize arguments.
69
70     Serializing method should hold the following contract:
71       - any empty value (meaning removal of the attribute)
72         is deserialized into a empty string
73       - None denotes a empty value for any attribute
74
75     Deserializing method should hold the following contract:
76       - None denotes a empty value for any attribute (however,
77         this is here as a safeguard, TaskWarrior currently does
78         not export empty-valued attributes) if the attribute
79         is not iterable (e.g. list or set), in which case
80         a empty iterable should be used.
81
82     Normalizing methods should hold the following contract:
83       - They are used to validate and normalize the user input.
84         Any attribute value that comes from the user (during Task
85         initialization, assignign values to Task attributes, or
86         filtering by user-provided values of attributes) is first
87         validated and normalized using the normalize_{key} method.
88       - If validation or normalization fails, normalizer is expected
89         to raise ValueError.
90     """
91
92     def __init__(self, warrior):
93         self.warrior = warrior
94
95     def _deserialize(self, key, value):
96         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
97                                lambda x: x if x != '' else None)
98         return hydrate_func(value)
99
100     def _serialize(self, key, value):
101         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
102                                  lambda x: x if x is not None else '')
103         return dehydrate_func(value)
104
105     def _normalize(self, key, value):
106         """
107         Use normalize_<key> methods to normalize user input. Any user
108         input will be normalized at the moment it is used as filter,
109         or entered as a value of Task attribute.
110         """
111
112         # None value should not be converted by normalizer
113         if value is None:
114             return None
115
116         normalize_func = getattr(self, 'normalize_{0}'.format(key),
117                                  lambda x: x)
118
119         return normalize_func(value)
120
121     def timestamp_serializer(self, date):
122         if not date:
123             return ''
124
125         # Any serialized timestamp should be localized, we need to
126         # convert to UTC before converting to string (DATE_FORMAT uses UTC)
127         date = date.astimezone(pytz.utc)
128
129         return date.strftime(DATE_FORMAT)
130
131     def timestamp_deserializer(self, date_str):
132         if not date_str:
133             return None
134
135         # Return timestamp localized in the local zone
136         naive_timestamp = datetime.datetime.strptime(date_str, DATE_FORMAT)
137         localized_timestamp = pytz.utc.localize(naive_timestamp)
138         return localized_timestamp.astimezone(local_zone)
139
140     def serialize_entry(self, value):
141         return self.timestamp_serializer(value)
142
143     def deserialize_entry(self, value):
144         return self.timestamp_deserializer(value)
145
146     def normalize_entry(self, value):
147         return self.datetime_normalizer(value)
148
149     def serialize_modified(self, value):
150         return self.timestamp_serializer(value)
151
152     def deserialize_modified(self, value):
153         return self.timestamp_deserializer(value)
154
155     def normalize_modified(self, value):
156         return self.datetime_normalizer(value)
157
158     def serialize_start(self, value):
159         return self.timestamp_serializer(value)
160
161     def deserialize_start(self, value):
162         return self.timestamp_deserializer(value)
163
164     def normalize_start(self, value):
165         return self.datetime_normalizer(value)
166
167     def serialize_end(self, value):
168         return self.timestamp_serializer(value)
169
170     def deserialize_end(self, value):
171         return self.timestamp_deserializer(value)
172
173     def normalize_end(self, value):
174         return self.datetime_normalizer(value)
175
176     def serialize_due(self, value):
177         return self.timestamp_serializer(value)
178
179     def deserialize_due(self, value):
180         return self.timestamp_deserializer(value)
181
182     def normalize_due(self, value):
183         return self.datetime_normalizer(value)
184
185     def serialize_scheduled(self, value):
186         return self.timestamp_serializer(value)
187
188     def deserialize_scheduled(self, value):
189         return self.timestamp_deserializer(value)
190
191     def normalize_scheduled(self, value):
192         return self.datetime_normalizer(value)
193
194     def serialize_until(self, value):
195         return self.timestamp_serializer(value)
196
197     def deserialize_until(self, value):
198         return self.timestamp_deserializer(value)
199
200     def normalize_until(self, value):
201         return self.datetime_normalizer(value)
202
203     def serialize_wait(self, value):
204         return self.timestamp_serializer(value)
205
206     def deserialize_wait(self, value):
207         return self.timestamp_deserializer(value)
208
209     def normalize_wait(self, value):
210         return self.datetime_normalizer(value)
211
212     def serialize_annotations(self, value):
213         value = value if value is not None else []
214
215         # This may seem weird, but it's correct, we want to export
216         # a list of dicts as serialized value
217         serialized_annotations = [json.loads(annotation.export_data())
218                                   for annotation in value]
219         return serialized_annotations if serialized_annotations else ''
220
221     def deserialize_annotations(self, data):
222         return [TaskAnnotation(self, d) for d in data] if data else []
223
224     def serialize_tags(self, tags):
225         return ','.join(tags) if tags else ''
226
227     def deserialize_tags(self, tags):
228         if isinstance(tags, six.string_types):
229             return tags.split(',') if tags else []
230         return tags or []
231
232     def serialize_depends(self, value):
233         # Return the list of uuids
234         value = value if value is not None else set()
235         return ','.join(task['uuid'] for task in value)
236
237     def deserialize_depends(self, raw_uuids):
238         raw_uuids = raw_uuids or ''  # Convert None to empty string
239         uuids = raw_uuids.split(',')
240         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
241
242     def datetime_normalizer(self, value):
243         """
244         Normalizes date/datetime value (considered to come from user input)
245         to localized datetime value. Following conversions happen:
246
247         naive date -> localized datetime with the same date, and time=midnight
248         naive datetime -> localized datetime with the same value
249         localized datetime -> localized datetime (no conversion)
250         """
251
252         if (isinstance(value, datetime.date)
253             and not isinstance(value, datetime.datetime)):
254             # Convert to local midnight
255             value_full = datetime.datetime.combine(value, datetime.time.min)
256             localized = local_zone.localize(value_full)
257         elif isinstance(value, datetime.datetime):
258             if value.tzinfo is None:
259                 # Convert to localized datetime object
260                 localized = local_zone.localize(value)
261             else:
262                 # If the value is already localized, there is no need to change
263                 # time zone at this point. Also None is a valid value too.
264                 localized = value
265         elif isinstance(value, six.string_types):
266             # For strings, use 'task calc' to evaluate the string to datetime
267             args = value.split()
268             result = self.warrior.execute_command(['calc'] + args)
269             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
270             localized = local_zone.localize(naive)
271         else:
272             raise ValueError("Provided value could not be converted to "
273                              "datetime, its type is not supported: {}"
274                              .format(type(value)))
275
276         return localized
277
278     def normalize_uuid(self, value):
279         # Enforce sane UUID
280         if not isinstance(value, six.string_types) or value == '':
281             raise ValueError("UUID must be a valid non-empty string, "
282                              "not: {}".format(value))
283
284         return value
285
286
287 class TaskResource(SerializingObject):
288     read_only_fields = []
289
290     def _load_data(self, data):
291         self._data = dict((key, self._deserialize(key, value))
292                           for key, value in data.items())
293         # We need to use a copy for original data, so that changes
294         # are not propagated.
295         self._original_data = copy.deepcopy(self._data)
296
297     def _update_data(self, data, update_original=False):
298         """
299         Low level update of the internal _data dict. Data which are coming as
300         updates should already be serialized. If update_original is True, the
301         original_data dict is updated as well.
302         """
303         self._data.update(dict((key, self._deserialize(key, value))
304                                for key, value in data.items()))
305
306         if update_original:
307             self._original_data = copy.deepcopy(self._data)
308
309
310     def __getitem__(self, key):
311         # This is a workaround to make TaskResource non-iterable
312         # over simple index-based iteration
313         try:
314             int(key)
315             raise StopIteration
316         except ValueError:
317             pass
318
319         if key not in self._data:
320             self._data[key] = self._deserialize(key, None)
321
322         return self._data.get(key)
323
324     def __setitem__(self, key, value):
325         if key in self.read_only_fields:
326             raise RuntimeError('Field \'%s\' is read-only' % key)
327
328         # Normalize the user input before saving it
329         value = self._normalize(key, value)
330         self._data[key] = value
331
332     def __str__(self):
333         s = six.text_type(self.__unicode__())
334         if not six.PY3:
335             s = s.encode('utf-8')
336         return s
337
338     def __repr__(self):
339         return str(self)
340
341     def export_data(self):
342         """
343         Exports current data contained in the Task as JSON
344         """
345
346         # We need to remove spaces for TW-1504, use custom separators
347         data_tuples = ((key, self._serialize(key, value))
348                        for key, value in six.iteritems(self._data))
349
350         # Empty string denotes empty serialized value, we do not want
351         # to pass that to TaskWarrior.
352         data_tuples = filter(lambda t: t[1] is not '', data_tuples)
353         data = dict(data_tuples)
354         return json.dumps(data, separators=(',',':'))
355
356     @property
357     def _modified_fields(self):
358         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
359         for key in writable_fields:
360             new_value = self._data.get(key)
361             old_value = self._original_data.get(key)
362
363             # Make sure not to mark data removal as modified field if the
364             # field originally had some empty value
365             if key in self._data and not new_value and not old_value:
366                 continue
367
368             if new_value != old_value:
369                 yield key
370
371     @property
372     def modified(self):
373         return bool(list(self._modified_fields))
374
375
376 class TaskAnnotation(TaskResource):
377     read_only_fields = ['entry', 'description']
378
379     def __init__(self, task, data={}):
380         self.task = task
381         self._load_data(data)
382         super(TaskAnnotation, self).__init__(task.warrior)
383
384     def remove(self):
385         self.task.remove_annotation(self)
386
387     def __unicode__(self):
388         return self['description']
389
390     def __eq__(self, other):
391         # consider 2 annotations equal if they belong to the same task, and
392         # their data dics are the same
393         return self.task == other.task and self._data == other._data
394
395     __repr__ = __unicode__
396
397
398 class Task(TaskResource):
399     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
400
401     class DoesNotExist(Exception):
402         pass
403
404     class CompletedTask(Exception):
405         """
406         Raised when the operation cannot be performed on the completed task.
407         """
408         pass
409
410     class DeletedTask(Exception):
411         """
412         Raised when the operation cannot be performed on the deleted task.
413         """
414         pass
415
416     class NotSaved(Exception):
417         """
418         Raised when the operation cannot be performed on the task, because
419         it has not been saved to TaskWarrior yet.
420         """
421         pass
422
423     @classmethod
424     def from_input(cls, input_file=sys.stdin, modify=None, warrior=None):
425         """
426         Creates a Task object, directly from the stdin, by reading one line.
427         If modify=True, two lines are used, first line interpreted as the
428         original state of the Task object, and second line as its new,
429         modified value. This is consistent with the TaskWarrior's hook
430         system.
431
432         Object created by this method should not be saved, deleted
433         or refreshed, as t could create a infinite loop. For this
434         reason, TaskWarrior instance is set to None.
435
436         Input_file argument can be used to specify the input file,
437         but defaults to sys.stdin.
438         """
439
440         # Detect the hook type if not given directly
441         name = os.path.basename(sys.argv[0])
442         modify = name.startswith('on-modify') if modify is None else modify
443
444         # Create the TaskWarrior instance if none passed
445         if warrior is None:
446             hook_parent_dir = os.path.dirname(os.path.dirname(sys.argv[0]))
447             warrior = TaskWarrior(data_location=hook_parent_dir)
448
449         # TaskWarrior instance is set to None
450         task = cls(warrior)
451
452         # Load the data from the input
453         task._load_data(json.loads(input_file.readline().strip()))
454
455         # If this is a on-modify event, we are provided with additional
456         # line of input, which provides updated data
457         if modify:
458             task._update_data(json.loads(input_file.readline().strip()))
459
460         return task
461
462     def __init__(self, warrior, **kwargs):
463         super(Task, self).__init__(warrior)
464
465         # Check that user is not able to set read-only value in __init__
466         for key in kwargs.keys():
467             if key in self.read_only_fields:
468                 raise RuntimeError('Field \'%s\' is read-only' % key)
469
470         # We serialize the data in kwargs so that users of the library
471         # do not have to pass different data formats via __setitem__ and
472         # __init__ methods, that would be confusing
473
474         # Rather unfortunate syntax due to python2.6 comaptiblity
475         self._data = dict((key, self._normalize(key, value))
476                           for (key, value) in six.iteritems(kwargs))
477         self._original_data = copy.deepcopy(self._data)
478
479         # Provide read only access to the original data
480         self.original = ReadOnlyDictView(self._original_data)
481
482     def __unicode__(self):
483         return self['description']
484
485     def __eq__(self, other):
486         if self['uuid'] and other['uuid']:
487             # For saved Tasks, just define equality by equality of uuids
488             return self['uuid'] == other['uuid']
489         else:
490             # If the tasks are not saved, compare the actual instances
491             return id(self) == id(other)
492
493
494     def __hash__(self):
495         if self['uuid']:
496             # For saved Tasks, just define equality by equality of uuids
497             return self['uuid'].__hash__()
498         else:
499             # If the tasks are not saved, return hash of instance id
500             return id(self).__hash__()
501
502     @property
503     def completed(self):
504         return self['status'] == six.text_type('completed')
505
506     @property
507     def deleted(self):
508         return self['status'] == six.text_type('deleted')
509
510     @property
511     def waiting(self):
512         return self['status'] == six.text_type('waiting')
513
514     @property
515     def pending(self):
516         return self['status'] == six.text_type('pending')
517
518     @property
519     def saved(self):
520         return self['uuid'] is not None or self['id'] is not None
521
522     def serialize_depends(self, cur_dependencies):
523         # Check that all the tasks are saved
524         for task in (cur_dependencies or set()):
525             if not task.saved:
526                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
527                                     'it can be set as dependency.' % task)
528
529         return super(Task, self).serialize_depends(cur_dependencies)
530
531     def format_depends(self):
532         # We need to generate added and removed dependencies list,
533         # since Taskwarrior does not accept redefining dependencies.
534
535         # This cannot be part of serialize_depends, since we need
536         # to keep a list of all depedencies in the _data dictionary,
537         # not just currently added/removed ones
538
539         old_dependencies = self._original_data.get('depends', set())
540
541         added = self['depends'] - old_dependencies
542         removed = old_dependencies - self['depends']
543
544         # Removed dependencies need to be prefixed with '-'
545         return 'depends:' + ','.join(
546                 [t['uuid'] for t in added] +
547                 ['-' + t['uuid'] for t in removed]
548             )
549
550     def format_description(self):
551         # Task version older than 2.4.0 ignores first word of the
552         # task description if description: prefix is used
553         if self.warrior.version < VERSION_2_4_0:
554             return self._data['description']
555         else:
556             return "description:'{0}'".format(self._data['description'] or '')
557
558     def delete(self):
559         if not self.saved:
560             raise Task.NotSaved("Task needs to be saved before it can be deleted")
561
562         # Refresh the status, and raise exception if the task is deleted
563         self.refresh(only_fields=['status'])
564
565         if self.deleted:
566             raise Task.DeletedTask("Task was already deleted")
567
568         self.warrior.execute_command([self['uuid'], 'delete'])
569
570         # Refresh the status again, so that we have updated info stored
571         self.refresh(only_fields=['status', 'start', 'end'])
572
573     def start(self):
574         if not self.saved:
575             raise Task.NotSaved("Task needs to be saved before it can be started")
576
577         # Refresh, and raise exception if task is already completed/deleted
578         self.refresh(only_fields=['status'])
579
580         if self.completed:
581             raise Task.CompletedTask("Cannot start a completed task")
582         elif self.deleted:
583             raise Task.DeletedTask("Deleted task cannot be started")
584
585         self.warrior.execute_command([self['uuid'], 'start'])
586
587         # Refresh the status again, so that we have updated info stored
588         self.refresh(only_fields=['status', 'start'])
589
590     def done(self):
591         if not self.saved:
592             raise Task.NotSaved("Task needs to be saved before it can be completed")
593
594         # Refresh, and raise exception if task is already completed/deleted
595         self.refresh(only_fields=['status'])
596
597         if self.completed:
598             raise Task.CompletedTask("Cannot complete a completed task")
599         elif self.deleted:
600             raise Task.DeletedTask("Deleted task cannot be completed")
601
602         self.warrior.execute_command([self['uuid'], 'done'])
603
604         # Refresh the status again, so that we have updated info stored
605         self.refresh(only_fields=['status', 'start', 'end'])
606
607     def save(self):
608         if self.saved and not self.modified:
609             return
610
611         args = [self['uuid'], 'modify'] if self.saved else ['add']
612         args.extend(self._get_modified_fields_as_args())
613         output = self.warrior.execute_command(args)
614
615         # Parse out the new ID, if the task is being added for the first time
616         if not self.saved:
617             id_lines = [l for l in output if l.startswith('Created task ')]
618
619             # Complain loudly if it seems that more tasks were created
620             # Should not happen
621             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
622                 raise TaskWarriorException("Unexpected output when creating "
623                                            "task: %s" % '\n'.join(id_lines))
624
625             # Circumvent the ID storage, since ID is considered read-only
626             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
627
628         # Refreshing is very important here, as not only modification time
629         # is updated, but arbitrary attribute may have changed due hooks
630         # altering the data before saving
631         self.refresh()
632
633     def add_annotation(self, annotation):
634         if not self.saved:
635             raise Task.NotSaved("Task needs to be saved to add annotation")
636
637         args = [self['uuid'], 'annotate', annotation]
638         self.warrior.execute_command(args)
639         self.refresh(only_fields=['annotations'])
640
641     def remove_annotation(self, annotation):
642         if not self.saved:
643             raise Task.NotSaved("Task needs to be saved to remove annotation")
644
645         if isinstance(annotation, TaskAnnotation):
646             annotation = annotation['description']
647         args = [self['uuid'], 'denotate', annotation]
648         self.warrior.execute_command(args)
649         self.refresh(only_fields=['annotations'])
650
651     def _get_modified_fields_as_args(self):
652         args = []
653
654         def add_field(field):
655             # Add the output of format_field method to args list (defaults to
656             # field:value)
657             serialized_value = self._serialize(field, self._data[field])
658
659             # Empty values should not be enclosed in quotation marks, see
660             # TW-1510
661             if serialized_value is '':
662                 escaped_serialized_value = ''
663             else:
664                 escaped_serialized_value = "'{0}'".format(serialized_value)
665
666             format_default = lambda: "{0}:{1}".format(field,
667                                                       escaped_serialized_value)
668
669             format_func = getattr(self, 'format_{0}'.format(field),
670                                   format_default)
671
672             args.append(format_func())
673
674         # If we're modifying saved task, simply pass on all modified fields
675         if self.saved:
676             for field in self._modified_fields:
677                 add_field(field)
678         # For new tasks, pass all fields that make sense
679         else:
680             for field in self._data.keys():
681                 if field in self.read_only_fields:
682                     continue
683                 add_field(field)
684
685         return args
686
687     def refresh(self, only_fields=[]):
688         # Raise error when trying to refresh a task that has not been saved
689         if not self.saved:
690             raise Task.NotSaved("Task needs to be saved to be refreshed")
691
692         # We need to use ID as backup for uuid here for the refreshes
693         # of newly saved tasks. Any other place in the code is fine
694         # with using UUID only.
695         args = [self['uuid'] or self['id'], 'export']
696         new_data = json.loads(self.warrior.execute_command(args)[0])
697         if only_fields:
698             to_update = dict(
699                 [(k, new_data.get(k)) for k in only_fields])
700             self._update_data(to_update, update_original=True)
701         else:
702             self._load_data(new_data)
703
704 class TaskFilter(SerializingObject):
705     """
706     A set of parameters to filter the task list with.
707     """
708
709     def __init__(self, warrior, filter_params=[]):
710         self.filter_params = filter_params
711         super(TaskFilter, self).__init__(warrior)
712
713     def add_filter(self, filter_str):
714         self.filter_params.append(filter_str)
715
716     def add_filter_param(self, key, value):
717         key = key.replace('__', '.')
718
719         # Replace the value with empty string, since that is the
720         # convention in TW for empty values
721         attribute_key = key.split('.')[0]
722
723         # Since this is user input, we need to normalize before we serialize
724         value = self._normalize(attribute_key, value)
725         value = self._serialize(attribute_key, value)
726
727         # If we are filtering by uuid:, do not use uuid keyword
728         # due to TW-1452 bug
729         if key == 'uuid':
730             self.filter_params.insert(0, value)
731         else:
732             # Surround value with aphostrophes unless it's a empty string
733             value = "'%s'" % value if value else ''
734
735             # We enforce equality match by using 'is' (or 'none') modifier
736             # Without using this syntax, filter fails due to TW-1479
737             modifier = '.is' if value else '.none'
738             key = key + modifier if '.' not in key else key
739
740             self.filter_params.append("{0}:{1}".format(key, value))
741
742     def get_filter_params(self):
743         return [f for f in self.filter_params if f]
744
745     def clone(self):
746         c = self.__class__(self.warrior)
747         c.filter_params = list(self.filter_params)
748         return c
749
750
751 class TaskQuerySet(object):
752     """
753     Represents a lazy lookup for a task objects.
754     """
755
756     def __init__(self, warrior=None, filter_obj=None):
757         self.warrior = warrior
758         self._result_cache = None
759         self.filter_obj = filter_obj or TaskFilter(warrior)
760
761     def __deepcopy__(self, memo):
762         """
763         Deep copy of a QuerySet doesn't populate the cache
764         """
765         obj = self.__class__()
766         for k, v in self.__dict__.items():
767             if k in ('_iter', '_result_cache'):
768                 obj.__dict__[k] = None
769             else:
770                 obj.__dict__[k] = copy.deepcopy(v, memo)
771         return obj
772
773     def __repr__(self):
774         data = list(self[:REPR_OUTPUT_SIZE + 1])
775         if len(data) > REPR_OUTPUT_SIZE:
776             data[-1] = "...(remaining elements truncated)..."
777         return repr(data)
778
779     def __len__(self):
780         if self._result_cache is None:
781             self._result_cache = list(self)
782         return len(self._result_cache)
783
784     def __iter__(self):
785         if self._result_cache is None:
786             self._result_cache = self._execute()
787         return iter(self._result_cache)
788
789     def __getitem__(self, k):
790         if self._result_cache is None:
791             self._result_cache = list(self)
792         return self._result_cache.__getitem__(k)
793
794     def __bool__(self):
795         if self._result_cache is not None:
796             return bool(self._result_cache)
797         try:
798             next(iter(self))
799         except StopIteration:
800             return False
801         return True
802
803     def __nonzero__(self):
804         return type(self).__bool__(self)
805
806     def _clone(self, klass=None, **kwargs):
807         if klass is None:
808             klass = self.__class__
809         filter_obj = self.filter_obj.clone()
810         c = klass(warrior=self.warrior, filter_obj=filter_obj)
811         c.__dict__.update(kwargs)
812         return c
813
814     def _execute(self):
815         """
816         Fetch the tasks which match the current filters.
817         """
818         return self.warrior.filter_tasks(self.filter_obj)
819
820     def all(self):
821         """
822         Returns a new TaskQuerySet that is a copy of the current one.
823         """
824         return self._clone()
825
826     def pending(self):
827         return self.filter(status=PENDING)
828
829     def completed(self):
830         return self.filter(status=COMPLETED)
831
832     def filter(self, *args, **kwargs):
833         """
834         Returns a new TaskQuerySet with the given filters added.
835         """
836         clone = self._clone()
837         for f in args:
838             clone.filter_obj.add_filter(f)
839         for key, value in kwargs.items():
840             clone.filter_obj.add_filter_param(key, value)
841         return clone
842
843     def get(self, **kwargs):
844         """
845         Performs the query and returns a single object matching the given
846         keyword arguments.
847         """
848         clone = self.filter(**kwargs)
849         num = len(clone)
850         if num == 1:
851             return clone._result_cache[0]
852         if not num:
853             raise Task.DoesNotExist(
854                 'Task matching query does not exist. '
855                 'Lookup parameters were {0}'.format(kwargs))
856         raise ValueError(
857             'get() returned more than one Task -- it returned {0}! '
858             'Lookup parameters were {1}'.format(num, kwargs))
859
860
861 class TaskWarrior(object):
862     def __init__(self, data_location='~/.task', create=True, taskrc_location='~/.taskrc'):
863         data_location = os.path.expanduser(data_location)
864         self.taskrc_location = os.path.expanduser(taskrc_location)
865
866         # If taskrc does not exist, pass / to use defaults and avoid creating
867         # dummy .taskrc file by TaskWarrior
868         if not os.path.exists(self.taskrc_location):
869             self.taskrc_location = '/'
870
871         if create and not os.path.exists(data_location):
872             os.makedirs(data_location)
873
874         self.config = {
875             'data.location': data_location,
876             'confirmation': 'no',
877             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
878             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
879         }
880         self.tasks = TaskQuerySet(self)
881         self.version = self._get_version()
882
883     def _get_command_args(self, args, config_override={}):
884         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
885         config = self.config.copy()
886         config.update(config_override)
887         for item in config.items():
888             command_args.append('rc.{0}={1}'.format(*item))
889         command_args.extend(map(str, args))
890         return command_args
891
892     def _get_version(self):
893         p = subprocess.Popen(
894                 ['task', '--version'],
895                 stdout=subprocess.PIPE,
896                 stderr=subprocess.PIPE)
897         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
898         return stdout.strip('\n')
899
900     def execute_command(self, args, config_override={}, allow_failure=True):
901         command_args = self._get_command_args(
902             args, config_override=config_override)
903         logger.debug(' '.join(command_args))
904         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
905                              stderr=subprocess.PIPE)
906         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
907         if p.returncode and allow_failure:
908             if stderr.strip():
909                 error_msg = stderr.strip()
910             else:
911                 error_msg = stdout.strip()
912             raise TaskWarriorException(error_msg)
913         return stdout.strip().split('\n')
914
915     def enforce_recurrence(self):
916         # Run arbitrary report command which will trigger generation
917         # of recurrent tasks.
918
919         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
920         if self.version < VERSION_2_4_2:
921             self.execute_command(['next'], allow_failure=False)
922
923     def filter_tasks(self, filter_obj):
924         self.enforce_recurrence()
925         args = ['export', '--'] + filter_obj.get_filter_params()
926         tasks = []
927         for line in self.execute_command(args):
928             if line:
929                 data = line.strip(',')
930                 try:
931                     filtered_task = Task(self)
932                     filtered_task._load_data(json.loads(data))
933                     tasks.append(filtered_task)
934                 except ValueError:
935                     raise TaskWarriorException('Invalid JSON: %s' % data)
936         return tasks
937
938     def merge_with(self, path, push=False):
939         path = path.rstrip('/') + '/'
940         self.execute_command(['merge', path], config_override={
941             'merge.autopush': 'yes' if push else 'no',
942         })
943
944     def undo(self):
945         self.execute_command(['undo'])