]> git.madduck.net Git - etc/taskwarrior.git/blob - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

0f8d963140949683cc878f9a192a83bf1802bfb4
[etc/taskwarrior.git] / tasklib / task.py
1 from __future__ import print_function
2 import copy
3 import datetime
4 import json
5 import logging
6 import os
7 import pytz
8 import re
9 import six
10 import sys
11 import subprocess
12 import tzlocal
13
14 DATE_FORMAT = '%Y%m%dT%H%M%SZ'
15 DATE_FORMAT_CALC = '%Y-%m-%dT%H:%M:%S'
16 REPR_OUTPUT_SIZE = 10
17 PENDING = 'pending'
18 COMPLETED = 'completed'
19
20 VERSION_2_1_0 = six.u('2.1.0')
21 VERSION_2_2_0 = six.u('2.2.0')
22 VERSION_2_3_0 = six.u('2.3.0')
23 VERSION_2_4_0 = six.u('2.4.0')
24 VERSION_2_4_1 = six.u('2.4.1')
25 VERSION_2_4_2 = six.u('2.4.2')
26 VERSION_2_4_3 = six.u('2.4.3')
27
28 logger = logging.getLogger(__name__)
29 local_zone = tzlocal.get_localzone()
30
31
32 class TaskWarriorException(Exception):
33     pass
34
35
36 class ReadOnlyDictView(object):
37     """
38     Provides simplified read-only view upon dict object.
39     """
40
41     def __init__(self, viewed_dict):
42         self.viewed_dict = viewed_dict
43
44     def __getitem__(self, key):
45         return copy.deepcopy(self.viewed_dict.__getitem__(key))
46
47     def __contains__(self, k):
48         return self.viewed_dict.__contains__(k)
49
50     def __iter__(self):
51         for value in self.viewed_dict:
52             yield copy.deepcopy(value)
53
54     def __len__(self):
55         return len(self.viewed_dict)
56
57     def get(self, key, default=None):
58         return copy.deepcopy(self.viewed_dict.get(key, default))
59
60     def items(self):
61         return [copy.deepcopy(v) for v in self.viewed_dict.items()]
62
63     def values(self):
64         return [copy.deepcopy(v) for v in self.viewed_dict.values()]
65
66
67 class SerializingObject(object):
68     """
69     Common ancestor for TaskResource & TaskFilter, since they both
70     need to serialize arguments.
71
72     Serializing method should hold the following contract:
73       - any empty value (meaning removal of the attribute)
74         is deserialized into a empty string
75       - None denotes a empty value for any attribute
76
77     Deserializing method should hold the following contract:
78       - None denotes a empty value for any attribute (however,
79         this is here as a safeguard, TaskWarrior currently does
80         not export empty-valued attributes) if the attribute
81         is not iterable (e.g. list or set), in which case
82         a empty iterable should be used.
83
84     Normalizing methods should hold the following contract:
85       - They are used to validate and normalize the user input.
86         Any attribute value that comes from the user (during Task
87         initialization, assignign values to Task attributes, or
88         filtering by user-provided values of attributes) is first
89         validated and normalized using the normalize_{key} method.
90       - If validation or normalization fails, normalizer is expected
91         to raise ValueError.
92     """
93
94     def __init__(self, warrior):
95         self.warrior = warrior
96
97     def _deserialize(self, key, value):
98         hydrate_func = getattr(self, 'deserialize_{0}'.format(key),
99                                lambda x: x if x != '' else None)
100         return hydrate_func(value)
101
102     def _serialize(self, key, value):
103         dehydrate_func = getattr(self, 'serialize_{0}'.format(key),
104                                  lambda x: x if x is not None else '')
105         return dehydrate_func(value)
106
107     def _normalize(self, key, value):
108         """
109         Use normalize_<key> methods to normalize user input. Any user
110         input will be normalized at the moment it is used as filter,
111         or entered as a value of Task attribute.
112         """
113
114         # None value should not be converted by normalizer
115         if value is None:
116             return None
117
118         normalize_func = getattr(self, 'normalize_{0}'.format(key),
119                                  lambda x: x)
120
121         return normalize_func(value)
122
123     def timestamp_serializer(self, date):
124         if not date:
125             return ''
126
127         # Any serialized timestamp should be localized, we need to
128         # convert to UTC before converting to string (DATE_FORMAT uses UTC)
129         date = date.astimezone(pytz.utc)
130
131         return date.strftime(DATE_FORMAT)
132
133     def timestamp_deserializer(self, date_str):
134         if not date_str:
135             return None
136
137         # Return timestamp localized in the local zone
138         naive_timestamp = datetime.datetime.strptime(date_str, DATE_FORMAT)
139         localized_timestamp = pytz.utc.localize(naive_timestamp)
140         return localized_timestamp.astimezone(local_zone)
141
142     def serialize_entry(self, value):
143         return self.timestamp_serializer(value)
144
145     def deserialize_entry(self, value):
146         return self.timestamp_deserializer(value)
147
148     def normalize_entry(self, value):
149         return self.datetime_normalizer(value)
150
151     def serialize_modified(self, value):
152         return self.timestamp_serializer(value)
153
154     def deserialize_modified(self, value):
155         return self.timestamp_deserializer(value)
156
157     def normalize_modified(self, value):
158         return self.datetime_normalizer(value)
159
160     def serialize_start(self, value):
161         return self.timestamp_serializer(value)
162
163     def deserialize_start(self, value):
164         return self.timestamp_deserializer(value)
165
166     def normalize_start(self, value):
167         return self.datetime_normalizer(value)
168
169     def serialize_end(self, value):
170         return self.timestamp_serializer(value)
171
172     def deserialize_end(self, value):
173         return self.timestamp_deserializer(value)
174
175     def normalize_end(self, value):
176         return self.datetime_normalizer(value)
177
178     def serialize_due(self, value):
179         return self.timestamp_serializer(value)
180
181     def deserialize_due(self, value):
182         return self.timestamp_deserializer(value)
183
184     def normalize_due(self, value):
185         return self.datetime_normalizer(value)
186
187     def serialize_scheduled(self, value):
188         return self.timestamp_serializer(value)
189
190     def deserialize_scheduled(self, value):
191         return self.timestamp_deserializer(value)
192
193     def normalize_scheduled(self, value):
194         return self.datetime_normalizer(value)
195
196     def serialize_until(self, value):
197         return self.timestamp_serializer(value)
198
199     def deserialize_until(self, value):
200         return self.timestamp_deserializer(value)
201
202     def normalize_until(self, value):
203         return self.datetime_normalizer(value)
204
205     def serialize_wait(self, value):
206         return self.timestamp_serializer(value)
207
208     def deserialize_wait(self, value):
209         return self.timestamp_deserializer(value)
210
211     def normalize_wait(self, value):
212         return self.datetime_normalizer(value)
213
214     def serialize_annotations(self, value):
215         value = value if value is not None else []
216
217         # This may seem weird, but it's correct, we want to export
218         # a list of dicts as serialized value
219         serialized_annotations = [json.loads(annotation.export_data())
220                                   for annotation in value]
221         return serialized_annotations if serialized_annotations else ''
222
223     def deserialize_annotations(self, data):
224         return [TaskAnnotation(self, d) for d in data] if data else []
225
226     def serialize_tags(self, tags):
227         return ','.join(tags) if tags else ''
228
229     def deserialize_tags(self, tags):
230         if isinstance(tags, six.string_types):
231             return tags.split(',') if tags else []
232         return tags or []
233
234     def serialize_depends(self, value):
235         # Return the list of uuids
236         value = value if value is not None else set()
237         return ','.join(task['uuid'] for task in value)
238
239     def deserialize_depends(self, raw_uuids):
240         raw_uuids = raw_uuids or ''  # Convert None to empty string
241         uuids = raw_uuids.split(',')
242         return set(self.warrior.tasks.get(uuid=uuid) for uuid in uuids if uuid)
243
244     def datetime_normalizer(self, value):
245         """
246         Normalizes date/datetime value (considered to come from user input)
247         to localized datetime value. Following conversions happen:
248
249         naive date -> localized datetime with the same date, and time=midnight
250         naive datetime -> localized datetime with the same value
251         localized datetime -> localized datetime (no conversion)
252         """
253
254         if (isinstance(value, datetime.date)
255             and not isinstance(value, datetime.datetime)):
256             # Convert to local midnight
257             value_full = datetime.datetime.combine(value, datetime.time.min)
258             localized = local_zone.localize(value_full)
259         elif isinstance(value, datetime.datetime):
260             if value.tzinfo is None:
261                 # Convert to localized datetime object
262                 localized = local_zone.localize(value)
263             else:
264                 # If the value is already localized, there is no need to change
265                 # time zone at this point. Also None is a valid value too.
266                 localized = value
267         elif (isinstance(value, six.string_types)
268                 and self.warrior.version >= VERSION_2_4_0):
269             # For strings, use 'task calc' to evaluate the string to datetime
270             # available since TW 2.4.0
271             args = value.split()
272             result = self.warrior.execute_command(['calc'] + args)
273             naive = datetime.datetime.strptime(result[0], DATE_FORMAT_CALC)
274             localized = local_zone.localize(naive)
275         else:
276             raise ValueError("Provided value could not be converted to "
277                              "datetime, its type is not supported: {}"
278                              .format(type(value)))
279
280         return localized
281
282     def normalize_uuid(self, value):
283         # Enforce sane UUID
284         if not isinstance(value, six.string_types) or value == '':
285             raise ValueError("UUID must be a valid non-empty string, "
286                              "not: {}".format(value))
287
288         return value
289
290
291 class TaskResource(SerializingObject):
292     read_only_fields = []
293
294     def _load_data(self, data):
295         self._data = dict((key, self._deserialize(key, value))
296                           for key, value in data.items())
297         # We need to use a copy for original data, so that changes
298         # are not propagated.
299         self._original_data = copy.deepcopy(self._data)
300
301     def _update_data(self, data, update_original=False, remove_missing=False):
302         """
303         Low level update of the internal _data dict. Data which are coming as
304         updates should already be serialized. If update_original is True, the
305         original_data dict is updated as well.
306         """
307         self._data.update(dict((key, self._deserialize(key, value))
308                                for key, value in data.items()))
309
310         # In certain situations, we want to treat missing keys as removals
311         if remove_missing:
312             for key in set(self._data.keys()) - set(data.keys()):
313                 self._data[key] = None
314
315         if update_original:
316             self._original_data = copy.deepcopy(self._data)
317
318
319     def __getitem__(self, key):
320         # This is a workaround to make TaskResource non-iterable
321         # over simple index-based iteration
322         try:
323             int(key)
324             raise StopIteration
325         except ValueError:
326             pass
327
328         if key not in self._data:
329             self._data[key] = self._deserialize(key, None)
330
331         return self._data.get(key)
332
333     def __setitem__(self, key, value):
334         if key in self.read_only_fields:
335             raise RuntimeError('Field \'%s\' is read-only' % key)
336
337         # Normalize the user input before saving it
338         value = self._normalize(key, value)
339         self._data[key] = value
340
341     def __str__(self):
342         s = six.text_type(self.__unicode__())
343         if not six.PY3:
344             s = s.encode('utf-8')
345         return s
346
347     def __repr__(self):
348         return str(self)
349
350     def export_data(self):
351         """
352         Exports current data contained in the Task as JSON
353         """
354
355         # We need to remove spaces for TW-1504, use custom separators
356         data_tuples = ((key, self._serialize(key, value))
357                        for key, value in six.iteritems(self._data))
358
359         # Empty string denotes empty serialized value, we do not want
360         # to pass that to TaskWarrior.
361         data_tuples = filter(lambda t: t[1] is not '', data_tuples)
362         data = dict(data_tuples)
363         return json.dumps(data, separators=(',',':'))
364
365     @property
366     def _modified_fields(self):
367         writable_fields = set(self._data.keys()) - set(self.read_only_fields)
368         for key in writable_fields:
369             new_value = self._data.get(key)
370             old_value = self._original_data.get(key)
371
372             # Make sure not to mark data removal as modified field if the
373             # field originally had some empty value
374             if key in self._data and not new_value and not old_value:
375                 continue
376
377             if new_value != old_value:
378                 yield key
379
380     @property
381     def modified(self):
382         return bool(list(self._modified_fields))
383
384
385 class TaskAnnotation(TaskResource):
386     read_only_fields = ['entry', 'description']
387
388     def __init__(self, task, data={}):
389         self.task = task
390         self._load_data(data)
391         super(TaskAnnotation, self).__init__(task.warrior)
392
393     def remove(self):
394         self.task.remove_annotation(self)
395
396     def __unicode__(self):
397         return self['description']
398
399     def __eq__(self, other):
400         # consider 2 annotations equal if they belong to the same task, and
401         # their data dics are the same
402         return self.task == other.task and self._data == other._data
403
404     __repr__ = __unicode__
405
406
407 class Task(TaskResource):
408     read_only_fields = ['id', 'entry', 'urgency', 'uuid', 'modified']
409
410     class DoesNotExist(Exception):
411         pass
412
413     class CompletedTask(Exception):
414         """
415         Raised when the operation cannot be performed on the completed task.
416         """
417         pass
418
419     class DeletedTask(Exception):
420         """
421         Raised when the operation cannot be performed on the deleted task.
422         """
423         pass
424
425     class ActiveTask(Exception):
426         """
427         Raised when the operation cannot be performed on the active task.
428         """
429         pass
430
431     class InactiveTask(Exception):
432         """
433         Raised when the operation cannot be performed on an inactive task.
434         """
435         pass
436
437     class NotSaved(Exception):
438         """
439         Raised when the operation cannot be performed on the task, because
440         it has not been saved to TaskWarrior yet.
441         """
442         pass
443
444     @classmethod
445     def from_input(cls, input_file=sys.stdin, modify=None, warrior=None):
446         """
447         Creates a Task object, directly from the stdin, by reading one line.
448         If modify=True, two lines are used, first line interpreted as the
449         original state of the Task object, and second line as its new,
450         modified value. This is consistent with the TaskWarrior's hook
451         system.
452
453         Object created by this method should not be saved, deleted
454         or refreshed, as t could create a infinite loop. For this
455         reason, TaskWarrior instance is set to None.
456
457         Input_file argument can be used to specify the input file,
458         but defaults to sys.stdin.
459         """
460
461         # Detect the hook type if not given directly
462         name = os.path.basename(sys.argv[0])
463         modify = name.startswith('on-modify') if modify is None else modify
464
465         # Create the TaskWarrior instance if none passed
466         if warrior is None:
467             hook_parent_dir = os.path.dirname(os.path.dirname(sys.argv[0]))
468             warrior = TaskWarrior(data_location=hook_parent_dir)
469
470         # TaskWarrior instance is set to None
471         task = cls(warrior)
472
473         # Load the data from the input
474         task._load_data(json.loads(input_file.readline().strip()))
475
476         # If this is a on-modify event, we are provided with additional
477         # line of input, which provides updated data
478         if modify:
479             task._update_data(json.loads(input_file.readline().strip()),
480                               remove_missing=True)
481
482         return task
483
484     def __init__(self, warrior, **kwargs):
485         super(Task, self).__init__(warrior)
486
487         # Check that user is not able to set read-only value in __init__
488         for key in kwargs.keys():
489             if key in self.read_only_fields:
490                 raise RuntimeError('Field \'%s\' is read-only' % key)
491
492         # We serialize the data in kwargs so that users of the library
493         # do not have to pass different data formats via __setitem__ and
494         # __init__ methods, that would be confusing
495
496         # Rather unfortunate syntax due to python2.6 comaptiblity
497         self._data = dict((key, self._normalize(key, value))
498                           for (key, value) in six.iteritems(kwargs))
499         self._original_data = copy.deepcopy(self._data)
500
501         # Provide read only access to the original data
502         self.original = ReadOnlyDictView(self._original_data)
503
504     def __unicode__(self):
505         return self['description']
506
507     def __eq__(self, other):
508         if self['uuid'] and other['uuid']:
509             # For saved Tasks, just define equality by equality of uuids
510             return self['uuid'] == other['uuid']
511         else:
512             # If the tasks are not saved, compare the actual instances
513             return id(self) == id(other)
514
515
516     def __hash__(self):
517         if self['uuid']:
518             # For saved Tasks, just define equality by equality of uuids
519             return self['uuid'].__hash__()
520         else:
521             # If the tasks are not saved, return hash of instance id
522             return id(self).__hash__()
523
524     @property
525     def completed(self):
526         return self['status'] == six.text_type('completed')
527
528     @property
529     def deleted(self):
530         return self['status'] == six.text_type('deleted')
531
532     @property
533     def waiting(self):
534         return self['status'] == six.text_type('waiting')
535
536     @property
537     def pending(self):
538         return self['status'] == six.text_type('pending')
539
540     @property
541     def active(self):
542         return self['start'] is not None
543
544     @property
545     def saved(self):
546         return self['uuid'] is not None or self['id'] is not None
547
548     def serialize_depends(self, cur_dependencies):
549         # Check that all the tasks are saved
550         for task in (cur_dependencies or set()):
551             if not task.saved:
552                 raise Task.NotSaved('Task \'%s\' needs to be saved before '
553                                     'it can be set as dependency.' % task)
554
555         return super(Task, self).serialize_depends(cur_dependencies)
556
557     def format_depends(self):
558         # We need to generate added and removed dependencies list,
559         # since Taskwarrior does not accept redefining dependencies.
560
561         # This cannot be part of serialize_depends, since we need
562         # to keep a list of all depedencies in the _data dictionary,
563         # not just currently added/removed ones
564
565         old_dependencies = self._original_data.get('depends', set())
566
567         added = self['depends'] - old_dependencies
568         removed = old_dependencies - self['depends']
569
570         # Removed dependencies need to be prefixed with '-'
571         return 'depends:' + ','.join(
572                 [t['uuid'] for t in added] +
573                 ['-' + t['uuid'] for t in removed]
574             )
575
576     def format_description(self):
577         # Task version older than 2.4.0 ignores first word of the
578         # task description if description: prefix is used
579         if self.warrior.version < VERSION_2_4_0:
580             return self._data['description']
581         else:
582             return six.u("description:'{0}'").format(self._data['description'] or '')
583
584     def delete(self):
585         if not self.saved:
586             raise Task.NotSaved("Task needs to be saved before it can be deleted")
587
588         # Refresh the status, and raise exception if the task is deleted
589         self.refresh(only_fields=['status'])
590
591         if self.deleted:
592             raise Task.DeletedTask("Task was already deleted")
593
594         self.warrior.execute_command([self['uuid'], 'delete'])
595
596         # Refresh the status again, so that we have updated info stored
597         self.refresh(only_fields=['status', 'start', 'end'])
598
599     def start(self):
600         if not self.saved:
601             raise Task.NotSaved("Task needs to be saved before it can be started")
602
603         # Refresh, and raise exception if task is already completed/deleted
604         self.refresh(only_fields=['status'])
605
606         if self.completed:
607             raise Task.CompletedTask("Cannot start a completed task")
608         elif self.deleted:
609             raise Task.DeletedTask("Deleted task cannot be started")
610         elif self.active:
611             raise Task.ActiveTask("Task is already active")
612
613         self.warrior.execute_command([self['uuid'], 'start'])
614
615         # Refresh the status again, so that we have updated info stored
616         self.refresh(only_fields=['status', 'start'])
617
618     def stop(self):
619         if not self.saved:
620             raise Task.NotSaved("Task needs to be saved before it can be stopped")
621
622         # Refresh, and raise exception if task is already completed/deleted
623         self.refresh(only_fields=['status'])
624
625         if not self.active:
626             raise Task.InactiveTask("Cannot stop an inactive task")
627
628         self.warrior.execute_command([self['uuid'], 'stop'])
629
630         # Refresh the status again, so that we have updated info stored
631         self.refresh(only_fields=['status', 'start'])
632
633     def done(self):
634         if not self.saved:
635             raise Task.NotSaved("Task needs to be saved before it can be completed")
636
637         # Refresh, and raise exception if task is already completed/deleted
638         self.refresh(only_fields=['status'])
639
640         if self.completed:
641             raise Task.CompletedTask("Cannot complete a completed task")
642         elif self.deleted:
643             raise Task.DeletedTask("Deleted task cannot be completed")
644
645         # Older versions of TW do not stop active task at completion
646         if self.warrior.version < VERSION_2_4_0 and self.active:
647             self.stop()
648
649         self.warrior.execute_command([self['uuid'], 'done'])
650
651         # Refresh the status again, so that we have updated info stored
652         self.refresh(only_fields=['status', 'start', 'end'])
653
654     def save(self):
655         if self.saved and not self.modified:
656             return
657
658         args = [self['uuid'], 'modify'] if self.saved else ['add']
659         args.extend(self._get_modified_fields_as_args())
660         output = self.warrior.execute_command(args)
661
662         # Parse out the new ID, if the task is being added for the first time
663         if not self.saved:
664             id_lines = [l for l in output if l.startswith('Created task ')]
665
666             # Complain loudly if it seems that more tasks were created
667             # Should not happen
668             if len(id_lines) != 1 or len(id_lines[0].split(' ')) != 3:
669                 raise TaskWarriorException("Unexpected output when creating "
670                                            "task: %s" % '\n'.join(id_lines))
671
672             # Circumvent the ID storage, since ID is considered read-only
673             self._data['id'] = int(id_lines[0].split(' ')[2].rstrip('.'))
674
675         # Refreshing is very important here, as not only modification time
676         # is updated, but arbitrary attribute may have changed due hooks
677         # altering the data before saving
678         self.refresh()
679
680     def add_annotation(self, annotation):
681         if not self.saved:
682             raise Task.NotSaved("Task needs to be saved to add annotation")
683
684         args = [self['uuid'], 'annotate', annotation]
685         self.warrior.execute_command(args)
686         self.refresh(only_fields=['annotations'])
687
688     def remove_annotation(self, annotation):
689         if not self.saved:
690             raise Task.NotSaved("Task needs to be saved to remove annotation")
691
692         if isinstance(annotation, TaskAnnotation):
693             annotation = annotation['description']
694         args = [self['uuid'], 'denotate', annotation]
695         self.warrior.execute_command(args)
696         self.refresh(only_fields=['annotations'])
697
698     def _get_modified_fields_as_args(self):
699         args = []
700
701         def add_field(field):
702             # Add the output of format_field method to args list (defaults to
703             # field:value)
704             serialized_value = self._serialize(field, self._data[field])
705
706             # Empty values should not be enclosed in quotation marks, see
707             # TW-1510
708             if serialized_value is '':
709                 escaped_serialized_value = ''
710             else:
711                 escaped_serialized_value = six.u("'{0}'").format(serialized_value)
712
713             format_default = lambda: six.u("{0}:{1}").format(field,
714                                                       escaped_serialized_value)
715
716             format_func = getattr(self, 'format_{0}'.format(field),
717                                   format_default)
718
719             args.append(format_func())
720
721         # If we're modifying saved task, simply pass on all modified fields
722         if self.saved:
723             for field in self._modified_fields:
724                 add_field(field)
725         # For new tasks, pass all fields that make sense
726         else:
727             for field in self._data.keys():
728                 if field in self.read_only_fields:
729                     continue
730                 add_field(field)
731
732         return args
733
734     def refresh(self, only_fields=[]):
735         # Raise error when trying to refresh a task that has not been saved
736         if not self.saved:
737             raise Task.NotSaved("Task needs to be saved to be refreshed")
738
739         # We need to use ID as backup for uuid here for the refreshes
740         # of newly saved tasks. Any other place in the code is fine
741         # with using UUID only.
742         args = [self['uuid'] or self['id'], 'export']
743         new_data = json.loads(self.warrior.execute_command(args)[0])
744         if only_fields:
745             to_update = dict(
746                 [(k, new_data.get(k)) for k in only_fields])
747             self._update_data(to_update, update_original=True)
748         else:
749             self._load_data(new_data)
750
751 class TaskFilter(SerializingObject):
752     """
753     A set of parameters to filter the task list with.
754     """
755
756     def __init__(self, warrior, filter_params=[]):
757         self.filter_params = filter_params
758         super(TaskFilter, self).__init__(warrior)
759
760     def add_filter(self, filter_str):
761         self.filter_params.append(filter_str)
762
763     def add_filter_param(self, key, value):
764         key = key.replace('__', '.')
765
766         # Replace the value with empty string, since that is the
767         # convention in TW for empty values
768         attribute_key = key.split('.')[0]
769
770         # Since this is user input, we need to normalize before we serialize
771         value = self._normalize(attribute_key, value)
772         value = self._serialize(attribute_key, value)
773
774         # If we are filtering by uuid:, do not use uuid keyword
775         # due to TW-1452 bug
776         if key == 'uuid':
777             self.filter_params.insert(0, value)
778         else:
779             # Surround value with aphostrophes unless it's a empty string
780             value = "'%s'" % value if value else ''
781
782             # We enforce equality match by using 'is' (or 'none') modifier
783             # Without using this syntax, filter fails due to TW-1479
784             modifier = '.is' if value else '.none'
785             key = key + modifier if '.' not in key else key
786
787             self.filter_params.append(six.u("{0}:{1}").format(key, value))
788
789     def get_filter_params(self):
790         return [f for f in self.filter_params if f]
791
792     def clone(self):
793         c = self.__class__(self.warrior)
794         c.filter_params = list(self.filter_params)
795         return c
796
797
798 class TaskQuerySet(object):
799     """
800     Represents a lazy lookup for a task objects.
801     """
802
803     def __init__(self, warrior=None, filter_obj=None):
804         self.warrior = warrior
805         self._result_cache = None
806         self.filter_obj = filter_obj or TaskFilter(warrior)
807
808     def __deepcopy__(self, memo):
809         """
810         Deep copy of a QuerySet doesn't populate the cache
811         """
812         obj = self.__class__()
813         for k, v in self.__dict__.items():
814             if k in ('_iter', '_result_cache'):
815                 obj.__dict__[k] = None
816             else:
817                 obj.__dict__[k] = copy.deepcopy(v, memo)
818         return obj
819
820     def __repr__(self):
821         data = list(self[:REPR_OUTPUT_SIZE + 1])
822         if len(data) > REPR_OUTPUT_SIZE:
823             data[-1] = "...(remaining elements truncated)..."
824         return repr(data)
825
826     def __len__(self):
827         if self._result_cache is None:
828             self._result_cache = list(self)
829         return len(self._result_cache)
830
831     def __iter__(self):
832         if self._result_cache is None:
833             self._result_cache = self._execute()
834         return iter(self._result_cache)
835
836     def __getitem__(self, k):
837         if self._result_cache is None:
838             self._result_cache = list(self)
839         return self._result_cache.__getitem__(k)
840
841     def __bool__(self):
842         if self._result_cache is not None:
843             return bool(self._result_cache)
844         try:
845             next(iter(self))
846         except StopIteration:
847             return False
848         return True
849
850     def __nonzero__(self):
851         return type(self).__bool__(self)
852
853     def _clone(self, klass=None, **kwargs):
854         if klass is None:
855             klass = self.__class__
856         filter_obj = self.filter_obj.clone()
857         c = klass(warrior=self.warrior, filter_obj=filter_obj)
858         c.__dict__.update(kwargs)
859         return c
860
861     def _execute(self):
862         """
863         Fetch the tasks which match the current filters.
864         """
865         return self.warrior.filter_tasks(self.filter_obj)
866
867     def all(self):
868         """
869         Returns a new TaskQuerySet that is a copy of the current one.
870         """
871         return self._clone()
872
873     def pending(self):
874         return self.filter(status=PENDING)
875
876     def completed(self):
877         return self.filter(status=COMPLETED)
878
879     def filter(self, *args, **kwargs):
880         """
881         Returns a new TaskQuerySet with the given filters added.
882         """
883         clone = self._clone()
884         for f in args:
885             clone.filter_obj.add_filter(f)
886         for key, value in kwargs.items():
887             clone.filter_obj.add_filter_param(key, value)
888         return clone
889
890     def get(self, **kwargs):
891         """
892         Performs the query and returns a single object matching the given
893         keyword arguments.
894         """
895         clone = self.filter(**kwargs)
896         num = len(clone)
897         if num == 1:
898             return clone._result_cache[0]
899         if not num:
900             raise Task.DoesNotExist(
901                 'Task matching query does not exist. '
902                 'Lookup parameters were {0}'.format(kwargs))
903         raise ValueError(
904             'get() returned more than one Task -- it returned {0}! '
905             'Lookup parameters were {1}'.format(num, kwargs))
906
907
908 class TaskWarrior(object):
909     def __init__(self, data_location=None, create=True, taskrc_location='~/.taskrc'):
910         self.taskrc_location = os.path.expanduser(taskrc_location)
911
912         # If taskrc does not exist, pass / to use defaults and avoid creating
913         # dummy .taskrc file by TaskWarrior
914         if not os.path.exists(self.taskrc_location):
915             self.taskrc_location = '/'
916
917         self.version = self._get_version()
918         self.config = {
919             'confirmation': 'no',
920             'dependency.confirmation': 'no',  # See TW-1483 or taskrc man page
921             'recurrence.confirmation': 'no',  # Necessary for modifying R tasks
922             # 2.4.3 onwards supports 0 as infite bulk, otherwise set just
923             # arbitrary big number which is likely to be large enough
924             'bulk': 0 if self.version >= VERSION_2_4_3 else 100000,
925         }
926
927         # Set data.location override if passed via kwarg
928         if data_location is not None:
929             data_location = os.path.expanduser(data_location)
930             if create and not os.path.exists(data_location):
931                 os.makedirs(data_location)
932             self.config['data.location'] = data_location
933
934         self.tasks = TaskQuerySet(self)
935
936     def _get_command_args(self, args, config_override={}):
937         command_args = ['task', 'rc:{0}'.format(self.taskrc_location)]
938         config = self.config.copy()
939         config.update(config_override)
940         for item in config.items():
941             command_args.append('rc.{0}={1}'.format(*item))
942         command_args.extend(map(six.text_type, args))
943         return command_args
944
945     def _get_version(self):
946         p = subprocess.Popen(
947                 ['task', '--version'],
948                 stdout=subprocess.PIPE,
949                 stderr=subprocess.PIPE)
950         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
951         return stdout.strip('\n')
952
953     def get_config(self):
954         raw_output = self.execute_command(
955                 ['show'],
956                 config_override={'verbose': 'nothing'}
957             )
958
959         config = dict()
960         config_regex = re.compile(r'^(?P<key>[_a-z\.]+)\s+(?P<value>[^\s].+$)')
961
962         for line in raw_output:
963             match = config_regex.match(line)
964             if match:
965                 config[match.group('key')] = match.group('value').strip()
966
967         return config
968
969     def execute_command(self, args, config_override={}, allow_failure=True,
970                         return_all=False):
971         command_args = self._get_command_args(
972             args, config_override=config_override)
973         logger.debug(' '.join(command_args))
974         p = subprocess.Popen(command_args, stdout=subprocess.PIPE,
975                              stderr=subprocess.PIPE)
976         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
977         if p.returncode and allow_failure:
978             if stderr.strip():
979                 error_msg = stderr.strip()
980             else:
981                 error_msg = stdout.strip()
982             raise TaskWarriorException(error_msg)
983
984         # Return all whole triplet only if explicitly asked for
985         if not return_all:
986             return stdout.rstrip().split('\n')
987         else:
988             return (stdout.rstrip().split('\n'),
989                     stderr.rstrip().split('\n'),
990                     p.returncode)
991
992     def enforce_recurrence(self):
993         # Run arbitrary report command which will trigger generation
994         # of recurrent tasks.
995
996         # Only necessary for TW up to 2.4.1, fixed in 2.4.2.
997         if self.version < VERSION_2_4_2:
998             self.execute_command(['next'], allow_failure=False)
999
1000     def filter_tasks(self, filter_obj):
1001         self.enforce_recurrence()
1002         args = ['export', '--'] + filter_obj.get_filter_params()
1003         tasks = []
1004         for line in self.execute_command(args):
1005             if line:
1006                 data = line.strip(',')
1007                 try:
1008                     filtered_task = Task(self)
1009                     filtered_task._load_data(json.loads(data))
1010                     tasks.append(filtered_task)
1011                 except ValueError:
1012                     raise TaskWarriorException('Invalid JSON: %s' % data)
1013         return tasks
1014
1015     def merge_with(self, path, push=False):
1016         path = path.rstrip('/') + '/'
1017         self.execute_command(['merge', path], config_override={
1018             'merge.autopush': 'yes' if push else 'no',
1019         })
1020
1021     def undo(self):
1022         self.execute_command(['undo'])