浏览代码

Add $INCLUDE support

This doesn't support the origin stacking defined in RFC 1034 because:

1. Doing so would require us to maintain a real $INCLUDE stack instead
   of just chaining iterators; and

2. The expected use case is including automatically-generated snippets
   in zones that are being maintained with zc, so there's no real need
   for origin fiddling anyway because whatever automation is
   generating the snippets can just generate FQDNs if necessary.

If really needed, we could fix this, but, YAGNI.
Rob Austein 1 年之前
父节点
当前提交
df853d56f4
共有 2 个文件被更改,包括 60 次插入28 次删除
  1. 18 4
      README.md
  2. 42 24
      zc

+ 18 - 4
README.md

@@ -423,12 +423,26 @@ Examples:
     $RANGE dhcp-f{:03x} 10.1.0.50 10.2.255.254 50
 
 
-#### `$INCLUDE` and `$GENERATE` ####
+#### `$INCLUDE` ####
 
-The `$INCLUDE` and `$GENERATE` control operators are not currently implemented.
+`$INCLUDE` is a standard control operator, but for the main expected
+`zc` use cases there's not much need for it.
 
-`$INCLUDE` is a standard control operator, but we appear to have no
-current need for it.
+`zc` supports a limited form of the `$INCLUDE` operator, intended
+mainly for automation (that is, for cases where one wants to include a
+machine-generated set of DNS data into a larger zone that you're
+maintaining with `zc`).  Limitations:
+
+1. `zc` doesn't support the optional `origin` field of the `$INCLUDE`
+   operator as defined in RFC 1035.
+
+2. `zc` does *not* preserve the current `$ORIGIN` value of the outer
+   file while processing `$INCLUDE`, so if the included file changes
+   the `$ORIGIN`, the outer file will see that change.  Don't do that.
+
+#### `$GENERATE` ####
+
+The `$GENERATE` control operators is not currently implemented.
 
 `$GENERATE` is a BIND-specific control operator.  We could implement
 it if there were a real need, but the `$RANGE` operator covers the

+ 42 - 24
zc

@@ -32,6 +32,7 @@ from argparse           import ArgumentParser, ArgumentDefaultsHelpFormatter, \
                                RawDescriptionHelpFormatter, FileType
 from socket             import inet_ntop, inet_pton, AF_INET, AF_INET6
 from collections        import OrderedDict
+from itertools          import chain
 
 import dns.reversename
 import dns.rdataclass
@@ -152,24 +153,28 @@ class ZoneGen(object):
       + $MAP <boolean>
       + $RANGE <start-addr> <stop-addr> [<offset> [<multiplier> [<mapaddr>]]]
       + $REVERSE_ZONE <zone-name> [<zone-name> ...]
+      + $INCLUDE <file-name>
 
-    At present $INCLUDE and $GENERATE are not supported: we don't really need the former,
-    and $RANGE is (intended as) a replacement for the latter.
+    At present $GENERATE is not supported: $RANGE is (intended as) a replacement.
     """
 
-    def __init__(self, input, filename, now, reverse):
+    def __init__(self, input, now, reverse, opener):
         self.input      = input
-        self.filename   = filename
         self.now        = now
+        self.opener     = opener
         self.lines      = []
         self.origin     = None
         self.cur_origin = None
         self.map        = OrderedDict()
         self.map_enable = False
         self.reverse    = []
-        logger.info("Compiling zone %s", filename)
+        last_filename   = None
         try:
-            for self.lineno, self.line in enumerate(input, 1):
+            while True:
+                self.lineno, self.line, self.filename = next(self.input)
+                if self.filename != last_filename:
+                    logger.info("Compiling %s", self.filename)
+                    last_filename = self.filename
                 self.line = self.line.rstrip()
                 part = self.line.partition(";")
                 token = part[0].split()
@@ -191,6 +196,8 @@ class ZoneGen(object):
                     self.rr(name, addr, comment)
                     if self.map_enable:
                         self.map_rr(name, addr, comment)
+        except StopIteration:
+            pass
         except Exception as e:
             logger.error("{self.filename}:{self.lineno}: {e!s}: {self.line}\n".format(self = self, e = e))
             sys.exit(1)
@@ -251,8 +258,8 @@ class ZoneGen(object):
     def handle_MAP(self, cmd):
         self.map_enable = self.get_mapping_state(cmd)
 
-    def handle_INCLUDE(self, name):
-        raise NotImplementedError("Not implemented")
+    def handle_INCLUDE(self, filename):
+        self.input = chain(self.opener(filename), self.input)
 
     def handle_GENERATE(self, name, *args):
         raise NotImplementedError("Not implemented (try $RANGE)")
@@ -296,7 +303,7 @@ class ZoneGen(object):
                         z.find_rdataset(rname, PTR, create = True).add(rdata, ttl)
                         break
                 else:
-                    logger.warn("%29s (%-16s %s) does not match any given reverse zone", rname, addr, name)
+                    logger.warning("%29s (%-16s %s) does not match any given reverse zone", rname, addr, name)
 
 
 class ZoneHerd(object):
@@ -307,13 +314,13 @@ class ZoneHerd(object):
     a confirmation dance when running as git {pre,post}-receive hooks
     """
 
-    def __init__(self, inputs, outdir, tempword = "RENMWO"):
+    def __init__(self, inputs, outdir, opener, tempword = "RENMWO"):
         self.names = OrderedDict()
         atexit.register(self.cleanup)
 
         now = int(time.time())
         reverse = OrderedDict()
-        forward = [ZoneGen(lines, name, now, reverse) for lines, name in inputs]
+        forward = [ZoneGen(input, now, reverse, opener) for input in inputs]
 
         header = ";; Generated by zc at {time}, do not edit by hand\n\n".format(
             time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)))
@@ -368,13 +375,17 @@ class GitView(object):
                 self.commit = newsha
                 break
         if self.commit is not None:
-            tree = self.repo.commit(self.commit).tree
-            self.jcfg = json.loads(tree["config.json"].data_stream.read())
+            self.tree = self.repo.commit(self.commit).tree
+            self.jcfg = json.load(self.tree["config.json"].data_stream)
             log_level = self.jcfg.get("log-level", "warning").strip()
             self.stderr_logger.setLevel(log_levels[log_level])
-            self.zone_blobs = [tree[name] for name in self.jcfg["zones"]]
+            self.zone_inputs = [self.opener(name) for name in self.jcfg["zones"]]
             self.log_user_hook_commit()
 
+    def opener(self, name):
+        for lineno, line in enumerate(self.tree[name].data_stream.read().decode().splitlines(), 1):
+            yield lineno, line, name
+
     def configure_logging(self):
         self.stderr_logger = logging.StreamHandler()
         self.stderr_logger.setLevel(logging.WARNING)
@@ -465,7 +476,14 @@ def cli_main():
 
     logging.basicConfig(format = "%(message)s", level = log_levels[args.log_level])
 
-    herd = ZoneHerd(((input, input.name) for input in args.input), args.output_directory)
+    def opener(f):
+        if isinstance(f, str):
+            f = open(f, "r")
+        with f:
+            for lineno, line in enumerate(f, 1):
+                yield lineno, line, f.name
+
+    herd = ZoneHerd((opener(input) for input in args.input), args.output_directory, opener)
     herd.finish()
 
 
@@ -507,9 +525,7 @@ def pre_receive_main():
         if not stat.S_ISFIFO(os.fstat(fifo).st_mode):
             raise RuntimeError("{} is not a FIFO!".format(gv.fifo_name))
 
-        herd = ZoneHerd(((blob.data_stream.read().splitlines(), blob.name) for blob in gv.zone_blobs),
-                        gv.outdir,
-                        gv.commit)
+        herd = ZoneHerd(gv.zone_inputs, gv.outdir, gv.opener, gv.commit)
 
         logging.getLogger().removeHandler(gv.stderr_logger)
 
@@ -524,7 +540,7 @@ def pre_receive_main():
             t = time.time()
             if not select.select([fifo], [], [], remaining)[0]:
                 break               # Timeout
-            chunk = os.read(fifo, 1024)
+            chunk = os.read(fifo, 1024).decode()
             if chunk == "":
                 break               # EOF
             confirmation += chunk
@@ -533,11 +549,12 @@ def pre_receive_main():
                 herd.finish()       # Success
                 if gv.postcmd:
                     logger.info("Running post-command %r", gv.postcmd)
-                    proc = subprocess.Popen(gv.postcmd, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
-                    for line in proc.stdout.read().splitlines():
-                        logger.info(">> %s", line)
-                    proc.stdout.close()
-                    proc.wait()
+                    with subprocess.Popen(gv.postcmd,
+                                          stdout = subprocess.PIPE,
+                                          stderr = subprocess.STDOUT,
+                                          text = True, errors = "backslashreplace") as proc:
+                        for line in proc.stdout:
+                            logger.info(">> %s", line.rstrip())
                 break
             remaining -= time.time() - t
 
@@ -559,6 +576,7 @@ def post_receive_main():
         gv = GitView()
         if gv.commit is not None:
             with open(gv.fifo_name, "w") as f:
+                logger.debug("Commit: %s", gv.commit)
                 f.write(gv.commit + "\n")
     except Exception as e:
         logger.error("%s", e)