#!/usr/bin/env python3

import argparse
import datetime
import json
import os
import re
import sys
import time
import traceback

import pytimeparse
import redis

# constants

KAMAILIO_CONFIG = "/etc/kamailio/proxy/kamailio.cfg"
ENCODINGS = ["utf-8", "latin1"]


# sort-of constants

# cutoff to detect permanent registrations: 5 years into the future
PERM_CUTOFF = time.time() + 86400 * 365 * 5


# global variables

redis_schema = None
usrloc_keys = {}


def init_redis_schema(args):
    global redis_schema, usrloc_keys

    # parse usrloc pseudo-indexes from config

    version = None

    with open(KAMAILIO_CONFIG, "r") as fp:
        # modparam("db_redis", "keys", "1:location=entry:ruid&usrdom:
        # username,domain&timer:partition,keepalive&expiry:expires&master:")
        matcher = re.compile(
            r'^modparam\s*\(\s*"db_redis"\s*,\s*"keys"\s*,\s*'
            r'"(?:(.*?):)?location=(.*?)"\s*\)$'
        )
        matches = None
        while True:
            line = fp.readline()
            if line == "":
                break
            line = line.strip()
            matches = matcher.search(line)
            if matches:
                break

        if not matches:
            raise SystemError(
                "unable to find db_redis keys definition in Kamailio config"
            )

        version = matches[1]
        if version is None:
            version = "0"

        # parse out definition string

        keys = matches[2].split("&")
        for key in keys:
            (kname, kcols) = key.split(":")
            if kcols == "":
                usrloc_keys[kname] = []
            else:
                usrloc_keys[kname] = kcols.split(",")

    # define supported KeyDB schema versions

    class redisVersion0:
        prefix = ""

        def convert_out(entry):
            if "expires" in entry:
                entry["expires"] = datetime.datetime.fromtimestamp(
                    entry["expires"]
                ).strftime("%Y-%m-%d %H:%M:%S")

    class redisVersion1:
        prefix = "1:"

        def convert_out(entry):
            pass

    redis_schemas = {
        "0": redisVersion0,
        "1": redisVersion1,
    }

    # see which one to use

    if args.version is not None:
        version = args.version
    if version not in redis_schemas:
        raise ValueError(f"invalid KeyDB schema version ('{version}')")
    redis_schema = redis_schemas[version]


def convert_in(entry):
    if "expires" not in entry:
        entry["expires"] = 0
    else:
        # try int first
        try:
            exp = int(entry["expires"])
            entry["expires"] = exp
        except ValueError:
            # string timestamp
            exp = datetime.datetime.strptime(
                entry["expires"], "%Y-%m-%d %H:%M:%S"
            )
            entry["expires"] = int(exp.timestamp())


def convert_out(entry):
    redis_schema.convert_out(entry)


def do_import(redis_db, input_file, allow_flush, allow_merge, flt):
    if allow_flush:
        redis_db.flushdb()
    elif not allow_merge:
        ret = redis_db.dbsize()
        if ret > 0:
            print(
                "KeyDB DB is not empty and neither flushdb nor merge was "
                "allowed",
                file=sys.stderr,
            )
            return False

    # hand-roll the reading to avoid having to load the entire JSON list
    first = True
    buf = ""
    while True:
        eof = False
        chunk = input_file.read(1024)
        if chunk:
            buf = buf + chunk
        else:
            eof = True

        # we should always be leading at the beginning of a dict, therefore
        # discard all whitespace
        buf = buf.lstrip()

        # discard leading '['
        if first:
            if buf[0] == "[":
                buf = buf[1:]
            first = False
            buf = buf.lstrip()

        if not len(buf):
            # success, consumed all data
            break

        if eof:
            # something wrong - reached end of file, but still data left in
            # the buffer
            raise SyntaxError(f"leftover JSON data at end of buffer: {buf}")

        # consume all dicts we have in the buffer
        while len(buf):
            # we should have a dict
            if buf[0] != "{":
                raise SyntaxError("invalid JSON format (no dict found)")
            # find end of dict
            pos = buf.find("},\n")
            if pos == -1:
                # perhaps end of list?
                pos = buf.find("}\n")
                if pos == -1:
                    pos = buf.find("}]\n")
            if pos == -1:
                # read more input
                break
            # extract dict and discard from buffer
            onedict = buf[0:pos] + "}"
            buf = buf[pos:]
            # also discard dict end from buffer
            buf = buf.lstrip("}],\n ")

            # parse and process
            entry = json.loads(onedict)

            try:
                convert_in(entry)

                if not flt.is_allowed(entry):
                    continue

                flt.mangle(entry)

                # construct pseudo indexes
                indexes = []
                for usrloc_key in usrloc_keys.items():
                    (keyname, subkeys_list) = usrloc_key
                    full_key_list = []
                    for subkey in subkeys_list:
                        full_key_list.append(str(entry[subkey]))
                    full_key = redis_schema.prefix + "location:" + keyname
                    if full_key_list:
                        full_key = full_key + "::" + ":".join(full_key_list)
                    indexes.append(full_key)

                # first index is the entry key
                entry_key = indexes[0]

                # write entry
                convert_out(entry)
                redis_db.hset(entry_key, mapping=entry)

                # write pseudo indexes
                for index in indexes[1:]:
                    redis_db.sadd(index, entry_key)

            except Exception:
                print(
                    f"Could not write entry '{entry}' to KeyDB",
                    file=sys.stderr,
                )
                traceback.print_exc(file=sys.stderr)

    return True


def do_export(redis_db, output_file, flt):
    # export entry by entry and fake the JSON list
    print("[", file=output_file)

    cur = 0
    first = True

    while True:
        (cur, eles) = redis_db.sscan(
            redis_schema.prefix + "location:master", cursor=cur, count=10000
        )

        for ele in eles:
            try:
                entry_raw = redis_db.hgetall(ele)

                entry = {}
                for k, v in entry_raw.items():
                    ke = None
                    ve = None
                    for encoding in ENCODINGS:
                        try:
                            ke = k.decode(encoding)
                        except (UnicodeDecodeError, LookupError):
                            continue
                        try:
                            ve = v.decode(encoding)
                        except (UnicodeDecodeError, LookupError):
                            continue
                    if ke is None or ve is None:
                        raise TypeError(
                            f"Failed to convert {str(k)} or {str(v)} "
                            "into a string"
                        )
                    entry[ke] = ve

                convert_in(entry)

                if not flt.is_allowed(entry):
                    continue

                flt.mangle(entry)

                if not first:
                    print(",", file=output_file)

                json.dump(entry, output_file, indent=2)

                first = False

            except Exception:
                print(
                    f"Could not read key '{ele}' from KeyDB", file=sys.stderr
                )
                traceback.print_exc(file=sys.stderr)

        if cur == 0:
            break

    print("\n]", file=output_file)

    return True


def is_perm(entry):
    if "expires" not in entry:
        return True
    elif entry["expires"] == 0:
        return True
    elif entry["expires"] >= PERM_CUTOFF:
        return True
    return False


class regFilter:
    def __init__(self, args):
        self.__allow_perm = True
        self.__allow_non_perm = True
        self.__shift_exp = 0
        self.__set_exp = None

        if args.permanent:
            if args.permanent == "only":
                self.__allow_non_perm = False
            elif args.permanent == "exclude":
                self.__allow_perm = False
            else:
                raise ValueError('invalid value for "--permenent"')

        if args.shift_expiry:
            self.__shift_exp = pytimeparse.parse(args.shift_expiry)
            if not self.__shift_exp:
                raise ValueError('invalid time period for "--shift-expiry"')
        if args.set_expiry:
            self.__set_exp = pytimeparse.parse(args.set_expiry)
            if not self.__set_exp:
                raise ValueError('invalid time period for "--set-expiry"')
        if self.__set_exp and self.__shift_exp:
            raise ValueError(
                'cannot use both "--shift-expiry" and ' '"--set-expiry"'
            )

    def is_allowed(self, entry):
        # permanent reg?

        if is_perm(entry):
            if not self.__allow_perm:
                return False
        else:
            if not self.__allow_non_perm:
                return False

        return True

    def mangle(self, entry):
        if not is_perm(entry):
            if self.__shift_exp:
                entry["expires"] = entry["expires"] + self.__shift_exp
            elif self.__set_exp:
                entry["expires"] = int(time.time()) + self.__set_exp


def main():
    global redis_schema

    parser = argparse.ArgumentParser()
    parser.add_argument("operation", help='"import" or "export"')
    parser.add_argument(
        "--ip", help="KeyDB server address", default="localhost"
    )
    parser.add_argument(
        "--port", help="KeyDB server port", type=int, default=6379
    )
    parser.add_argument(
        "--db", help="database used for operation", type=int, default=20
    )
    parser.add_argument("--file", help="IN or OUT file")
    parser.add_argument(
        "--flushdb",
        help="flush existing contents of DB before importing",
        action="store_true",
    )
    parser.add_argument(
        "--merge",
        help="merge imported records with existing DB",
        action="store_true",
    )
    parser.add_argument(
        "--overwrite",
        help="allow overwriting existing files when exporting",
        action="store_true",
    )
    parser.add_argument(
        "--permanent",
        help="include or exclude permanent registrations",
        choices=["only", "exclude"],
    )
    parser.add_argument("--shift-expiry", help="shift expiry into the future")
    parser.add_argument(
        "--set-expiry",
        help="set expiry into the future based on current " "time",
    )
    parser.add_argument(
        "--version", help="DB schema version", choices=["0", "1"]
    )

    args = parser.parse_args()

    if args.flushdb and args.merge:
        print(
            "The options --flushdb and --merge are mutually exclusive",
            file=sys.stderr,
        )
        return 1

    init_redis_schema(args)

    redis_db = redis.Redis(host=args.ip, port=args.port, db=args.db)

    flt = regFilter(args)
    success = False

    if args.operation == "import":
        input_file = sys.stdin
        if args.file:
            input_file = open(args.file, "r")
        success = do_import(
            redis_db, input_file, args.flushdb, args.merge, flt
        )
        input_file.close()
    elif args.operation == "export":
        output_file = sys.stdout
        if args.file:
            if os.path.exists(args.file) and not args.overwrite:
                print(
                    f"'{args.file}' already exists and overwriting it was "
                    "not allowed",
                    file=sys.stderr,
                )
                return 1
            output_file = open(args.file, "w")
        success = do_export(redis_db, output_file, flt)
        output_file.close()
    else:
        print(f"Invalid operation '{args.operation}'", file=sys.stderr)
        print(file=sys.stderr)
        parser.print_help(file=sys.stderr)
        return 2

    if success:
        return 0
    return 1


sys.exit(main())
