#!/usr/bin/python3

import MySQLdb
import MySQLdb.cursors
from pprint import pprint
from time import time
import argparse

desc = """
This script does the following:
 1) checks for the same 'contract_sound_set' preference used
    in multiple subscribers of the same customer, that use different
    'system' sound sets following the inheritance:
    own > subscriber_profile > domain
    as the result, the contract sound set is created as a copy for
    each of those occurrences except the first one
 2) subscriber's assigned contract sound sets that use
    own > subscriber_profile > domain 'sound_set' preferences, have
    the 'sound_set' assigned as a 'parent' of the respective contract sound set
"""

db_host = "localhost"
db_port = 3306
database = "provisioning"

db_conn = {}

args = None


def logit(str, level, end):
    print("%s-> %s" % (" " * 4 * level, str), end=end)


def log(str, level=0, end="\n"):
    logit(str, level, end)


def error(str, level=0, end="\n"):
    logit("Error: " + str, level, end)


def connect_db():
    try:
        global db_conn
        db_conn = MySQLdb.connect(
            host=db_host,
            port=db_port,
            read_default_file="/etc/mysql/sipwise_extra.cnf",
            database=database,
        )
        log("connected to %s:%s as 'sipwise'" % (db_host, db_port))
        db_conn.autocommit(False)
        log("autocommit=false")
        return db_conn
    except Exception as e:
        error(
            "could not connect to %s:%s as 'sipwise': %s"
            % (db_host, db_port, e)
        )
        return 0


def fetch_customers_query():
    return """
SELECT s.id AS subscriber_id, s.account_id AS customer_id,
       vcp.value as contract_sound_set,
       COALESCE(usr_value, prof_value, dom_value) AS sound_set,
       pss.parent_id, pss.reseller_id, pss.name, pss.description,
       contract_pref_id
  FROM voip_subscribers s
  JOIN (SELECT p.id as contract_pref_id, up.subscriber_id, up.value
          FROM voip_preferences p
          LEFT JOIN voip_usr_preferences up ON
            up.attribute_id = p.id
          WHERE p.attribute = 'contract_sound_set') vcp ON
            vcp.subscriber_id = s.id
  LEFT JOIN (SELECT up.subscriber_id, up.value as usr_value
               FROM voip_preferences p
               LEFT JOIN voip_usr_preferences up ON
                 up.attribute_id = p.id
               WHERE p.attribute = 'sound_set') vup ON
                 vup.subscriber_id = s.id
  LEFT JOIN (SELECT dp.domain_id, dp.value as dom_value
               FROM voip_preferences p
               LEFT JOIN voip_dom_preferences dp ON
                 dp.attribute_id = p.id
               WHERE p.attribute = 'sound_set') vdp ON
                 vdp.domain_id = s.domain_id
  LEFT JOIN (SELECT pp.profile_id, pp.value as prof_value
               FROM voip_preferences p
               LEFT JOIN voip_prof_preferences pp
                 FORCE INDEX (profidattrid_idx) ON
                 pp.attribute_id = p.id
               WHERE p.attribute = 'sound_set') vpp ON
                 vpp.profile_id = s.profile_id
  JOIN voip_sound_sets ss ON
    ss.id = COALESCE(usr_value, prof_value, dom_value)
  JOIN voip_sound_sets pss ON
    pss.id = vcp.value
ORDER BY customer_id, subscriber_id
"""


def sound_sets_insert_query():
    return """
INSERT INTO voip_sound_sets
(name, description, reseller_id, contract_id, parent_id, contract_default)
VALUES
(%s, %s, %s, %s, %s, 0)
"""


def sound_sets_parent_update_query():
    return """
UPDATE voip_sound_sets
SET parent_id = %s
WHERE id = %s
"""


def subscriber_contract_ss_pref_update_query():
    return """
UPDATE voip_usr_preferences
SET value = %s
WHERE attribute_id = %s
AND subscriber_id = %s
"""


def fetch_customers_data():
    customers = {}

    cursor = db_conn.cursor()
    cursor.execute(fetch_customers_query())

    prev_customer_id = 0
    for row in cursor:
        [
            sub_id,
            customer_id,
            contract_ss_id,
            ss_id,
            parent_id,
            reseller_id,
            name,
            description,
            contract_pref_id,
        ] = row
        if customer_id not in customers:
            customers[customer_id] = {}
        customer = customers[customer_id]

        if contract_ss_id not in customer:
            customer[contract_ss_id] = {
                "sub_ids": [sub_id],
                "sound_sets": [ss_id],
                "distinct_sound_set": ss_id,
                "name": name,
                "description": description,
                "reseller_id": reseller_id,
                "parent_id": parent_id,
                "contract_pref_id": contract_pref_id,
            }
        contract_ss = customer[contract_ss_id]

        if contract_ss["distinct_sound_set"] != ss_id:
            contract_ss["sub_ids"].append(sub_id)
            contract_ss["sound_sets"].append(ss_id)

        if customer_id != prev_customer_id:
            if prev_customer_id > 0:
                cleanup_ok_records(customers, prev_customer_id)
            prev_customer_id = customer_id

    if cursor.rowcount > 0:
        cleanup_ok_records(customers, prev_customer_id)

    cursor.close()

    return customers


def cleanup_ok_records(customers, customer_id):
    customers[customer_id] = {
        k: v
        for k, v in customers[customer_id].items()
        if not (v["parent_id"] and len(v["sound_sets"]) == 1)
    }
    if not len(customers[customer_id]):
        customers.pop(customer_id)


def process_records(customers):
    [inserts, subs, pref_updates, parent_updates] = [[], [], [], []]
    contract_pref_id = None
    for customer_id, contract_ss in customers.items():
        log(f"customer_id={customer_id}", 1)
        for contract_sound_set_id, data in contract_ss.items():
            log(
                f"contract_sound_set={contract_sound_set_id} "
                f'name=\'{data["name"]}\'',
                2,
            )
            sound_sets = data["sound_sets"]
            sub_ids = data["sub_ids"]
            distinct_sound_set = data["distinct_sound_set"]
            contract_pref_id = data["contract_pref_id"]
            parent_id = data["parent_id"]
            for idx, sound_set_id in enumerate(sound_sets):
                sub_id = sub_ids[idx]
                if idx == 0:
                    if parent_id:
                        log(
                            f"contract sound set already has a "
                            f"parent={parent_id}",
                            3,
                        )
                        continue
                    log(
                        f"set parent sound_set={sound_set_id} sub_id={sub_id}",
                        3,
                    )
                    parent_updates.append(
                        (sound_set_id, contract_sound_set_id)
                    )
                elif sound_set_id != distinct_sound_set:
                    log(
                        f"create a copy #{idx} parent "
                        f"sound_set={sound_set_id} sub_id={sub_id}",
                        3,
                    )
                    inserts.append(
                        (
                            f'{data["name"]} copy#{idx}',
                            data["description"],
                            data["reseller_id"],
                            customer_id,
                            sound_set_id,
                        )
                    )
                    subs.append(sub_id)

    if not len(inserts) and not len(parent_updates):
        log("nothing to do")
        return

    log(
        f"prepared sound_set_inserts={len(inserts)} "
        f"sound_set_parent_updates={len(parent_updates)}"
    )

    if not args.apply:
        return

    insert_contract_sound_sets(inserts, subs, pref_updates, contract_pref_id)
    set_sound_set_parents(parent_updates)

    return


def insert_contract_sound_sets(inserts, subs, pref_updates, contract_pref_id):
    if len(inserts) == 0:
        return

    log("inserting copied contract sound sets...", 0, "")
    started = time()
    cursor = db_conn.cursor()
    cursor.executemany(sound_sets_insert_query(), inserts)
    print(" ok={0:.3f}s".format(time() - started))

    log("preparing preference updates...", 0, "")
    started = time()
    firstid = cursor.lastrowid - (len(inserts) - 1) * 2
    cursor.execute('SHOW VARIABLES LIKE "auto_increment_increment"')
    db_increment = int(cursor.fetchone()[1])
    if db_increment <= 0:
        raise Exception(
            f"Wrong database server increment value={db_increment}, "
            f"it should be greated than 0, please manually query "
            f"\"show variables like 'auto_increment_increment'\""
        )
    nextid = firstid
    for sub_id in subs:
        contract_ss_id = nextid
        pref_updates.append((contract_ss_id, contract_pref_id, sub_id))
        nextid += db_increment
    print(" ok={0:.3f}s".format(time() - started))

    log("updating subscriber contract_sound_set preferences...", 0, "")
    started = time()
    cursor.executemany(
        subscriber_contract_ss_pref_update_query(), pref_updates
    )
    cursor.close()
    print(" ok={0:.3f}s".format(time() - started))

    return


def set_sound_set_parents(updates):
    if not len(updates):
        return

    log("updating contract sound sets parents...", 0, "")
    started = time()
    cursor = db_conn.cursor()
    cursor.executemany(sound_sets_parent_update_query(), updates)
    cursor.close()
    print(" ok={0:.3f}s".format(time() - started))

    return


def main():
    global args
    argparser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description=desc
    )
    argparser.add_argument(
        "--apply",
        action=argparse.BooleanOptionalAction,
        help=(
            "apply the changes, otherwise no db changes are stored by default",
        )
    )
    args, macros = argparser.parse_known_args()
    apply = args.apply

    if not apply:
        log(
            "[dry run, no changes are applied, "
            "use --apply to commit the changes]"
        )

    connect_db()

    log("fetching data...", 0, "")
    started = time()
    customers = fetch_customers_data()
    print(" ok={0:.3f}s".format(time() - started))

    log("processing data...")
    started = time()
    process_records(customers)
    log("done {0:.3f}s".format(time() - started))

    if apply:
        db_conn.commit()
    else:
        db_conn.rollback()
    db_conn.close()


if __name__ == "__main__":
    main()
