diff options
author | Rob Austein <sra@hactrn.net> | 2017-05-21 22:13:00 -0400 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2017-05-21 22:13:00 -0400 |
commit | 54dc2f126d4921985211b1732d34feaaa5dcb1f8 (patch) | |
tree | 760ba1e97191804f0b3c63efeaf076224df479ea /zc |
First public version.
Diffstat (limited to 'zc')
-rwxr-xr-x | zc | 576 |
1 files changed, 576 insertions, 0 deletions
@@ -0,0 +1,576 @@ +#!/usr/bin/env python + +""" +Generate zone files from a simpl(er) flat text file. + +General intent here is to let users specify normal hosts in a simple +and compact format, with a few utilities we provide to automate +complex or repetitive stuff, including automatic generation of AAAA +RRs based on a mapping scheme from A RRs. + +After generating the text of the forward zone, we run it through +dnspython's zone parser, then generate reverse zones by translating +the A and AAAA RRs in the forward zone into the corresponding PTR RRs. +""" + +from dns.rdatatype import A, AAAA, SOA, NS, PTR +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter, FileType +from socket import inet_ntop, inet_pton, AF_INET, AF_INET6 +from collections import OrderedDict + +import dns.reversename +import dns.rdataclass +import dns.rdatatype +import dns.rdata +import dns.name +import dns.zone + +import logging.handlers +import subprocess +import logging +import atexit +import signal +import select +import fcntl +import stat +import time +import sys +import os + + +logger = logging.getLogger("zc") + +log_levels = OrderedDict((logging.getLevelName(i).lower(), i) + for i in (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR)) + + +class Address(long): + """ + Addresses are integers with some extra code to handle conversion + to and from text strings. + """ + + def __new__(cls, x): + if cls is Address and issubclass(x.__class__, Address): + cls = x.__class__ + if isinstance(x, (str, unicode)): + if cls is Address: + cls = V6 if ":" in x else V4 + x = int(inet_pton(cls.af, str(x)).encode("hex"), 16) + return long.__new__(cls, x) + + @property + def _bytestring(self): + if self < 0: + raise ValueError("value out of range") + return "{0:0{1}x}".format(self, self.bits / 4).decode("hex") + + def __str__(self): + return inet_ntop(self.af, self._bytestring) + + @property + def bytes(self): + return tuple(ord(b) for b in self._bytestring) + + @property + def mask(self): + return (1 << self.bits) - 1 + + @classmethod + def is_instance(cls, obj): + return isinstance(obj, cls) + +class V4(Address): + af = AF_INET + bits = 32 + rrtype = "A" + +class V6(Address): + af = AF_INET6 + bits = 128 + rrtype = "AAAA" + + +class Prefix(object): + """ + Prefixes are an address and a length. + """ + + def __init__(self, x, y = None): + if isinstance(x, (str, unicode)) and y is None: + x, y = x.split("/") + self.net = Address(x) + self.len = int(y) + if self.len < 0 or self.len > self.net.bits: + raise ValueError("Prefix length {0.len!s} is out of range for prefix {0.net!s}".format(self)) + + def __cmp__(self, other): + return cmp(self.net, other.net) or cmp(self.len, other.len) + + def __hash__(self): + return hash(self.net) ^ hash(self.len) + + def __int__(self): + return self.net + + def __long__(self): + return self.net + + def __str__(self): + return "{0.net!s}/{0.len!s}".format(self) + + @property + def subnet_mask(self): + return (self.net.mask >> self.len) ^ self.net.mask + + @property + def host_mask(self): + return ~self.subnet_mask & self.net.mask + + def matches(self, addr): + return self.net.__class__ is addr.__class__ and (self.net ^ addr) & self.subnet_mask == 0 + + +class ZoneGen(object): + """ + Parse input file, line-by-line. Lines can be: + + * Host-address pairs (generate A or AAAA RRs) + * DNS RRs (unchanged) + * Comments, blank lines (unchanged) + * Control operations: + + $ORIGIN <dns-name> + + $TTL <ttl-value> + + $MAP_RULE <prefix> <format> + + $MAP <boolean> + + $RANGE <start-addr> <stop-addr> [<offset> [<multiplier> [<mapaddr>]]] + + $REVERSE_ZONE <zone-name> [<zone-name> ...] + + At present $INCLUDE and $GENERATE are not supported: we don't really need the former, + and $RANGE is (intended as) a replacement for the latter. + """ + + def __init__(self, input, filename, now, reverse): + self.input = input + self.filename = filename + self.now = now + self.lines = [] + self.origin = None + self.cur_origin = None + self.map = OrderedDict() + self.reverse = [] + logger.info("Compiling zone %s", filename) + try: + for self.lineno, self.line in enumerate(input, 1): + self.line = self.line.rstrip() + part = self.line.partition(";") + token = part[0].split() + if token and token[0].startswith("$"): + handler = getattr(self, "handle_" + token[0][1:], None) + if handler is None: + raise ValueError("Unrecognized control operation") + handler(*token[1:]) + elif len(token) != 2: + if len(token) >= 9 and "SOA" in token: + self.line = self.line.replace("@SERIAL@", str(now)) + token[token.index("@SERIAL@")] = str(now) + if len(token) > 0: + self.check_dns(token) + self.lines.append(self.line) + else: + comment = " ;" + part[2] if part[2] else "" + name, addr = token[0], Address(token[1]) + self.rr(name, addr, comment) + if self.map_enable: + self.map_rr(name, addr, comment) + except Exception as e: + logger.error("{self.filename}:{self.lineno}: {e!s}: {self.line}\n".format(self = self, e = e)) + sys.exit(1) + fn = self.origin.to_text(omit_final_dot = True) + logger.debug("Generated zone file %s:", fn) + for i, line in enumerate(self.lines, 1): + logger.debug("[%5d] %s", i, line) + logger.debug("End of generated zone file %s", fn) + self.text = "\n".join(self.lines) + "\n" + self.zone = dns.zone.from_text(self.text, relativize = False, filename = fn) + self.build_reverse(reverse) + + def check_dns(self, token): + try: + dns.name.from_text(token.pop(0)) + if token[0].isdigit(): + del token[0] + if token[0].upper() == "IN": + del token[0] + rdtype = dns.rdatatype.from_text(token.pop(0)) + dns.rdata.from_text(dns.rdataclass.IN, rdtype, " ".join(token), self.cur_origin) + except: + raise ValueError("Syntax error") + + def rr(self, name, addr, comment = ""): + self.lines.append("{name:<23s} {addr.rrtype:<7s} {addr!s}{comment}".format( + name = name, addr = addr, comment = comment)) + + def map_rr(self, name, addr, comment = ""): + for prefix, format in self.map.iteritems(): + if prefix.matches(addr): + self.rr(name, Address(format.format(addr.bytes)), comment) + break + + def to_file(self, f, relativize = None): + f.write(self.text) # "relativize" ignored, present only for dnspython API compatability + + def handle_ORIGIN(self, origin): + self.cur_origin = dns.name.from_text(origin) + if self.origin is None: + self.origin = self.cur_origin + self.lines.append("$ORIGIN {}".format(self.cur_origin.to_text())) + + def handle_TTL(self, ttl): + self.lines.append(self.line) + + def handle_MAP_RULE(self, prefix, format): + self.map[Prefix(prefix)] = format + + _bool_names = dict(yes = True, no = False, on = True, off = False, true = True, false = False) + + def get_mapping_state(self, token): + try: + return self._bool_names[token.lower()] + except: + raise ValueError("Unrecognized mapping state") + + def handle_MAP(self, cmd): + self.map_enable = self.get_mapping_state(cmd) + + def handle_INCLUDE(self, name): + raise NotImplementedError("Not implemented") + + def handle_GENERATE(self, name, *args): + raise NotImplementedError("Not implemented (try $RANGE)") + + def handle_RANGE(self, fmt, start, stop, offset = None, multiplier = None, mapaddr = None): + start = Address(start) + stop = Address(stop) + offset = start.bytes[-1] if offset is None else int(offset, 0) + multiplier = 1 if multiplier is None else int(multiplier, 0) + method = self.rr if mapaddr is None or not self.get_mapping_state(mapaddr) else self.map_rr + for i in xrange(stop - start + 1): + method(fmt.format(offset + i), start.__class__(start + i * multiplier)) + + def handle_REVERSE_ZONE(self, *names): + self.reverse.extend(dns.name.from_text(name) for name in names) + + def build_reverse(self, reverse): + + zones = [] + + for name in self.reverse: + if name not in reverse: + reverse[name] = dns.zone.Zone(name, relativize = False) + reverse[name].find_rdataset(rdtype = SOA, name = name, create = True).update( + self.zone.find_rdataset(rdtype = SOA, name = self.zone.origin)) + reverse[name].find_rdataset(rdtype = NS, name = name, create = True).update( + self.zone.find_rdataset(rdtype = NS, name = self.zone.origin)) + reverse[name].check_origin() + zones.append(reverse[name]) + + if not zones: + return + + for qtype in (A, AAAA): + for name, ttl, addr in self.zone.iterate_rdatas(qtype): + rname = dns.reversename.from_address(addr.to_text()) + rdata = name.to_wire() + rdata = dns.rdata.from_wire(self.zone.rdclass, PTR, rdata, 0, len(rdata)) + for z in zones: + if rname.is_subdomain(z.origin): + z.find_rdataset(rname, PTR, create = True).add(rdata, ttl) + break + else: + logger.warn("%29s (%-16s %s) does not match any given reverse zone", rname, addr, name) + + +class ZoneHerd(object): + """ + Collection of zones to be generated and written. This is a class + rather than a function to simplify doing all the real work up + front while deferring final installation until we've gone through + a confirmation dance when running as git {pre,post}-receive hooks + """ + + def __init__(self, inputs, outdir, tempword = "RENMWO"): + self.names = OrderedDict() + atexit.register(self.cleanup) + + now = int(time.time()) + reverse = OrderedDict() + forward = [ZoneGen(lines, name, now, reverse) for lines, name in inputs] + + header = ";; Generated by zc at {time}, do not edit by hand\n\n".format( + time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now))) + + os.chdir(outdir) + + pid = os.getpid() + + for z in reverse.values() + forward: + fn = z.origin.to_text(omit_final_dot = True) + tfn = ".~{}~{}~{}".format(pid, tempword, fn) + self.names[tfn] = fn + with open(tfn, "w") as f: + f.write(header) + z.to_file(f, relativize = False) + logger.info("Wrote %s", fn) + + def finish(self): + while self.names: + tfn, fn = self.names.popitem() + os.rename(tfn, fn) + logger.info("Installed %s", fn) + + def cleanup(self): + for tfn in self.names: + try: + os.unlink(tfn) + logger.debug("Unlinked %s", tfn) + except: + pass + + +class GitView(object): + """ + Wrapper around git code common to both hooks. + """ + + all_zeros = "0" * 40 + + def __init__(self): + import git, json + self.repo = git.Repo() + self.gcfg = self.repo.config_reader() + self.configure_logging() + self.outdir = self.gcfg.get_value("zc", "output-directory") + self.timeout = self.gcfg.get_value("zc", "hook-timeout", 15) + self.postcmd = self.gcfg.get_value("zc", "post-command", "").split() + self.commit = None + for line in sys.stdin: + oldsha, newsha, refname = line.split() + if refname == "refs/heads/master" and newsha != self.all_zeros: + self.commit = newsha + break + if self.commit is not None: + tree = self.repo.commit(self.commit).tree + self.jcfg = json.loads(tree["config.json"].data_stream.read()) + log_level = self.jcfg.get("log-level", "warning").strip() + self.stderr_logger.setLevel(log_levels[log_level]) + self.zone_blobs = [tree[name] for name in self.jcfg["zones"]] + self.log_user_hook_commit() + + def configure_logging(self): + self.stderr_logger = logging.StreamHandler() + self.stderr_logger.setLevel(logging.WARNING) + self.stderr_logger.setFormatter(logging.Formatter("%(name)s %(levelname)s %(message)s")) + logging.getLogger().addHandler(self.stderr_logger) + logging.getLogger().setLevel(logging.DEBUG) + log_level = self.gcfg.get_value("zc", "log-level", "warning") + log_file = self.gcfg.get_value("zc", "log-file", "/var/log/zc/zc.log") + log_hours = self.gcfg.get_value("zc", "log-file-hours", 24) + log_count = self.gcfg.get_value("zc", "log-file-count", 7) + if log_file: + self.file_logger = logging.handlers.TimedRotatingFileHandler( + filename = log_file, + interval = log_hours, + backupCount = log_count, + when = "H", + utc = True) + self.file_logger.setFormatter(logging.Formatter( + "%(asctime)-15s %(name)s [%(process)s] %(levelname)s %(message)s")) + self.file_logger.setLevel(log_levels[log_level]) + logging.getLogger().addHandler(self.file_logger) + else: + self.file_logger = None + + def log_user_hook_commit(self): + logger.debug("Original SSH command: %s", os.getenv("SSH_ORIGINAL_COMMAND")) + logger.debug("authorized_keys command: %s", os.getenv("GIT_REMOTE_ONLY_COMMAND")) + user = os.getenv("GIT_REMOTE_ONLY_COMMAND", "").split() + user = user[2] if len(user) > 2 else "unknown" + logger.info("User %s running %s processing commit %s", user, sys.argv[0], self.commit) + + @property + def fifo_name(self): + return os.path.join(self.outdir, ".zc.fifo") + + +def daemonize(): + """ + Detach from parent process, in this case git, so that can report + success to git when running as a pre-receive hook while sticking + around to handle final installation of our generated zone files. + + Not sure how much of the following ritual is necessary, but some + of it definitely is (git push hangs if we just fork() and _exit()). + Sacrifice the rubber chicken and move on. + """ + + sys.stdout.flush() + sys.stderr.flush() + old_action = signal.signal(signal.SIGHUP, signal.SIG_IGN) + if os.fork() > 0: + os._exit(0) + os.setsid() + fd = os.open(os.devnull, os.O_RDWR) + os.dup2(fd, 0) + os.dup2(fd, 1) + os.dup2(fd, 2) + if fd > 2: + os.close(fd) + signal.signal(signal.SIGHUP, old_action) + + +def cli_main(): + + """ + Entry point for command line use. + """ + + parser = ArgumentParser(formatter_class = type("HF", (ArgumentDefaultsHelpFormatter, + RawDescriptionHelpFormatter), {}), + description = __doc__) + + parser.add_argument("-o", "--output-directory", + default = ".", + help = "directory for output files") + + parser.add_argument("-l", "--log-level", + choices = tuple(log_levels), + default = "warning", + help = "how loudly to bark") + + parser.add_argument("input", + nargs = "+", + type = FileType("r"), + help = "input file") + + args = parser.parse_args() + + logging.basicConfig(format = "%(message)s", level = log_levels[args.log_level]) + + herd = ZoneHerd(((input, input.name) for input in args.input), args.output_directory) + herd.finish() + + +def pre_receive_main(): + """ + Entry point for git pre-receive hook. + + Do all the zone generation and write the files to disk under + temporary names, but defer final installation until we get + confirmation from the post-receive hook that git is done accepting + the push. Since git won't do this until after the pre-receive + hook exits, this hook has to daemonize itself after doing all the + real work, so that git can get on with the rest. + + This may be excessively paranoid, but git makes few promises about + what will happen if more than one push is active at the same time. + In theory, the lock on our FIFO is enough to force serialization, + but that can fail if, eg, somebody deletes the FIFO itself. So + our wakeup signal is receiving the commit hash through the FIFO + from the post-receive hook. + + If we don't get the right wakeup signal before a (configurable) + timeout expires, we clean up our output files and exit. + """ + + try: + gv = GitView() + if gv.commit is None: + logger.info("No commits on master branch, nothing to do") + sys.exit() + + if not os.path.exists(gv.fifo_name): + os.mkfifo(gv.fifo_name) + + fifo = os.open(gv.fifo_name, os.O_RDONLY | os.O_NONBLOCK) + + fcntl.flock(fifo, fcntl.LOCK_EX) + + if not stat.S_ISFIFO(os.fstat(fifo).st_mode): + raise RuntimeError("{} is not a FIFO!".format(gv.fifo_name)) + + herd = ZoneHerd(((blob.data_stream.read().splitlines(), blob.name) for blob in gv.zone_blobs), + gv.outdir, + gv.commit) + + logging.getLogger().removeHandler(gv.stderr_logger) + + daemonize() + + logger.info("Awaiting confirmation of commit %s before installing files", gv.commit) + + remaining = gv.timeout + confirmation = "" + + while remaining > 0: + t = time.time() + if not select.select([fifo], [], [], remaining)[0]: + break # Timeout + chunk = os.read(fifo, 1024) + if chunk == "": + break # EOF + confirmation += chunk + if gv.commit in confirmation.splitlines(): + logger.info("Commit %s confirmed", gv.commit) + herd.finish() # Success + if gv.postcmd: + logger.info("Running post-command %r", gv.postcmd) + proc = subprocess.Popen(gv.postcmd, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) + for line in proc.stdout.read().splitlines(): + logger.info(">> %s", line) + proc.stdout.close() + proc.wait() + break + remaining -= time.time() - t + + except Exception as e: + logger.error("%s", e) + +def post_receive_main(): + """ + Entry point for git post-receive hook. + + Zone files have already been generated and written, daemonized + pre-receive hook process is just waiting for us to confirm that + git has finished accepting push of this commit, which we do by + sending our commit hash to the pre-receive daemon. + """ + + try: + gv = GitView() + if gv.commit is not None: + with open(gv.fifo_name, "w") as f: + f.write(gv.commit + "\n") + except Exception as e: + logger.error("%s", e) + + +def main(): + """ + Entry point, just dispatch based on how we were invoked. + """ + + jane = os.path.basename(sys.argv[0]) + + if jane == "pre-receive": + pre_receive_main() + + elif jane == "post-receive": + post_receive_main() + + else: + cli_main() + + +if __name__ == "__main__": + main() |