# $Id$
#
# Copyright (C) 2015-2016  Parsons Government Services ("PARSONS")
# Portions copyright (C) 2013-2014  Dragon Research Labs ("DRL")
# Portions copyright (C) 2009-2012  Internet Systems Consortium ("ISC")
# Portions copyright (C) 2007-2008  American Registry for Internet Numbers ("ARIN")
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notices and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND PARSONS, DRL, ISC, AND ARIN
# DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.  IN NO EVENT
# SHALL PARSONS, DRL, ISC, OR ARIN BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF
# CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

"""
Trivial RPKI up-down protocol root server.
"""

import os
import time
import logging
import argparse

import rpki.resource_set
import rpki.up_down
import rpki.x509
import rpki.http_simple
import rpki.config
import rpki.exceptions
import rpki.relaxng
import rpki.sundial
import rpki.log
import rpki.daemonize
import rpki.publication

from lxml.etree import Element, SubElement

logger = logging.getLogger(__name__)


class ReplayTracker(object):
    """
    Stash for replay protection timestamps.
    """

    def __init__(self):
        self.cms_timestamp = None



class main(object):


    def root_newer_than_subject(self):
        return self.rpki_root_cert.mtime > os.stat(self.rpki_subject_cert_file).st_mtime


    def get_subject_cert(self):
        try:
            x = rpki.x509.X509(Auto_file = self.rpki_subject_cert_file)
            logger.debug("Read subject cert %s", self.rpki_subject_cert_file)
            return x
        except IOError:
            return None


    def set_subject_cert(self, cert):
        logger.debug("Writing subject cert %s, SKI %s", self.rpki_subject_cert_file, cert.hSKI())
        with open(self.rpki_subject_cert_file, "wb") as f:
            f.write(cert.get_DER())


    def del_subject_cert(self):
        logger.debug("Deleting subject cert %s", self.rpki_subject_cert_file)
        os.remove(self.rpki_subject_cert_file)


    def get_subject_pkcs10(self):
        try:
            x = rpki.x509.PKCS10(Auto_file = self.rpki_subject_pkcs10)
            logger.debug("Read subject PKCS #10 %s", self.rpki_subject_pkcs10)
            return x
        except IOError:
            return None


    def set_subject_pkcs10(self, pkcs10):
        logger.debug("Writing subject PKCS #10 %s", self.rpki_subject_pkcs10)
        with open(self.rpki_subject_pkcs10, "wb") as f:
            f.write(pkcs10.get_DER())


    def del_subject_pkcs10(self):
        logger.debug("Deleting subject PKCS #10 %s", self.rpki_subject_pkcs10)
        try:
            os.remove(self.rpki_subject_pkcs10)
        except OSError:
            pass


    def issue_subject_cert_maybe(self, new_pkcs10):
        now = rpki.sundial.now()
        subject_cert = self.get_subject_cert()
        if subject_cert is None:
            subject_cert_hash = None
        else:
            subject_cert_hash = rpki.x509.sha256(subject_cert.get_DER()).encode("hex")
        old_pkcs10 = self.get_subject_pkcs10()
        if new_pkcs10 is not None and new_pkcs10 != old_pkcs10:
            self.set_subject_pkcs10(new_pkcs10)
            if subject_cert is not None:
                logger.debug("PKCS #10 changed, regenerating subject certificate")
                self.revoke_subject_cert(now)
                subject_cert = None
        if subject_cert is not None and subject_cert.getNotAfter() <= now + self.rpki_subject_regen:
            logger.debug("Subject certificate has reached expiration threshold, regenerating")
            self.revoke_subject_cert(now)
            subject_cert = None
        if subject_cert is not None and self.root_newer_than_subject():
            logger.debug("Root certificate has changed, regenerating subject")
            self.revoke_subject_cert(now)
            subject_cert = None
        if subject_cert is not None:
            return subject_cert, None
        pkcs10 = old_pkcs10 if new_pkcs10 is None else new_pkcs10
        if pkcs10 is None:
            logger.debug("No PKCS #10 request, can't generate subject certificate yet")
            return None, None
        resources = self.rpki_root_cert.get_3779resources()
        notAfter = now + self.rpki_subject_lifetime
        logger.info("Generating subject cert %s with resources %s, expires %s",
                    self.rpki_subject_cert_uri, resources, notAfter)
        req_key = pkcs10.getPublicKey()
        req_sia = pkcs10.get_SIA()
        self.next_serial_number()
        subject_cert = self.rpki_root_cert.issue(
            keypair     = self.rpki_root_key,
            subject_key = req_key,
            serial      = self.serial_number,
            sia         = req_sia,
            aia         = self.rpki_root_cert_uri,
            crldp       = self.rpki_root_crl_uri,
            resources   = resources,
            notBefore   = now,
            notAfter    = notAfter)
        self.set_subject_cert(subject_cert)
        pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap,
                           type = "query", version = rpki.publication.version)
        pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_subject_cert_uri)
        pdu.text = subject_cert.get_Base64()
        if subject_cert_hash is not None:
            pdu.set("hash", subject_cert_hash)
        self.generate_crl_and_manifest(now, pubd_msg)
        return subject_cert, pubd_msg


    def generate_crl_and_manifest(self, now, pubd_msg):
        subject_cert = self.get_subject_cert()
        self.next_serial_number()
        self.next_crl_number()
        while self.revoked and self.revoked[0][1] + 2 * self.rpki_subject_regen < now:
            del self.revoked[0]
        crl = rpki.x509.CRL.generate(
            keypair             = self.rpki_root_key,
            issuer              = self.rpki_root_cert,
            serial              = self.crl_number,
            thisUpdate          = now,
            nextUpdate          = now + self.rpki_subject_regen,
            revokedCertificates = self.revoked)
        crl_hash = self.read_hash_maybe(self.rpki_root_crl_file)
        logger.debug("Writing CRL %s", self.rpki_root_crl_file)
        with open(self.rpki_root_crl_file, "wb") as f:
            f.write(crl.get_DER())
        pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_crl_uri)
        pdu.text = crl.get_Base64()
        if crl_hash is not None:
            pdu.set("hash", crl_hash)
        manifest_content = [(os.path.basename(self.rpki_root_crl_uri), crl)]
        if subject_cert is not None:
            manifest_content.append((os.path.basename(self.rpki_subject_cert_uri), subject_cert))
        manifest_resources = rpki.resource_set.resource_bag.from_inheritance()
        manifest_keypair = rpki.x509.RSA.generate()
        manifest_cert = self.rpki_root_cert.issue(
            keypair     = self.rpki_root_key,
            subject_key = manifest_keypair.get_public(),
            serial      = self.serial_number,
            sia         = (None, None, self.rpki_root_manifest_uri, self.rrdp_notification_uri),
            aia         = self.rpki_root_cert_uri,
            crldp       = self.rpki_root_crl_uri,
            resources   = manifest_resources,
            notBefore   = now,
            notAfter    = now + self.rpki_subject_lifetime,
            is_ca       = False)
        manifest = rpki.x509.SignedManifest.build(
            serial         = self.crl_number,
            thisUpdate     = now,
            nextUpdate     = now + self.rpki_subject_regen,
            names_and_objs = manifest_content,
            keypair        = manifest_keypair,
            certs          = manifest_cert)
        mft_hash = self.read_hash_maybe(self.rpki_root_manifest_file)
        logger.debug("Writing manifest %s", self.rpki_root_manifest_file)
        with open(self.rpki_root_manifest_file, "wb") as f:
            f.write(manifest.get_DER())
        pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_manifest_uri)
        pdu.text = manifest.get_Base64()
        if mft_hash is not None:
            pdu.set("hash", mft_hash)
        cer_hash = rpki.x509.sha256(self.rpki_root_cert.get_DER()).encode("hex")
        if cer_hash != self.rpki_root_cert_hash:
            pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_cert_uri)
            pdu.text = self.rpki_root_cert.get_Base64()
            if self.rpki_root_cert_hash is not None:
                pdu.set("hash", self.rpki_root_cert_hash)
            self.rpki_root_cert_hash = cer_hash


    @staticmethod
    def read_hash_maybe(fn):
        try:
            with open(fn, "rb") as f:
                return rpki.x509.sha256(f.read()).encode("hex")
        except IOError:
            return None


    def revoke_subject_cert(self, now):
        self.revoked.append((self.get_subject_cert().getSerial(), now))


    def publish(self, q_msg):
        if q_msg is None:
            return
        assert len(q_msg) > 0

        if not all(q_pdu.get("hash") is not None for q_pdu in q_msg):
            logger.debug("Some publication PDUs are missing hashes, checking published data...")
            q = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap,
                        type = "query", version = rpki.publication.version)
            SubElement(q, rpki.publication.tag_list)
            published_hash = dict((r.get("uri"), r.get("hash")) for r in self.call_pubd(q))
            for q_pdu in q_msg:
                q_uri = q_pdu.get("uri")
                if q_pdu.get("hash") is None and published_hash.get(q_uri) is not None:
                    logger.debug("Updating hash of %s to %s from previously published data", q_uri, published_hash[q_uri])
                    q_pdu.set("hash", published_hash[q_uri])

        r_msg = self.call_pubd(q_msg)
        if len(q_msg) != len(r_msg):
            raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %s, got %s" % (len(q_msg), len(r_msg)))


    def call_pubd(self, q_msg):
        for q_pdu in q_msg:
            logger.info("Sending %s to pubd", q_pdu.get("uri"))
        r_msg = rpki.http_simple.client(
            proto_cms_msg = rpki.publication.cms_msg,
            client_key    = self.rootd_bpki_key,
            client_cert   = self.rootd_bpki_cert,
            client_crl    = self.rootd_bpki_crl,
            server_ta     = self.bpki_ta,
            server_cert   = self.pubd_bpki_cert,
            url           = self.pubd_url,
            q_msg         = q_msg,
            replay_track  = self.pubd_replay_tracker)
        rpki.publication.raise_if_error(r_msg)
        return r_msg


    def compose_response(self, r_msg, pkcs10 = None):
        subject_cert, pubd_msg = self.issue_subject_cert_maybe(pkcs10)
        bag = self.rpki_root_cert.get_3779resources()
        rc = SubElement(r_msg, rpki.up_down.tag_class,
                        class_name        = self.rpki_class_name,
                        cert_url          = str(rpki.up_down.multi_uri(self.rpki_root_cert_uri)),
                        resource_set_as   = str(bag.asn),
                        resource_set_ipv4 = str(bag.v4),
                        resource_set_ipv6 = str(bag.v6),
                        resource_set_notafter = str(bag.valid_until))
        if subject_cert is not None:
            c = SubElement(rc, rpki.up_down.tag_certificate,
                           cert_url = str(rpki.up_down.multi_uri(self.rpki_subject_cert_uri)))
            c.text = subject_cert.get_Base64()
        SubElement(rc, rpki.up_down.tag_issuer).text = self.rpki_root_cert.get_Base64()
        self.publish(pubd_msg)


    def handle_list(self, q_msg, r_msg):
        self.compose_response(r_msg)


    def handle_issue(self, q_msg, r_msg):
        # This is where we'd check q_msg[0].get("class_name") if this weren't rootd.
        self.compose_response(r_msg, rpki.x509.PKCS10(Base64 = q_msg[0].text))


    def handle_revoke(self, q_msg, r_msg):
        class_name = q_msg[0].get("class_name")
        ski        = q_msg[0].get("ski")
        logger.debug("Revocation requested for class %s SKI %s", class_name, ski)
        subject_cert = self.get_subject_cert()
        if subject_cert is None:
            logger.debug("No subject certificate, nothing to revoke")
            raise rpki.exceptions.NotInDatabase
        if subject_cert.gSKI() != ski:
            logger.debug("Subject certificate has different SKI %s, not revoking", subject_cert.gSKI())
            raise rpki.exceptions.NotInDatabase
        logger.debug("Revoking certificate %s", ski)
        now = rpki.sundial.now()
        pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap,
                           type = "query", version = rpki.publication.version)
        self.revoke_subject_cert(now)
        self.del_subject_cert()
        self.del_subject_pkcs10()
        SubElement(r_msg, q_msg[0].tag, class_name = class_name, ski = ski)
        self.generate_crl_and_manifest(now, pubd_msg)
        self.publish(pubd_msg)


    # Need to do something about mapping exceptions to up-down error
    # codes, right now everything shows up as "internal error".
    #
    #exceptions = {
    #  rpki.exceptions.ClassNameUnknown                    : 1201,
    #  rpki.exceptions.NoActiveCA                          : 1202,
    #  (rpki.exceptions.ClassNameUnknown, revoke_pdu)      : 1301,
    #  (rpki.exceptions.NotInDatabase,    revoke_pdu)      : 1302 }
    #
    # Might be that what we want here is a subclass of
    # rpki.exceptions.RPKI_Exception which carries an extra data field
    # for the up-down error code, so that we can add the correct code
    # when we instantiate it.
    #
    # There are also a few that are also schema violations, which means
    # we'd have to catch them before validating or pick them out of a
    # message that failed validation or otherwise break current
    # modularity.  Maybe an optional pre-validation check method hook in
    # rpki.x509.XML_CMS_object which we can use to intercept such things?


    def handler(self, request, q_der):
        try:
            q_cms = rpki.up_down.cms_msg(DER = q_der)
            q_msg = q_cms.unwrap((self.bpki_ta, self.child_bpki_cert))
            q_type = q_msg.get("type")
            logger.info("Serving %s query", q_type)
            r_msg = Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap,
                            version = rpki.up_down.version,
                            sender  = q_msg.get("recipient"), recipient = q_msg.get("sender"),
                            type = q_type + "_response")
            try:
                self.rpkid_cms_timestamp = q_cms.check_replay(self.rpkid_cms_timestamp, request.path)
                getattr(self, "handle_" + q_type)(q_msg, r_msg)
            except Exception, e:
                logger.exception("Exception processing up-down %s message", q_type)
                rpki.up_down.generate_error_response_from_exception(r_msg, e, q_type)
            request.send_cms_response(rpki.up_down.cms_msg().wrap(
                r_msg, self.rootd_bpki_key, self.rootd_bpki_cert,
                self.rootd_bpki_crl if self.include_bpki_crl else None))
        except Exception, e:
            logger.exception("Unhandled exception processing up-down message")
            request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e))


    def next_crl_number(self):
        if self.crl_number is None:
            try:
                crl = rpki.x509.CRL(DER_file = self.rpki_root_crl_file)
                self.crl_number = crl.getCRLNumber()
            except:
                self.crl_number = 0
        self.crl_number += 1
        return self.crl_number


    def next_serial_number(self):
        if self.serial_number is None:
            subject_cert = self.get_subject_cert()
            if subject_cert is not None:
                self.serial_number = subject_cert.getSerial() + 1
            else:
                self.serial_number = 0
        self.serial_number += 1
        return self.serial_number


    def __init__(self):
        self.serial_number = None
        self.crl_number = None
        self.revoked = []
        self.rpkid_cms_timestamp = None
        self.pubd_replay_tracker = ReplayTracker()

        os.environ["TZ"] = "UTC"
        time.tzset()

        self.cfg = rpki.config.argparser(section = "rootd", doc = __doc__)
        self.cfg.add_boolean_argument("--foreground", default = False,
                                      help = "do not daemonize")
        self.cfg.add_argument("--pidfile",
                              default = os.pat.join(rpki.daemonize.default_pid_directory,
                                                    "rootd.pid"),
                              help = "override default location of pid file")
        self.cfg.add_logging_arguments()
        args = parser.parse_args()

        self.cfg.configure_logging(args = args, ident = "rootd")

        self.cfg.set_global_flags()

        if not args.foreground:
            rpki.daemonize.daemon(pidfile = args.pidfile)

        self.bpki_ta                 = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta"))
        self.rootd_bpki_key          = rpki.x509.RSA( Auto_update = self.cfg.get("rootd-bpki-key"))
        self.rootd_bpki_cert         = rpki.x509.X509(Auto_update = self.cfg.get("rootd-bpki-cert"))
        self.rootd_bpki_crl          = rpki.x509.CRL( Auto_update = self.cfg.get("rootd-bpki-crl"))
        self.child_bpki_cert         = rpki.x509.X509(Auto_update = self.cfg.get("child-bpki-cert"))

        if self.cfg.has_option("pubd-bpki-cert"):
            self.pubd_bpki_cert        = rpki.x509.X509(Auto_update = self.cfg.get("pubd-bpki-cert"))
        else:
            self.pubd_bpki_cert        = None

        self.http_server_host        = self.cfg.get("server-host", "")
        self.http_server_port        = self.cfg.getint("server-port")

        self.rpki_class_name         = self.cfg.get("rpki-class-name")

        self.rpki_root_key           = rpki.x509.RSA( Auto_update = self.cfg.get("rpki-root-key-file"))
        self.rpki_root_cert          = rpki.x509.X509(Auto_update = self.cfg.get("rpki-root-cert-file"))
        self.rpki_root_cert_uri      = self.cfg.get("rpki-root-cert-uri")
        self.rpki_root_cert_hash     = None

        self.rpki_root_manifest_file = self.cfg.get("rpki-root-manifest-file")
        self.rpki_root_manifest_uri  = self.cfg.get("rpki-root-manifest-uri")

        self.rpki_root_crl_file      = self.cfg.get("rpki-root-crl-file")
        self.rpki_root_crl_uri       = self.cfg.get("rpki-root-crl-uri")

        self.rpki_subject_cert_file  = self.cfg.get("rpki-subject-cert-file")
        self.rpki_subject_cert_uri   = self.cfg.get("rpki-subject-cert-uri")
        self.rpki_subject_pkcs10     = self.cfg.get("rpki-subject-pkcs10-file")
        self.rpki_subject_lifetime   = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-lifetime", "8w"))
        self.rpki_subject_regen      = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-regen",
                                                                                 self.rpki_subject_lifetime.convert_to_seconds() / 2))

        self.include_bpki_crl        = self.cfg.getboolean("include-bpki-crl", False)

        self.pubd_url                = self.cfg.get("pubd-contact-uri")

        self.rrdp_notification_uri   = self.cfg.get("rrdp-notification-uri")

        rpki.http_simple.server(host     = self.http_server_host,
                                port     = self.http_server_port,
                                handlers = (("/", self.handler, rpki.up_down.allowed_content_types),))