]> git.madduck.net Git - etc/taskwarrior.git/blobdiff - tasklib/task.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

tests: Test execute command with return_all flag
[etc/taskwarrior.git] / tasklib / task.py
index 79899ba6b38cd349216f2e362e9bd90e9217e335..bb535172f1f48ebd6988e7243b81b486778cd783 100644 (file)
@@ -560,7 +560,7 @@ class Task(TaskResource):
         if self.warrior.version < VERSION_2_4_0:
             return self._data['description']
         else:
-            return "description:'{0}'".format(self._data['description'] or '')
+            return six.u("description:'{0}'").format(self._data['description'] or '')
 
     def delete(self):
         if not self.saved:
@@ -668,9 +668,9 @@ class Task(TaskResource):
             if serialized_value is '':
                 escaped_serialized_value = ''
             else:
-                escaped_serialized_value = "'{0}'".format(serialized_value)
+                escaped_serialized_value = six.u("'{0}'").format(serialized_value)
 
-            format_default = lambda: "{0}:{1}".format(field,
+            format_default = lambda: six.u("{0}:{1}").format(field,
                                                       escaped_serialized_value)
 
             format_func = getattr(self, 'format_{0}'.format(field),
@@ -744,7 +744,7 @@ class TaskFilter(SerializingObject):
             modifier = '.is' if value else '.none'
             key = key + modifier if '.' not in key else key
 
-            self.filter_params.append("{0}:{1}".format(key, value))
+            self.filter_params.append(six.u("{0}:{1}").format(key, value))
 
     def get_filter_params(self):
         return [f for f in self.filter_params if f]
@@ -899,7 +899,7 @@ class TaskWarrior(object):
         config.update(config_override)
         for item in config.items():
             command_args.append('rc.{0}={1}'.format(*item))
-        command_args.extend(map(str, args))
+        command_args.extend(map(six.text_type, args))
         return command_args
 
     def _get_version(self):
@@ -910,7 +910,8 @@ class TaskWarrior(object):
         stdout, stderr = [x.decode('utf-8') for x in p.communicate()]
         return stdout.strip('\n')
 
-    def execute_command(self, args, config_override={}, allow_failure=True):
+    def execute_command(self, args, config_override={}, allow_failure=True,
+                        return_all=False):
         command_args = self._get_command_args(
             args, config_override=config_override)
         logger.debug(' '.join(command_args))
@@ -923,7 +924,14 @@ class TaskWarrior(object):
             else:
                 error_msg = stdout.strip()
             raise TaskWarriorException(error_msg)
-        return stdout.rstrip().split('\n')
+
+        # Return all whole triplet only if explicitly asked for
+        if not return_all:
+            return stdout.rstrip().split('\n')
+        else:
+            return (stdout.rstrip().split('\n'),
+                    stderr.rstrip().split('\n'),
+                    p.returncode)
 
     def enforce_recurrence(self):
         # Run arbitrary report command which will trigger generation