X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/2d80366ac1304c6eff84604d1308ecae5daeef32..4bcae4cf839aeba828bcdc4764369ae790f81d0d:/blib2to3/pgen2/grammar.py?ds=sidebyside

diff --git a/blib2to3/pgen2/grammar.py b/blib2to3/pgen2/grammar.py
index c00cb22..d6f0fc2 100644
--- a/blib2to3/pgen2/grammar.py
+++ b/blib2to3/pgen2/grammar.py
@@ -13,8 +13,9 @@ fallback token code OP, but the parser needs the actual token code.
 """
 
 # Python imports
-import collections
+import os
 import pickle
+import tempfile
 
 # Local imports
 from . import token
@@ -84,11 +85,16 @@ class Grammar(object):
         self.tokens = {}
         self.symbol2label = {}
         self.start = 256
+        # Python 3.7+ parses async as a keyword, not an identifier
+        self.async_keywords = False
 
     def dump(self, filename):
         """Dump the grammar tables to a pickle file."""
-        with open(filename, "wb") as f:
+        with tempfile.NamedTemporaryFile(
+            dir=os.path.dirname(filename), delete=False
+        ) as f:
             pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, filename)
 
     def load(self, filename):
         """Load the grammar tables from a pickle file."""
@@ -105,17 +111,25 @@ class Grammar(object):
         Copy the grammar.
         """
         new = self.__class__()
-        for dict_attr in ("symbol2number", "number2symbol", "dfas", "keywords",
-                          "tokens", "symbol2label"):
+        for dict_attr in (
+            "symbol2number",
+            "number2symbol",
+            "dfas",
+            "keywords",
+            "tokens",
+            "symbol2label",
+        ):
             setattr(new, dict_attr, getattr(self, dict_attr).copy())
         new.labels = self.labels[:]
         new.states = self.states[:]
         new.start = self.start
+        new.async_keywords = self.async_keywords
         return new
 
     def report(self):
         """Dump the grammar tables to standard output, for debugging."""
         from pprint import pprint
+
         print("s2n")
         pprint(self.symbol2number)
         print("n2s")
@@ -179,6 +193,7 @@ opmap_raw = """
 // DOUBLESLASH
 //= DOUBLESLASHEQUAL
 -> RARROW
+:= COLONEQUAL
 """
 
 opmap = {}