diff options
Diffstat (limited to 'rpki/xml_utils.py')
-rw-r--r-- | rpki/xml_utils.py | 54 |
1 files changed, 46 insertions, 8 deletions
diff --git a/rpki/xml_utils.py b/rpki/xml_utils.py index c276ce98..9b443d0b 100644 --- a/rpki/xml_utils.py +++ b/rpki/xml_utils.py @@ -32,11 +32,15 @@ XML utilities. """ +import logging import xml.sax import lxml.sax import lxml.etree import rpki.exceptions +logger = logging.getLogger(__name__) + + class sax_handler(xml.sax.handler.ContentHandler): """ SAX handler for RPKI protocols. @@ -56,6 +60,7 @@ class sax_handler(xml.sax.handler.ContentHandler): """ Initialize SAX handler. """ + xml.sax.handler.ContentHandler.__init__(self) self.text = "" self.stack = [] @@ -64,18 +69,21 @@ class sax_handler(xml.sax.handler.ContentHandler): """ Redirect startElementNS() events to startElement(). """ + return self.startElement(name[1], attrs) def endElementNS(self, name, qname): """ Redirect endElementNS() events to endElement(). """ + return self.endElement(name[1]) def characters(self, content): """ Accumulate a chuck of element content (text). """ + self.text += content def startElement(self, name, attrs): @@ -111,6 +119,7 @@ class sax_handler(xml.sax.handler.ContentHandler): Handle endElement() events. Mostly this means handling any accumulated element text. """ + text = self.text.encode("ascii").strip() self.text = "" self.stack[-1].endElement(self.stack, name, text) @@ -120,6 +129,7 @@ class sax_handler(xml.sax.handler.ContentHandler): """ Create a one-off SAX parser, parse an ETree, return the result. """ + self = cls() lxml.sax.saxify(elt, self) return self.result @@ -128,6 +138,7 @@ class sax_handler(xml.sax.handler.ContentHandler): """ Handle top-level PDU for this protocol. """ + assert name == self.name and attrs["version"] == self.version return self.pdu() @@ -154,6 +165,7 @@ class base_elt(object): """ Default startElement() handler: just process attributes. """ + if name not in self.elements: assert name == self.element_name, "Unexpected name %s, stack %s" % (name, stack) self.read_attrs(attrs) @@ -162,6 +174,7 @@ class base_elt(object): """ Default endElement() handler: just pop the stack. """ + assert name == self.element_name, "Unexpected name %s, stack %s" % (name, stack) stack.pop() @@ -169,12 +182,14 @@ class base_elt(object): """ Default toXML() element generator. """ + return self.make_elt() def read_attrs(self, attrs): """ Template-driven attribute reader. """ + for key in self.attributes: val = attrs.get(key, None) if isinstance(val, str) and val.isdigit() and not key.endswith("_handle"): @@ -187,6 +202,7 @@ class base_elt(object): """ XML element constructor. """ + elt = lxml.etree.Element(self.xmlns + self.element_name, nsmap = self.nsmap) for key in self.attributes: val = getattr(self, key, None) @@ -201,6 +217,7 @@ class base_elt(object): """ Constructor for Base64-encoded subelement. """ + if value is not None and not value.empty(): lxml.etree.SubElement(elt, self.xmlns + name, nsmap = self.nsmap).text = value.get_Base64() @@ -208,6 +225,7 @@ class base_elt(object): """ Convert a base_elt object to string format. """ + return lxml.etree.tostring(self.toXML(), pretty_print = True, encoding = "us-ascii") @classmethod @@ -215,6 +233,7 @@ class base_elt(object): """ Generic PDU constructor. """ + self = cls() for k, v in kargs.items(): if isinstance(v, bool): @@ -235,6 +254,7 @@ class text_elt(base_elt): """ Extract text from parsed XML. """ + base_elt.endElement(self, stack, name, text) setattr(self, self.text_attribute, text) @@ -242,6 +262,7 @@ class text_elt(base_elt): """ Insert text into generated XML. """ + elt = self.make_elt() elt.text = getattr(self, self.text_attribute) or None return elt @@ -258,6 +279,7 @@ class data_elt(base_elt): that sub-elements are Base64-encoded using the sql_template mechanism. """ + if name in self.elements: elt_type = self.sql_template.map.get(name) assert elt_type is not None, "Couldn't find element type for %s, stack %s" % (name, stack) @@ -271,6 +293,7 @@ class data_elt(base_elt): Default element generator for SQL-based objects. This assumes that sub-elements are Base64-encoded DER objects. """ + elt = self.make_elt() for i in self.elements: self.make_b64elt(elt, i, getattr(self, i, None)) @@ -280,6 +303,7 @@ class data_elt(base_elt): """ Construct a reply PDU. """ + if r_pdu is None: r_pdu = self.__class__() self.make_reply_clone_hook(r_pdu) @@ -297,6 +321,7 @@ class data_elt(base_elt): """ Overridable hook. """ + pass def serve_fetch_one(self): @@ -304,6 +329,7 @@ class data_elt(base_elt): Find the object on which a get, set, or destroy method should operate. """ + r = self.serve_fetch_one_maybe() if r is None: raise rpki.exceptions.NotFound @@ -313,12 +339,14 @@ class data_elt(base_elt): """ Overridable hook. """ + cb() def serve_post_save_hook(self, q_pdu, r_pdu, cb, eb): """ Overridable hook. """ + cb() def serve_create(self, r_msg, cb, eb): @@ -371,6 +399,7 @@ class data_elt(base_elt): """ Handle a get action. """ + r_pdu = self.serve_fetch_one() self.make_reply(r_pdu) r_msg.append(r_pdu) @@ -380,6 +409,7 @@ class data_elt(base_elt): """ Handle a list action for non-self objects. """ + for r_pdu in self.serve_fetch_all(): self.make_reply(r_pdu) r_msg.append(r_pdu) @@ -389,12 +419,14 @@ class data_elt(base_elt): """ Overridable hook. """ + cb() def serve_destroy(self, r_msg, cb, eb): """ Handle a destroy action. """ + def done(): db_pdu.sql_delete() r_msg.append(self.make_reply()) @@ -406,19 +438,17 @@ class data_elt(base_elt): """ Action dispatch handler. """ - dispatch = { "create" : self.serve_create, - "set" : self.serve_set, - "get" : self.serve_get, - "list" : self.serve_list, - "destroy" : self.serve_destroy } - if self.action not in dispatch: + + method = getattr(self, "serve_" + self.action, None) + if method is None: raise rpki.exceptions.BadQuery("Unexpected query: action %s" % self.action) - dispatch[self.action](r_msg, cb, eb) + method(r_msg, cb, eb) def unimplemented_control(self, *controls): """ Uniform handling for unimplemented control operations. """ + unimplemented = [x for x in controls if getattr(self, x, False)] if unimplemented: raise rpki.exceptions.NotImplementedYet("Unimplemented control %s" % ", ".join(unimplemented)) @@ -432,6 +462,7 @@ class msg(list): """ Handle top-level PDU. """ + if name == "msg": assert self.version == int(attrs["version"]) self.type = attrs["type"] @@ -445,6 +476,7 @@ class msg(list): """ Handle top-level PDU. """ + assert name == "msg", "Unexpected name %s, stack %s" % (name, stack) assert len(stack) == 1 stack.pop() @@ -453,14 +485,16 @@ class msg(list): """ Convert msg object to string. """ + return lxml.etree.tostring(self.toXML(), pretty_print = True, encoding = "us-ascii") def toXML(self): """ Generate top-level PDU. """ + elt = lxml.etree.Element(self.xmlns + "msg", nsmap = self.nsmap, version = str(self.version), type = self.type) - elt.extend([i.toXML() for i in self]) + elt.extend(i.toXML() for i in self) return elt @classmethod @@ -468,6 +502,7 @@ class msg(list): """ Create a query PDU. """ + self = cls(args) self.type = "query" return self @@ -477,6 +512,7 @@ class msg(list): """ Create a reply PDU. """ + self = cls(args) self.type = "reply" return self @@ -485,10 +521,12 @@ class msg(list): """ Is this msg a query? """ + return self.type == "query" def is_reply(self): """ Is this msg a reply? """ + return self.type == "reply" |