from concurrent.futures import Executor, ProcessPoolExecutor
 from enum import Enum, Flag
 from functools import partial, wraps
+import io
 import keyword
 import logging
 from multiprocessing import Manager
     """
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+
+    with open(src, "rb") as buf:
+        newline, encoding, src_contents = prepare_input(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
         return False
 
     if write_back == write_back.YES:
-        with open(src, "w", encoding=src_buffer.encoding) as f:
+        with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
         src_name = f"{src}  (original)"
         if lock:
             lock.acquire()
         try:
-            sys.stdout.write(diff_contents)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff_contents)
+            f.detach()
         finally:
             if lock:
                 lock.release()
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    src = sys.stdin.read()
+    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
 
     finally:
         if write_back == WriteBack.YES:
-            sys.stdout.write(dst)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(dst)
+            f.detach()
         elif write_back == WriteBack.DIFF:
             src_name = "<stdin>  (original)"
             dst_name = "<stdin>  (formatted)"
-            sys.stdout.write(diff(src, dst, src_name, dst_name))
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff(src, dst, src_name, dst_name))
+            f.detach()
 
 
 def format_file_contents(
     return dst_contents
 
 
+def prepare_input(src: bytes) -> Tuple[str, str, str]:
+    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+
+    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
+    universal newlines (i.e. only LF).
+    """
+    srcbuf = io.BytesIO(src)
+    encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+    srcbuf.seek(0)
+    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+
+
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1] != "\n":
-        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
-        src_txt += nl
+        src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
 
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from functools import partial
-from io import StringIO
+from io import BytesIO, TextIOWrapper
 import os
 from pathlib import Path
 import sys
         source, expected = read_data("../black")
         hold_stdin, hold_stdout = sys.stdin, sys.stdout
         try:
-            sys.stdin, sys.stdout = StringIO(source), StringIO()
-            sys.stdin.name = "<stdin>"
+            sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+            sys.stdin.buffer.name = "<stdin>"  # type: ignore
             black.format_stdin_to_stdout(
                 line_length=ll, fast=True, write_back=black.WriteBack.YES
             )
         expected, _ = read_data("expression.diff")
         hold_stdin, hold_stdout = sys.stdin, sys.stdout
         try:
-            sys.stdin, sys.stdout = StringIO(source), StringIO()
-            sys.stdin.name = "<stdin>"
+            sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+            sys.stdin.buffer.name = "<stdin>"  # type: ignore
             black.format_stdin_to_stdout(
                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
             )
         tmp_file = Path(black.dump_to_file(source))
         hold_stdout = sys.stdout
         try:
-            sys.stdout = StringIO()
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
             sys.stdout.seek(0)
             actual = sys.stdout.read()
             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
             self.assertEqual(result.exit_code, 2)
 
+    def test_preserves_line_endings(self) -> None:
+        with TemporaryDirectory() as workspace:
+            test_file = Path(workspace) / "test.py"
+            for nl in ["\n", "\r\n"]:
+                contents = nl.join(["def f(  ):", "    pass"])
+                test_file.write_bytes(contents.encode())
+                ff(test_file, write_back=black.WriteBack.YES)
+                updated_contents: bytes = test_file.read_bytes()
+                self.assertIn(nl.encode(), updated_contents)  # type: ignore
+                if nl == "\n":
+                    self.assertNotIn(b"\r\n", updated_contents)  # type: ignore
+
 
 if __name__ == "__main__":
     unittest.main()