from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum, Flag
from functools import partial, wraps
+import io
import keyword
import logging
from multiprocessing import Manager
"""
if src.suffix == ".pyi":
mode |= FileMode.PYI
- with tokenize.open(src) as src_buffer:
- src_contents = src_buffer.read()
+
+ with open(src, "rb") as buf:
+ newline, encoding, src_contents = prepare_input(buf.read())
try:
dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast, mode=mode
return False
if write_back == write_back.YES:
- with open(src, "w", encoding=src_buffer.encoding) as f:
+ with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
elif write_back == write_back.DIFF:
src_name = f"{src} (original)"
if lock:
lock.acquire()
try:
- sys.stdout.write(diff_contents)
+ f = io.TextIOWrapper(
+ sys.stdout.buffer,
+ encoding=encoding,
+ newline=newline,
+ write_through=True,
+ )
+ f.write(diff_contents)
+ f.detach()
finally:
if lock:
lock.release()
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`.
"""
- src = sys.stdin.read()
+ newline, encoding, src = prepare_input(sys.stdin.buffer.read())
dst = src
try:
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
finally:
if write_back == WriteBack.YES:
- sys.stdout.write(dst)
+ f = io.TextIOWrapper(
+ sys.stdout.buffer,
+ encoding=encoding,
+ newline=newline,
+ write_through=True,
+ )
+ f.write(dst)
+ f.detach()
elif write_back == WriteBack.DIFF:
src_name = "<stdin> (original)"
dst_name = "<stdin> (formatted)"
- sys.stdout.write(diff(src, dst, src_name, dst_name))
+ f = io.TextIOWrapper(
+ sys.stdout.buffer,
+ encoding=encoding,
+ newline=newline,
+ write_through=True,
+ )
+ f.write(diff(src, dst, src_name, dst_name))
+ f.detach()
def format_file_contents(
return dst_contents
+def prepare_input(src: bytes) -> Tuple[str, str, str]:
+ """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+
+ Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
+ universal newlines (i.e. only LF).
+ """
+ srcbuf = io.BytesIO(src)
+ encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+ newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+ srcbuf.seek(0)
+ return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+
+
GRAMMARS = [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
"""Given a string with source, return the lib2to3 Node."""
grammar = pygram.python_grammar_no_print_statement
if src_txt[-1] != "\n":
- nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
- src_txt += nl
+ src_txt += "\n"
for grammar in GRAMMARS:
drv = driver.Driver(grammar, pytree.convert)
try:
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import partial
-from io import StringIO
+from io import BytesIO, TextIOWrapper
import os
from pathlib import Path
import sys
source, expected = read_data("../black")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
- sys.stdin, sys.stdout = StringIO(source), StringIO()
- sys.stdin.name = "<stdin>"
+ sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+ sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.YES
)
expected, _ = read_data("expression.diff")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
- sys.stdin, sys.stdout = StringIO(source), StringIO()
- sys.stdin.name = "<stdin>"
+ sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+ sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
)
tmp_file = Path(black.dump_to_file(source))
hold_stdout = sys.stdout
try:
- sys.stdout = StringIO()
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
sys.stdout.seek(0)
actual = sys.stdout.read()
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
self.assertEqual(result.exit_code, 2)
+ def test_preserves_line_endings(self) -> None:
+ with TemporaryDirectory() as workspace:
+ test_file = Path(workspace) / "test.py"
+ for nl in ["\n", "\r\n"]:
+ contents = nl.join(["def f( ):", " pass"])
+ test_file.write_bytes(contents.encode())
+ ff(test_file, write_back=black.WriteBack.YES)
+ updated_contents: bytes = test_file.read_bytes()
+ self.assertIn(nl.encode(), updated_contents) # type: ignore
+ if nl == "\n":
+ self.assertNotIn(b"\r\n", updated_contents) # type: ignore
+
if __name__ == "__main__":
unittest.main()