Browse Source

Add $INCLUDE support

This doesn't support the origin stacking defined in RFC 1034 because:

1. Doing so would require us to maintain a real $INCLUDE stack instead
   of just chaining iterators; and

2. The expected use case is including automatically-generated snippets
   in zones that are being maintained with zc, so there's no real need
   for origin fiddling anyway because whatever automation is
   generating the snippets can just generate FQDNs if necessary.

If really needed, we could fix this, but, YAGNI.
Rob Austein 1 year ago
parent
commit
df853d56f4
2 changed files with 60 additions and 28 deletions
  1. 18 4
      README.md
  2. 42 24
      zc

+ 18 - 4
README.md

@@ -423,12 +423,26 @@ Examples:
     $RANGE dhcp-f{:03x} 10.1.0.50 10.2.255.254 50
     $RANGE dhcp-f{:03x} 10.1.0.50 10.2.255.254 50
 
 
 
 
-#### `$INCLUDE` and `$GENERATE` ####
+#### `$INCLUDE` ####
 
 
-The `$INCLUDE` and `$GENERATE` control operators are not currently implemented.
+`$INCLUDE` is a standard control operator, but for the main expected
+`zc` use cases there's not much need for it.
 
 
-`$INCLUDE` is a standard control operator, but we appear to have no
-current need for it.
+`zc` supports a limited form of the `$INCLUDE` operator, intended
+mainly for automation (that is, for cases where one wants to include a
+machine-generated set of DNS data into a larger zone that you're
+maintaining with `zc`).  Limitations:
+
+1. `zc` doesn't support the optional `origin` field of the `$INCLUDE`
+   operator as defined in RFC 1035.
+
+2. `zc` does *not* preserve the current `$ORIGIN` value of the outer
+   file while processing `$INCLUDE`, so if the included file changes
+   the `$ORIGIN`, the outer file will see that change.  Don't do that.
+
+#### `$GENERATE` ####
+
+The `$GENERATE` control operators is not currently implemented.
 
 
 `$GENERATE` is a BIND-specific control operator.  We could implement
 `$GENERATE` is a BIND-specific control operator.  We could implement
 it if there were a real need, but the `$RANGE` operator covers the
 it if there were a real need, but the `$RANGE` operator covers the

+ 42 - 24
zc

@@ -32,6 +32,7 @@ from argparse           import ArgumentParser, ArgumentDefaultsHelpFormatter, \
                                RawDescriptionHelpFormatter, FileType
                                RawDescriptionHelpFormatter, FileType
 from socket             import inet_ntop, inet_pton, AF_INET, AF_INET6
 from socket             import inet_ntop, inet_pton, AF_INET, AF_INET6
 from collections        import OrderedDict
 from collections        import OrderedDict
+from itertools          import chain
 
 
 import dns.reversename
 import dns.reversename
 import dns.rdataclass
 import dns.rdataclass
@@ -152,24 +153,28 @@ class ZoneGen(object):
       + $MAP <boolean>
       + $MAP <boolean>
       + $RANGE <start-addr> <stop-addr> [<offset> [<multiplier> [<mapaddr>]]]
       + $RANGE <start-addr> <stop-addr> [<offset> [<multiplier> [<mapaddr>]]]
       + $REVERSE_ZONE <zone-name> [<zone-name> ...]
       + $REVERSE_ZONE <zone-name> [<zone-name> ...]
+      + $INCLUDE <file-name>
 
 
-    At present $INCLUDE and $GENERATE are not supported: we don't really need the former,
-    and $RANGE is (intended as) a replacement for the latter.
+    At present $GENERATE is not supported: $RANGE is (intended as) a replacement.
     """
     """
 
 
-    def __init__(self, input, filename, now, reverse):
+    def __init__(self, input, now, reverse, opener):
         self.input      = input
         self.input      = input
-        self.filename   = filename
         self.now        = now
         self.now        = now
+        self.opener     = opener
         self.lines      = []
         self.lines      = []
         self.origin     = None
         self.origin     = None
         self.cur_origin = None
         self.cur_origin = None
         self.map        = OrderedDict()
         self.map        = OrderedDict()
         self.map_enable = False
         self.map_enable = False
         self.reverse    = []
         self.reverse    = []
-        logger.info("Compiling zone %s", filename)
+        last_filename   = None
         try:
         try:
-            for self.lineno, self.line in enumerate(input, 1):
+            while True:
+                self.lineno, self.line, self.filename = next(self.input)
+                if self.filename != last_filename:
+                    logger.info("Compiling %s", self.filename)
+                    last_filename = self.filename
                 self.line = self.line.rstrip()
                 self.line = self.line.rstrip()
                 part = self.line.partition(";")
                 part = self.line.partition(";")
                 token = part[0].split()
                 token = part[0].split()
@@ -191,6 +196,8 @@ class ZoneGen(object):
                     self.rr(name, addr, comment)
                     self.rr(name, addr, comment)
                     if self.map_enable:
                     if self.map_enable:
                         self.map_rr(name, addr, comment)
                         self.map_rr(name, addr, comment)
+        except StopIteration:
+            pass
         except Exception as e:
         except Exception as e:
             logger.error("{self.filename}:{self.lineno}: {e!s}: {self.line}\n".format(self = self, e = e))
             logger.error("{self.filename}:{self.lineno}: {e!s}: {self.line}\n".format(self = self, e = e))
             sys.exit(1)
             sys.exit(1)
@@ -251,8 +258,8 @@ class ZoneGen(object):
     def handle_MAP(self, cmd):
     def handle_MAP(self, cmd):
         self.map_enable = self.get_mapping_state(cmd)
         self.map_enable = self.get_mapping_state(cmd)
 
 
-    def handle_INCLUDE(self, name):
-        raise NotImplementedError("Not implemented")
+    def handle_INCLUDE(self, filename):
+        self.input = chain(self.opener(filename), self.input)
 
 
     def handle_GENERATE(self, name, *args):
     def handle_GENERATE(self, name, *args):
         raise NotImplementedError("Not implemented (try $RANGE)")
         raise NotImplementedError("Not implemented (try $RANGE)")
@@ -296,7 +303,7 @@ class ZoneGen(object):
                         z.find_rdataset(rname, PTR, create = True).add(rdata, ttl)
                         z.find_rdataset(rname, PTR, create = True).add(rdata, ttl)
                         break
                         break
                 else:
                 else:
-                    logger.warn("%29s (%-16s %s) does not match any given reverse zone", rname, addr, name)
+                    logger.warning("%29s (%-16s %s) does not match any given reverse zone", rname, addr, name)
 
 
 
 
 class ZoneHerd(object):
 class ZoneHerd(object):
@@ -307,13 +314,13 @@ class ZoneHerd(object):
     a confirmation dance when running as git {pre,post}-receive hooks
     a confirmation dance when running as git {pre,post}-receive hooks
     """
     """
 
 
-    def __init__(self, inputs, outdir, tempword = "RENMWO"):
+    def __init__(self, inputs, outdir, opener, tempword = "RENMWO"):
         self.names = OrderedDict()
         self.names = OrderedDict()
         atexit.register(self.cleanup)
         atexit.register(self.cleanup)
 
 
         now = int(time.time())
         now = int(time.time())
         reverse = OrderedDict()
         reverse = OrderedDict()
-        forward = [ZoneGen(lines, name, now, reverse) for lines, name in inputs]
+        forward = [ZoneGen(input, now, reverse, opener) for input in inputs]
 
 
         header = ";; Generated by zc at {time}, do not edit by hand\n\n".format(
         header = ";; Generated by zc at {time}, do not edit by hand\n\n".format(
             time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)))
             time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)))
@@ -368,13 +375,17 @@ class GitView(object):
                 self.commit = newsha
                 self.commit = newsha
                 break
                 break
         if self.commit is not None:
         if self.commit is not None:
-            tree = self.repo.commit(self.commit).tree
-            self.jcfg = json.loads(tree["config.json"].data_stream.read())
+            self.tree = self.repo.commit(self.commit).tree
+            self.jcfg = json.load(self.tree["config.json"].data_stream)
             log_level = self.jcfg.get("log-level", "warning").strip()
             log_level = self.jcfg.get("log-level", "warning").strip()
             self.stderr_logger.setLevel(log_levels[log_level])
             self.stderr_logger.setLevel(log_levels[log_level])
-            self.zone_blobs = [tree[name] for name in self.jcfg["zones"]]
+            self.zone_inputs = [self.opener(name) for name in self.jcfg["zones"]]
             self.log_user_hook_commit()
             self.log_user_hook_commit()
 
 
+    def opener(self, name):
+        for lineno, line in enumerate(self.tree[name].data_stream.read().decode().splitlines(), 1):
+            yield lineno, line, name
+
     def configure_logging(self):
     def configure_logging(self):
         self.stderr_logger = logging.StreamHandler()
         self.stderr_logger = logging.StreamHandler()
         self.stderr_logger.setLevel(logging.WARNING)
         self.stderr_logger.setLevel(logging.WARNING)
@@ -465,7 +476,14 @@ def cli_main():
 
 
     logging.basicConfig(format = "%(message)s", level = log_levels[args.log_level])
     logging.basicConfig(format = "%(message)s", level = log_levels[args.log_level])
 
 
-    herd = ZoneHerd(((input, input.name) for input in args.input), args.output_directory)
+    def opener(f):
+        if isinstance(f, str):
+            f = open(f, "r")
+        with f:
+            for lineno, line in enumerate(f, 1):
+                yield lineno, line, f.name
+
+    herd = ZoneHerd((opener(input) for input in args.input), args.output_directory, opener)
     herd.finish()
     herd.finish()
 
 
 
 
@@ -507,9 +525,7 @@ def pre_receive_main():
         if not stat.S_ISFIFO(os.fstat(fifo).st_mode):
         if not stat.S_ISFIFO(os.fstat(fifo).st_mode):
             raise RuntimeError("{} is not a FIFO!".format(gv.fifo_name))
             raise RuntimeError("{} is not a FIFO!".format(gv.fifo_name))
 
 
-        herd = ZoneHerd(((blob.data_stream.read().splitlines(), blob.name) for blob in gv.zone_blobs),
-                        gv.outdir,
-                        gv.commit)
+        herd = ZoneHerd(gv.zone_inputs, gv.outdir, gv.opener, gv.commit)
 
 
         logging.getLogger().removeHandler(gv.stderr_logger)
         logging.getLogger().removeHandler(gv.stderr_logger)
 
 
@@ -524,7 +540,7 @@ def pre_receive_main():
             t = time.time()
             t = time.time()
             if not select.select([fifo], [], [], remaining)[0]:
             if not select.select([fifo], [], [], remaining)[0]:
                 break               # Timeout
                 break               # Timeout
-            chunk = os.read(fifo, 1024)
+            chunk = os.read(fifo, 1024).decode()
             if chunk == "":
             if chunk == "":
                 break               # EOF
                 break               # EOF
             confirmation += chunk
             confirmation += chunk
@@ -533,11 +549,12 @@ def pre_receive_main():
                 herd.finish()       # Success
                 herd.finish()       # Success
                 if gv.postcmd:
                 if gv.postcmd:
                     logger.info("Running post-command %r", gv.postcmd)
                     logger.info("Running post-command %r", gv.postcmd)
-                    proc = subprocess.Popen(gv.postcmd, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
-                    for line in proc.stdout.read().splitlines():
-                        logger.info(">> %s", line)
-                    proc.stdout.close()
-                    proc.wait()
+                    with subprocess.Popen(gv.postcmd,
+                                          stdout = subprocess.PIPE,
+                                          stderr = subprocess.STDOUT,
+                                          text = True, errors = "backslashreplace") as proc:
+                        for line in proc.stdout:
+                            logger.info(">> %s", line.rstrip())
                 break
                 break
             remaining -= time.time() - t
             remaining -= time.time() - t
 
 
@@ -559,6 +576,7 @@ def post_receive_main():
         gv = GitView()
         gv = GitView()
         if gv.commit is not None:
         if gv.commit is not None:
             with open(gv.fifo_name, "w") as f:
             with open(gv.fifo_name, "w") as f:
+                logger.debug("Commit: %s", gv.commit)
                 f.write(gv.commit + "\n")
                 f.write(gv.commit + "\n")
     except Exception as e:
     except Exception as e:
         logger.error("%s", e)
         logger.error("%s", e)