diff options
Diffstat (limited to 'rpkid/rpki/resource_set.py')
-rw-r--r-- | rpkid/rpki/resource_set.py | 138 |
1 files changed, 108 insertions, 30 deletions
diff --git a/rpkid/rpki/resource_set.py b/rpkid/rpki/resource_set.py index be39df75..0bc31ef2 100644 --- a/rpkid/rpki/resource_set.py +++ b/rpkid/rpki/resource_set.py @@ -301,7 +301,7 @@ class resource_set(list): canonical = False - def __init__(self, ini = None): + def __init__(self, ini = None, allow_overlap = False): """ Initialize a resource_set. """ @@ -316,24 +316,30 @@ class resource_set(list): self.parse_rfc3779_tuple(ini) elif isinstance(ini, list): self.extend(ini) - else: - assert ini is None or (isinstance(ini, str) and ini == ""), "Unexpected initializer: %s" % str(ini) - self.canonize() + elif ini is not None and ini != "": + raise ValueError("Unexpected initializer: %s" % str(ini)) + self.canonize(allow_overlap) - def canonize(self): + def canonize(self, allow_overlap = False): """ Whack this resource_set into canonical form. """ assert not self.inherit or not self if not self.canonical: self.sort() - for i in xrange(len(self) - 2, -1, -1): - if self[i].max + 1 == self[i+1].min: + i = 0 + while i + 1 < len(self): + if allow_overlap and self[i].max + 1 >= self[i+1].min: + self[i] = type(self[i])(self[i].min, max(self[i].max, self[i+1].max)) + del self[i+1] + elif self[i].max + 1 == self[i+1].min: self[i] = type(self[i])(self[i].min, self[i+1].max) - self.pop(i + 1) - if __debug__: - for i in xrange(0, len(self) - 1): - assert self[i].max < self[i+1].min, "Resource overlap: %s %s" % (self[i], self[i+1]) + del self[i+1] + else: + i += 1 + for i in xrange(0, len(self) - 1): + if self[i].max >= self[i+1].min: + raise rpki.exceptions.ResourceOverlap("Resource overlap: %s %s" % (self[i], self[i+1])) self.canonical = True def append(self, item): @@ -425,18 +431,24 @@ class resource_set(list): del set2[0] return type(self)(result) + __or__ = union + def intersection(self, other): """ Set intersection for resource sets. """ return self._comm(other)[2] + __and__ = intersection + def difference(self, other): """ Set difference for resource sets. """ return self._comm(other)[0] + __sub__ = difference + def symmetric_difference(self, other): """ Set symmetric difference (XOR) for resource sets. @@ -444,6 +456,8 @@ class resource_set(list): com = self._comm(other) return com[0].union(com[1]) + __xor__ = symmetric_difference + def contains(self, item): """ Set membership test for resource sets. @@ -468,6 +482,8 @@ class resource_set(list): hi = mid return lo < len(self) and self[lo].min <= min and self[lo].max >= max + __contains__ = contains + def issubset(self, other): """ Test whether self is a subset (possibly improper) of other. @@ -477,12 +493,26 @@ class resource_set(list): return False return True + __le__ = issubset + def issuperset(self, other): """ Test whether self is a superset (possibly improper) of other. """ return other.issubset(self) + __ge__ = issuperset + + def __lt__(self, other): + return not self.issuperset(other) + + def __gt__(self, other): + return not self.issubset(other) + + __eq__ = list.__eq__ + + __ne__ = list.__ne__ + @classmethod def from_sql(cls, sql, query, args = None): """ @@ -730,6 +760,26 @@ class resource_bag(object): return self @classmethod + def from_str(cls, text, allow_overlap = False): + """ + Parse a comma-separated text string into a resource_bag. Not + particularly efficient, fix that if and when it becomes an issue. + """ + asns = [] + v4s = [] + v6s = [] + for word in text.split(","): + if "." in word: + v4s.append(word) + elif ":" in word: + v6s.append(word) + else: + asns.append(word) + return cls(asn = resource_set_as(",".join(asns), allow_overlap) if asns else None, + v4 = resource_set_ipv4(",".join(v4s), allow_overlap) if v4s else None, + v6 = resource_set_ipv6(",".join(v6s), allow_overlap) if v6s else None) + + @classmethod def from_rfc3779_tuples(cls, exts): """ Build a resource_bag from intermediate form generated by RFC 3779 @@ -773,21 +823,49 @@ class resource_bag(object): Compute intersection with another resource_bag. valid_until attribute (if any) inherits from self. """ - return self.__class__(self.asn.intersection(other.asn), - self.v4.intersection(other.v4), - self.v6.intersection(other.v6), + return self.__class__(self.asn & other.asn, + self.v4 & other.v4, + self.v6 & other.v6, self.valid_until) + __and__ = intersection + def union(self, other): """ Compute union with another resource_bag. valid_until attribute (if any) inherits from self. """ - return self.__class__(self.asn.union(other.asn), - self.v4.union(other.v4), - self.v6.union(other.v6), + return self.__class__(self.asn | other.asn, + self.v4 | other.v4, + self.v6 | other.v6, + self.valid_until) + + __or__ = union + + def difference(self, other): + """ + Compute difference against another resource_bag. valid_until + attribute (if any) inherits from self + """ + return self.__class__(self.asn - other.asn, + self.v4 - other.v4, + self.v6 - other.v6, + self.valid_until) + + __sub__ = difference + + def symmetric_difference(self, other): + """ + Compute symmetric difference against another resource_bag. + valid_until attribute (if any) inherits from self + """ + return self.__class__(self.asn ^ other.asn, + self.v4 ^ other.v4, + self.v6 ^ other.v6, self.valid_until) + __xor__ = symmetric_difference + def __str__(self): s = "" if self.asn: @@ -1095,25 +1173,25 @@ if __name__ == "__main__": v1 = r1._comm(r2) v2 = r2._comm(r1) assert v1[0] == v2[1] and v1[1] == v2[0] and v1[2] == v2[2] - for i in r1: assert r1.contains(i) and r1.contains(i.min) and r1.contains(i.max) - for i in r2: assert r2.contains(i) and r2.contains(i.min) and r2.contains(i.max) - for i in v1[0]: assert r1.contains(i) and not r2.contains(i) - for i in v1[1]: assert not r1.contains(i) and r2.contains(i) - for i in v1[2]: assert r1.contains(i) and r2.contains(i) - v1 = r1.union(r2) - v2 = r2.union(r1) + for i in r1: assert i in r1 and i.min in r1 and i.max in r1 + for i in r2: assert i in r2 and i.min in r2 and i.max in r2 + for i in v1[0]: assert i in r1 and i not in r2 + for i in v1[1]: assert i not in r1 and i in r2 + for i in v1[2]: assert i in r1 and i in r2 + v1 = r1 | r2 + v2 = r2 | r1 assert v1 == v2 print "x|y:", v1, testprefix(v1) - v1 = r1.difference(r2) - v2 = r2.difference(r1) + v1 = r1 - r2 + v2 = r2 - r1 print "x-y:", v1, testprefix(v1) print "y-x:", v2, testprefix(v2) - v1 = r1.symmetric_difference(r2) - v2 = r2.symmetric_difference(r1) + v1 = r1 ^ r2 + v2 = r2 ^ r1 assert v1 == v2 print "x^y:", v1, testprefix(v1) - v1 = r1.intersection(r2) - v2 = r2.intersection(r1) + v1 = r1 & r2 + v2 = r2 & r1 assert v1 == v2 print "x&y:", v1, testprefix(v1) |