#!/usr/local/bin/python

# $Id$

# Copyright (C) 2013  Dragon Research Labs ("DRL")
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND DRL DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS.  IN NO EVENT SHALL DRL BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.

# Preliminary script to work out what's involved in building an
# SQLite3 database of RP objects.  We haven't bothered with this until
# now in rcynic, because we mostly just walk the filesystem tree, but
# LTA and some of the ideas Tim is playing with require a lot of
# lookups based on things that are not the URIs we use as filenames,
# so some kind of indexing may become necessary.  Given the complexity
# of building any kind of real index over RFC 3779 resources,
# otherwise fine lightweight tools like the Python shelve library
# probably won't cut it here, and I don't want to add a dependency on
# MySQL on the RP side (yet?), so let's see what we can do with SQLite3.

import os
import sys
import sqlite3
import rpki.POW
import rpki.x509
import rpki.resource_set

fn2map = dict(cer = rpki.x509.X509,
              crl = rpki.x509.CRL,
              mft = rpki.x509.SignedManifest,
              roa = rpki.x509.ROA,
              gbr = rpki.x509.Ghostbuster)

sqlite3.register_adapter(rpki.POW.IPAddress,
                         lambda x: buffer("_" + x.toBytes()))

sqlite3.register_converter("RangeVal",
                           lambda s: long(s) if s.isdigit() else rpki.POW.IPAddress.fromBytes(s[1:]))


def main():
  db, cur = initialize_database()
  test(cur)
  db.close()


def test(cur):
  print "Testing range functions"
  for fn2 in [None] + fn2map.keys():
    if fn2 is not None:
      print
      print "Restricting search to type", fn2
    print
    print "Looking for range that should include adrilankha and psg again"
    for r in find_by_range(cur, "147.28.0.19", "147.28.0.62", fn2):
      print r, r.uris
    print
    print "Looking for range that should include adrilankha"
    for r in find_by_range(cur, "147.28.0.19", "147.28.0.19", fn2):
      print r, r.uris
    print
    print "Looking for range that should include ASN 3130"
    for r in find_by_range(cur, 3130, 3130, fn2):
      print r, r.uris
  print
  print "Moving on to resource sets"
  for expr in ("147.28.0.19-147.28.0.62",
               "3130",
               "2001:418:1::19/128",
               "147.28.0.19-147.28.0.62,198.180.150.50/32",
               "3130,147.28.0.19-147.28.0.62,198.180.150.50/32",
               "2001:418:1::62/128,198.180.150.50/32,2001:418:8006::50/128",
               "147.28.0.19-147.28.0.62,2001:418:1::19/128,2001:418:1::62/128,198.180.150.50/32,2001:418:8006::50/128"):
    print
    print "Trying", expr
    for r in find_by_resource_bag(cur, rpki.resource_set.resource_bag.from_str(expr)):
      print r, r.uris
    for fn2 in fn2map:
      print
      print "Trying", fn2, expr
      for r in find_by_resource_bag(cur, rpki.resource_set.resource_bag.from_str(expr), fn2):
        print r, r.uris
    

def find_by_ski(cur, ski, fn2 = None):
  return find_results(cur, fn2,
                      """
                      SELECT id, fn2, der
                      FROM object
                      WHERE ski = ?
                      """,
                      buffer(ski))


def find_by_aki(cur, aki, fn2 = None):
  return find_results(cur, fn2,
                      """
                      SELECT id, fn2, der
                      FROM object
                      WHERE aki = ?
                      """,
                      buffer(aki))


# It's easiest to understand overlap conditions by understanding
# non-overlap then inverting and and applying De Morgan's law.  Ranges
# A and B do not overlap if either A.min > B.max or A.max < B.min;
# therefore they do overlap if A.min <= B.max and A.max >= B.min.

def find_by_range(cur, range_min, range_max = None, fn2 = None):
  if range_max is None:
    range_max = range_min
  if isinstance(range_min, (str, unicode)):
    range_min = long(range_min) if range_min.isdigit() else rpki.POW.IPAddress(range_min)
  if isinstance(range_max, (str, unicode)):
    range_max = long(range_max) if range_max.isdigit() else rpki.POW.IPAddress(range_max)
  assert isinstance(range_min, (int, long, rpki.POW.IPAddress))
  assert isinstance(range_max, (int, long, rpki.POW.IPAddress))
  return find_results(cur, fn2,
                      """
                      SELECT object.id, fn2, der
                      FROM object, range
                      WHERE ? <= max AND ? >= min AND object.id = range.id
                      """,
                      range_min,
                      range_max)


def find_by_resource_bag(cur, bag, fn2 = None):
  assert bag.asn or bag.v4 or bag.v6
  qset = []
  aset = []
  for rset in (bag.asn, bag.v4, bag.v6):
    if rset:
      for r in rset:
        qset.append("(? <= max AND ? >= min)")
        aset.append(r.min)
        aset.append(r.max)
  return find_results(*([
    cur, fn2,
    """
    SELECT object.id, fn2, der
    FROM object, range
    WHERE object.id = range.id AND (%s)
    """ % (" OR ".join(qset))
    ] + aset))


def find_results(cur, fn2, query, *args):
  if fn2 is not None:
    assert fn2 in fn2map
    query += " AND fn2 = ?"
    args = args + (fn2,)
  query += " GROUP BY object.id"
  results = []
  cur.execute(query, args)
  selections = cur.fetchall()
  for rowid, fn2, der in selections:
    obj = fn2map[fn2](DER = der)
    cur.execute("SELECT uri FROM uri WHERE id = ?", (rowid,))
    obj.uris = [u[0] for u in cur.fetchall()]
    obj.uri = obj.uris[0] if len(obj.uris) == 1 else None
    results.append(obj)
  return results

def initialize_database(db_name = "rp-sqlite.db",
                        rcynic_root = os.path.expanduser("~/rpki/subvert-rpki.hactrn.net/trunk/"
                                                         "rcynic/rcynic-data/unauthenticated"),
                        delete_old_db = True,
                        spinner = 100):

  # For now just wire in the database name and rcynic root, fix this
  # later if overall approach seems usable.  Might even end up just
  # being an in-memory SQL database, who knows?

  if delete_old_db:
    try:
      os.unlink(db_name)
    except:
      pass

  db = sqlite3.connect(db_name, detect_types = sqlite3.PARSE_DECLTYPES)
  db.text_factory = str

  cur = db.cursor()
  cur.execute("PRAGMA foreign_keys = on")

  cur.execute('''
              CREATE TABLE object (
                id INTEGER PRIMARY KEY NOT NULL,
                der BLOB NOT NULL,
                fn2 TEXT NOT NULL,
                ski BLOB,
                aki BLOB,
                UNIQUE (der))
              ''')

  cur.execute('''
              CREATE TABLE uri (
                id INTEGER NOT NULL,
                uri TEXT NOT NULL,
                UNIQUE (uri),
                FOREIGN KEY (id) REFERENCES object(id)
                        ON DELETE CASCADE
                        ON UPDATE CASCADE)
              ''')

  cur.execute("CREATE INDEX uri_index ON uri(id)")

  cur.execute('''
              CREATE TABLE range (
                id INTEGER NOT NULL,
                min RangeVal NOT NULL,
                max RangeVal NOT NULL,
                UNIQUE (id, min, max),
                FOREIGN KEY (id) REFERENCES object(id)
                        ON DELETE CASCADE
                        ON UPDATE CASCADE)
              ''')

  cur.execute("CREATE INDEX range_index ON range(min, max)")

  nobj = 0

  for root, dirs, files in os.walk(rcynic_root):
    for fn in files:
      fn = os.path.join(root, fn)
      fn2 = os.path.splitext(fn)[1][1:]

      try:
        obj = fn2map[fn2](DER_file = fn)
      except:
        continue

      if spinner and nobj % spinner == 0:
        sys.stderr.write("\r%s %d..." % ("|\\-/"[(nobj/spinner) & 3], nobj))

      nobj += 1

      if fn2 == "crl":
        ski = None
        aki = buffer(obj.get_AKI())
        cer = None

      else:
        if fn2 == "cer":
          cer = obj
        else:
          cer = rpki.x509.X509(POW = obj.get_POW().certs()[0])
        ski = buffer(cer.get_SKI())
        try:
          aki = buffer(cer.get_AKI())
        except:
          aki = None

      der = buffer(obj.get_DER())
      uri = "rsync://" + fn[len(rcynic_root) + 1:]

      try:
        cur.execute("INSERT INTO object (der, fn2, ski, aki) VALUES (?, ?, ?, ?)",
                    (der, fn2, ski, aki))
        rowid = cur.lastrowid

      except sqlite3.IntegrityError:
        cur.execute("SELECT id FROM object WHERE der = ? AND fn2 = ?", (der, fn2))
        rows = cur.fetchall()
        rowid = rows[0][0]
        assert len(rows) == 1

      else:
        if cer is not None:
          bag = cer.get_3779resources()
          for rset in (bag.asn, bag.v4, bag.v6):
            if rset is not None:
              cur.executemany("REPLACE INTO range (id, min, max) VALUES (?, ?, ?)",
                              ((rowid, i.min, i.max) for i in rset))

      cur.execute("INSERT INTO uri (id, uri) VALUES (?, ?)",
                  (rowid, uri))

  if spinner:
    sys.stderr.write("\r= %d objects, committing..." % nobj)

  db.commit()

  if spinner:
    sys.stderr.write("done.\n")

  return db, cur

if __name__ == "__main__":
  main()