X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/63da5d088cd8d38e925d9d45b194705fb5258ecc..d10b56e6f797878b7c76d69777f25907beb4cddd:/blib2to3/pgen2/grammar.py?ds=sidebyside

diff --git a/blib2to3/pgen2/grammar.py b/blib2to3/pgen2/grammar.py
index a1da546..d6f0fc2 100644
--- a/blib2to3/pgen2/grammar.py
+++ b/blib2to3/pgen2/grammar.py
@@ -13,7 +13,9 @@ fallback token code OP, but the parser needs the actual token code.
 """
 
 # Python imports
+import os
 import pickle
+import tempfile
 
 # Local imports
 from . import token
@@ -83,11 +85,16 @@ class Grammar(object):
         self.tokens = {}
         self.symbol2label = {}
         self.start = 256
+        # Python 3.7+ parses async as a keyword, not an identifier
+        self.async_keywords = False
 
     def dump(self, filename):
         """Dump the grammar tables to a pickle file."""
-        with open(filename, "wb") as f:
+        with tempfile.NamedTemporaryFile(
+            dir=os.path.dirname(filename), delete=False
+        ) as f:
             pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, filename)
 
     def load(self, filename):
         """Load the grammar tables from a pickle file."""
@@ -104,17 +111,25 @@ class Grammar(object):
         Copy the grammar.
         """
         new = self.__class__()
-        for dict_attr in ("symbol2number", "number2symbol", "dfas", "keywords",
-                          "tokens", "symbol2label"):
+        for dict_attr in (
+            "symbol2number",
+            "number2symbol",
+            "dfas",
+            "keywords",
+            "tokens",
+            "symbol2label",
+        ):
             setattr(new, dict_attr, getattr(self, dict_attr).copy())
         new.labels = self.labels[:]
         new.states = self.states[:]
         new.start = self.start
+        new.async_keywords = self.async_keywords
         return new
 
     def report(self):
         """Dump the grammar tables to standard output, for debugging."""
         from pprint import pprint
+
         print("s2n")
         pprint(self.symbol2number)
         print("n2s")
@@ -178,6 +193,7 @@ opmap_raw = """
 // DOUBLESLASH
 //= DOUBLESLASHEQUAL
 -> RARROW
+:= COLONEQUAL
 """
 
 opmap = {}