#!/usr/bin/env python3

import argparse
import logging
import re
from sys import exit

import redis

loc_prefix = "1:"

parser = argparse.ArgumentParser(
    description="Fixes pseudo-indexes used in Redis "
    "to make location lookups faster"
)
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
verbose = args.verbose

if verbose:
    logging.basicConfig(level=logging.INFO)

f = open("/etc/ngcp-usr-location/ngcp-usr-location.conf", "r")
filecontent = f.read()
f.close()
patternredisip = re.compile(r"^REDIS_IP=(\S.*)$", re.MULTILINE)
patternredisdb = re.compile(r"^REDIS_LOC_DB=(\d+)$", re.MULTILINE)
ip = re.findall(patternredisip, filecontent)
db = re.findall(patternredisdb, filecontent)
if len(ip) and len(db):
    pass
else:
    logging.error("Config file is unparsable")
    exit(1)
ip = ip[0]
db = db[0]
r = redis.Redis(host="%s" % ip, port=6379, db=db, decode_responses=True)

f = open("/etc/kamailio/proxy/kamailio.cfg", "r")
filecontent = f.read()
f.close()
patternmodparam = re.compile(
    r'^modparam\("db_redis"\s*,\s*"keys"\s*,\s*"'
    + loc_prefix
    + r'location=([^"]+)',
    re.MULTILINE,
)
modparam = re.findall(patternmodparam, filecontent)
if not modparam:
    logging.error("Could not find db_redis modparam in kamailio.cfg")
    exit(1)

modparam = modparam[0]

maps = {}
ent_keys = []
defs = modparam.split("&")
for d in defs:
    (name, parts) = d.split(":")
    if name != "entry":
        parts_list = []
        if parts:
            parts_list = parts.split(",")
        maps[name] = parts_list


def redis_evalsha(idx, cmd, sha, numkeys, key, *args):
    data = r.evalsha(sha, numkeys, key, *args)
    if verbose:
        logging.info("Raw command which could be send via keydb-cli:")
        logging.info(
            "keydb-cli -n %s --raw eval '%s' %s %s %s"
            % (db, cmd, numkeys, key, " ".join(args))
        )
    if data & 1:
        if data & 2:
            logging.info(
                "Added %s to its %s index and also its master index"
                % (key, idx)
            )
        else:
            logging.info("Added %s to its %s index" % (key, idx))
    elif data & 2:
        logging.info("Added %s to its master index" % (idx))


rawcommand = (
    "  local ret = 0;\n"
    "  local parts = {};\n"
    "  if ARGV[2] then\n"
    "    parts = redis.call("
    '      "HMGET", KEYS[1], unpack(ARGV, 2)'
    "    );\n"
    "  end;\n"
    '  local key = ARGV[1]..table.concat(parts, ":");\n'
    "  if redis.call("
    '    "SADD", key, KEYS[1]'
    "  ) ~= 0 then\n"
    "    ret = ret + 1;\n"
    "  end;\n"
    "  if redis.call("
    '    "SADD", KEYS[2], key'
    "  ) ~= 0 then\n"
    "    ret = ret + 2;\n"
    "  end;\n"
    "  return ret;\n"
)

sha = r.script_load(rawcommand)


for key in r.scan_iter(loc_prefix + "location:entry::*"):
    for sname, parts in maps.items():
        if len(parts) > 0:
            redis_evalsha(
                sname,
                rawcommand,
                sha,
                2,
                key,
                loc_prefix + "location::index::" + sname,
                loc_prefix + "location:" + sname + "::",
                *parts,
            )
        else:
            redis_evalsha(
                sname,
                rawcommand,
                sha,
                2,
                key,
                loc_prefix + "location::index::" + sname,
                loc_prefix + "location:" + sname,
            )
