X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/dbe26161fa68632d608a440666a0960a32630902..00a302560b92951c22f0f4c8d618cf63de39bd57:/black.py diff --git a/black.py b/black.py index e59a1e5..36e49b0 100644 --- a/black.py +++ b/black.py @@ -4,6 +4,7 @@ from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor from enum import Enum, Flag from functools import partial, wraps +import io import keyword import logging from multiprocessing import Manager @@ -465,8 +466,9 @@ def format_file_in_place( """ if src.suffix == ".pyi": mode |= FileMode.PYI - with tokenize.open(src) as src_buffer: - src_contents = src_buffer.read() + + with open(src, "rb") as buf: + newline, encoding, src_contents = prepare_input(buf.read()) try: dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast, mode=mode @@ -475,7 +477,7 @@ def format_file_in_place( return False if write_back == write_back.YES: - with open(src, "w", encoding=src_buffer.encoding) as f: + with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) elif write_back == write_back.DIFF: src_name = f"{src} (original)" @@ -484,7 +486,14 @@ def format_file_in_place( if lock: lock.acquire() try: - sys.stdout.write(diff_contents) + f = io.TextIOWrapper( + sys.stdout.buffer, + encoding=encoding, + newline=newline, + write_through=True, + ) + f.write(diff_contents) + f.detach() finally: if lock: lock.release() @@ -503,7 +512,7 @@ def format_stdin_to_stdout( `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to :func:`format_file_contents`. """ - src = sys.stdin.read() + newline, encoding, src = prepare_input(sys.stdin.buffer.read()) dst = src try: dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode) @@ -514,11 +523,25 @@ def format_stdin_to_stdout( finally: if write_back == WriteBack.YES: - sys.stdout.write(dst) + f = io.TextIOWrapper( + sys.stdout.buffer, + encoding=encoding, + newline=newline, + write_through=True, + ) + f.write(dst) + f.detach() elif write_back == WriteBack.DIFF: src_name = " (original)" dst_name = " (formatted)" - sys.stdout.write(diff(src, dst, src_name, dst_name)) + f = io.TextIOWrapper( + sys.stdout.buffer, + encoding=encoding, + newline=newline, + write_through=True, + ) + f.write(diff(src, dst, src_name, dst_name)) + f.detach() def format_file_contents( @@ -579,6 +602,19 @@ def format_str( return dst_contents +def prepare_input(src: bytes) -> Tuple[str, str, str]: + """Analyze `src` and return a tuple of (newline, encoding, decoded_contents) + + Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with + universal newlines (i.e. only LF). + """ + srcbuf = io.BytesIO(src) + encoding, lines = tokenize.detect_encoding(srcbuf.readline) + newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n" + srcbuf.seek(0) + return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read() + + GRAMMARS = [ pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement, @@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node: """Given a string with source, return the lib2to3 Node.""" grammar = pygram.python_grammar_no_print_statement if src_txt[-1] != "\n": - nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n" - src_txt += nl + src_txt += "\n" for grammar in GRAMMARS: drv = driver.Driver(grammar, pytree.convert) try: