X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/2fa31ff31469e587ea23cb86308495c4673b5ddd..5bc40707afa5fb53bbc2484ed34f69b011b98172:/black.py diff --git a/black.py b/black.py index b2b2db7..203fbfa 100644 --- a/black.py +++ b/black.py @@ -74,7 +74,9 @@ class CannotSplit(Exception): @click.argument( 'src', nargs=-1, - type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True), + type=click.Path( + exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True + ), ) @click.pass_context def main( @@ -89,6 +91,8 @@ def main( elif p.is_file(): # if a file was explicitly given, we don't care about its extension sources.append(p) + elif s == '-': + sources.append(Path('-')) else: err(f'invalid path: {s}') if len(sources) == 0: @@ -97,9 +101,12 @@ def main( p = sources[0] report = Report() try: - changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check - ) + if not p.is_file() and str(p) == '-': + changed = format_stdin_to_stdout(line_length=line_length, fast=fast) + else: + changed = format_file_in_place( + p, line_length=line_length, fast=fast, write_back=not check + ) report.done(p, changed) except Exception as exc: report.failed(p, str(exc)) @@ -156,34 +163,50 @@ def format_file_in_place( src: Path, line_length: int, fast: bool, write_back: bool = False ) -> bool: """Format the file and rewrite if changed. Return True if changed.""" + with tokenize.open(src) as src_buffer: + src_contents = src_buffer.read() try: - contents, encoding = format_file(src, line_length=line_length, fast=fast) + contents = format_file_contents( + src_contents, line_length=line_length, fast=fast + ) except NothingChanged: return False if write_back: - with open(src, "w", encoding=encoding) as f: + with open(src, "w", encoding=src_buffer.encoding) as f: f.write(contents) return True -def format_file( - src: Path, line_length: int, fast: bool -) -> Tuple[FileContent, Encoding]: +def format_stdin_to_stdout(line_length: int, fast: bool) -> bool: + """Format file on stdin and pipe output to stdout. Return True if changed.""" + contents = sys.stdin.read() + try: + contents = format_file_contents(contents, line_length=line_length, fast=fast) + return True + + except NothingChanged: + return False + + finally: + sys.stdout.write(contents) + + +def format_file_contents( + src_contents: str, line_length: int, fast: bool +) -> FileContent: """Reformats a file and returns its contents and encoding.""" - with tokenize.open(src) as src_buffer: - src_contents = src_buffer.read() if src_contents.strip() == '': - raise NothingChanged(src) + raise NothingChanged dst_contents = format_str(src_contents, line_length=line_length) if src_contents == dst_contents: - raise NothingChanged(src) + raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) assert_stable(src_contents, dst_contents, line_length=line_length) - return dst_contents, src_buffer.encoding + return dst_contents def format_str(src_contents: str, line_length: int) -> FileContent: @@ -463,8 +486,7 @@ class Line: return ( (first_leaf.type == token.NAME and first_leaf.value == 'def') or ( - first_leaf.type == token.NAME - and first_leaf.value == 'async' + first_leaf.type == token.ASYNC and second_leaf is not None and second_leaf.type == token.NAME and second_leaf.value == 'def' @@ -723,8 +745,9 @@ class LineGenerator(Visitor[Line]): def visit_default(self, node: LN) -> Iterator[Line]: if isinstance(node, Leaf): + any_open_brackets = self.current_line.bracket_tracker.any_open_brackets() for comment in generate_comments(node): - if self.current_line.bracket_tracker.any_open_brackets(): + if any_open_brackets: # any comment within brackets is subject to splitting self.current_line.append(comment) elif comment.type == token.COMMENT: @@ -736,7 +759,7 @@ class LineGenerator(Visitor[Line]): # regular standalone comment, to be processed later (see # docstring in `generate_comments()` self.standalone_comments.append(comment) - normalize_prefix(node) + normalize_prefix(node, inside_brackets=any_open_brackets) if node.type not in WHITESPACE: for comment in self.standalone_comments: yield from self.line() @@ -793,7 +816,7 @@ class LineGenerator(Visitor[Line]): for child in children: yield from self.visit(child) - if child.type == token.NAME and child.value == 'async': # type: ignore + if child.type == token.ASYNC: break internal_stmt = next(children) @@ -849,7 +872,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type != syms.subscript: + if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: return NO prev = leaf.prev_sibling @@ -882,7 +905,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO elif prevp.type == token.COLON: - if prevp.parent and prevp.parent.type == syms.subscript: + if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: return NO elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}: @@ -1216,7 +1239,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: current_leaves = body_leaves # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) @@ -1256,7 +1279,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: head_leaves.reverse() # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) @@ -1320,7 +1343,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: - normalize_prefix(current_line.leaves[0]) + normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1331,7 +1354,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: and trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) - normalize_prefix(current_line.leaves[0]) + normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line @@ -1349,13 +1372,18 @@ def is_import(leaf: Leaf) -> bool: ) -def normalize_prefix(leaf: Leaf) -> None: - """Leave existing extra newlines for imports. Remove everything else.""" - if is_import(leaf): +def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: + """Leave existing extra newlines if not `inside_brackets`. + + Remove everything else. Note: don't use backslashes for formatting or + you'll lose your voting rights. + """ + if not inside_brackets: spl = leaf.prefix.split('#', 1) - nl_count = spl[0].count('\n') - leaf.prefix = '\n' * nl_count - return + if '\\' not in spl[0]: + nl_count = spl[0].count('\n') + leaf.prefix = '\n' * nl_count + return leaf.prefix = ''