#!/usr/bin/python3

from __future__ import annotations

import pymysql
import pymysql.cursors
import argparse

from pymysql import Connection
from pymysql.cursors import Cursor

desc = """
overview:
    The script maps codec IDs for all subscribers, domains, and peers do their
    corresponding name while also matching the value of the codecs_filter flag
    to that of codecs_id_filter.

    If any conflicts arise, the user will be prompted with different options
    to handle these. Optionally, the user can use one of three flags
    ( -s, -a, -S ) to batch-automate this process. The behaviour for these
    three modes are, respectively, as follow: The the option to skip the
    mapping for the current subscriber/domain/peer, append any new codec
    names to the existing name list or substitute the existing name list
    with the names corresponding to the updated ID list while also mathing
    the value of the codecs_filter flag to that of codecs_id_filter.
"""

db_host = "localhost"
db_port = 3306
database = "provisioning"

args = None
is_verbose = False


class Groups:
    """Groups to be queried."""

    def __init__(self):
        self.subscriber = {}
        self.domain = {}
        self.peer = {}

    def get_group(self, attr):
        """Get requested group."""
        return getattr(self, attr)

    def get_group_names(self):
        """Get the names of all groups"""
        return self.__dict__.keys()


class QueryResult:
    """Result of the query."""

    def __init__(self, _id, attribute, value):
        self.id = _id
        self.attribute = attribute
        self.value = value

    def attributes(self) -> tuple:
        return tuple(self.__dict__.values())


groups = Groups()


class Codec:
    """
    Codec object.

    The codec object stores all the codec configurations for any given
    subscriber, peer, or domain.
    """

    codecs = {
        0: "PCMU",
        1: "FS-1016",
        2: "G721",
        3: "GSM",
        4: "G723",
        5: "DVI4",
        6: "DVI4",
        7: "LPC",
        8: "PCMA",
        9: "G722",
        10: "L16",
        11: "L16",
        12: "QCELP",
        13: "CN",
        14: "MPA",
        15: "G728",
        16: "DVI4",
        17: "DVI4",
        18: "G729",
        19: "CN",
        25: "CELLB",
        26: "JPEG",
        28: "nv",
        31: "H261",
        32: "MPV",
        33: "MP2T",
        34: "H263",
    }

    def __init__(self, _id, group):
        """Initialize the codec object."""
        self.id = _id
        self.group = group
        self.codecs_id_filter = 0
        self.codecs_filter = 0
        self.codecs_id_list = set()
        self.codecs_list = set()

    def __str__(self) -> str:
        """String representation of the codec object."""
        return (
            f"Configuration for {self.group} with ID {self.id} \n"
            f"-> codec_id_filter={self.codecs_id_filter}\n"
            f"-> codecs_id_list={self.codecs_id_list}\n"
            f"-> codecs_filter={self.codecs_filter}\n"
            f"-> codecs_list={self.codecs_list}"
        )

    def set_attribute(self, attribute, value) -> None:
        """
        Set an attribute of the codec object.

        :param attribute: attribute of the codec object.
        :param value: value of the attribute.
        """
        setattr(self, attribute, value)

    def _list_to_csv(self, codec_list: set) -> str:
        """Parse list into a comma separated value string."""
        items = ""
        for index, value in enumerate(codec_list):
            if index == len(codec_list) - 1:
                items += f"{value}"
            else:
                items += f"{value},"
        return items

    def _is_codec_list_to_be_updated(self):
        """Check if the codec ids list matches the codec names list"""
        for codec_id in self.codecs_id_list:
            if codec_id in self.codecs:
                if self.codecs[codec_id] not in self.codecs_list:
                    return True
            else:
                return True
        return False

    def _add_codecs_to_set(self):
        """
        Iterates over codec ids and adds their corresponding name to the
        codecs_list set.
        """
        for codec_id in self.codecs_id_list:
            if codec_id in self.codecs:
                if isinstance(self.codecs_list, str):
                    self.codecs_list = set()
                self.codecs_list.add(self.codecs[codec_id])
                log_verbose(
                    f"Mapping codec with ID {codec_id} as "
                    f"{self.codecs[codec_id]}"
                )
            else:
                log(
                    "WARNING: Mapping error. Codec ID could not be mapped "
                    "to a name. Skipping..."
                )

    def map_id_to_name(self, conflict_handler: int):
        """
        Map codec id to its corresponding name.

        :param conflict_handler: conflict handling mode
        """
        log_verbose(f"Now mapping {self.group} with ID {self.id}")
        if len(self.codecs_id_list) > 0:
            if len(self.codecs_list) == 0:
                self._add_codecs_to_set()
                self.codecs_filter = self.codecs_id_filter
                log_verbose(
                    f"Codec filter value has been changed to "
                    f"{self.codecs_filter}"
                )
            else:
                if not self._is_codec_list_to_be_updated():
                    log_verbose(
                        "Codecs already mapped. No changes will be made. "
                        "Skipping..."
                    )
                    return
                action = "0"
                if conflict_handler > 0:
                    action = str(conflict_handler)
                else:
                    action = input(
                        f"Conflicting codec names associated with this "
                        f"{self.group} have been found! "
                        f"In order to resolve this issue, please choose "
                        f"one of the following options:\n"
                        f"1. skip       >>> "
                        f"the codec names will be left as is\n"
                        f"2. append     >>> "
                        f"the new codec names will be appended to the "
                        f"existing codec name list\n"
                        f"3. substitute >>> "
                        f"the current codec names will be replaced by "
                        f"the ones specified by their ID\n"
                        f"Please, select option 1, 2, or 3: "
                    )
                match action:
                    case "1":
                        # skip
                        log(f"Skipping {self.group}")
                        return
                    case "2":
                        # append
                        log(
                            f"Appending new codec names for the current "
                            f"{self.group}"
                        )
                        self._add_codecs_to_set()
                    case "3":
                        # substitute
                        log(
                            f"Substituting codec names for the current "
                            f"{self.group}"
                        )
                        self.codecs_list.clear()
                        log_verbose("Existing codec list has been cleared")
                        self._add_codecs_to_set()
                        self.codecs_filter = self.codecs_id_filter
                        log_verbose(
                            f"Codec filter value has been changed to "
                            f"{self.codecs_filter}"
                        )
                    case _:
                        print(
                            f"Sorry, I don't know how to handle this action. "
                            f"The {self.group} will be skipped."
                        )

    def generate_codec_query(self) -> tuple[str, str]:
        """
        Generate codec query for the codec object, with
        update-first-then-insert logic.

        :return list containing codec update and insert queries.
        """

        query_variables = {
            "subscriber": {
                "target_table": "provisioning.voip_usr_preferences",
                "id_column": "subscriber_id",
            },
            "domain": {
                "target_table": "provisioning.voip_dom_preferences",
                "id_column": "domain_id",
            },
            "peer": {
                "target_table": "provisioning.voip_peer_preferences",
                "id_column": "peer_host_id",
            },
        }

        target_table = query_variables[self.group]["target_table"]
        id_column = query_variables[self.group]["id_column"]

        update_query = f"""
            UPDATE {target_table} u
            JOIN provisioning.voip_preferences p ON
              u.attribute_id = p.id
            SET u.value = CASE
              WHEN p.attribute = 'codecs_filter'
                THEN {self.codecs_filter}
              WHEN p.attribute = 'codecs_list'
                THEN '{self._list_to_csv(self.codecs_list)}'
              WHEN p.attribute = 'codecs_id_filter'
                THEN {self.codecs_id_filter}
              WHEN p.attribute = 'codecs_id_list'
                THEN '{self._list_to_csv(self.codecs_id_list)}'
            END
            WHERE u.{id_column} = {self.id}
              AND p.attribute IN (
                'codecs_filter',
                 'codecs_list',
                 'codecs_id_filter',
                 'codecs_id_list'
              );
        """

        insert_query = f"""
            INSERT INTO {target_table} ({id_column}, attribute_id, value)
            SELECT {self.id}, p.id, CASE
                WHEN p.attribute = 'codecs_filter'
                  THEN {self.codecs_filter}
                WHEN p.attribute = 'codecs_list'
                  THEN '{self._list_to_csv(self.codecs_list)}'
                WHEN p.attribute = 'codecs_id_filter'
                  THEN {self.codecs_id_filter}
                WHEN p.attribute = 'codecs_id_list'
                  THEN '{self._list_to_csv(self.codecs_id_list)}'
              END AS value
            FROM provisioning.voip_preferences p
            WHERE p.attribute IN (
                'codecs_filter',
                'codecs_list',
                'codecs_id_filter',
                'codecs_id_list'
              ) AND NOT EXISTS (
                SELECT 1
                FROM {target_table} u
                WHERE u.{id_column} = {self.id}
                  AND u.attribute_id = p.id
              );
        """
        return (update_query, insert_query)


def logit(str, level, end) -> None:
    """
    General logger.

    Format messages and prints to stdout.

    :param str: message to log
    :param level: logging severity level
    :param end: end character
    :return:
    """
    print("%s-> %s" % (" " * 4 * level, str), end=end)


def log(str, level=0, end="\n") -> None:
    """Print a log message to the terminal."""
    logit(str, level, end)


def log_verbose(str, level=0, end="\n") -> None:
    """Print a log message to the terminal if verbose mode is active."""
    global is_verbose
    if is_verbose is True:
        logit(str, level, end)


def error(str, level=0, end="\n") -> None:
    """Print an error message to the terminal."""
    logit("Error: " + str, level, end)


def connect_db() -> Connection[Cursor] | int:
    """
    Connect to the database.

    :return: the connection object or 0 if there is an error.
    """

    try:
        db_connection = pymysql.connect(
            host=db_host,
            port=db_port,
            read_default_file="/etc/mysql/sipwise_extra.cnf",
            database=database,
        )
        log_verbose("connected to %s:%s as 'sipwise'" % (db_host, db_port))
        db_connection.autocommit(False)
        return db_connection
    except Exception as e:
        error(
            "could not connect to %s:%s as 'sipwise': %s"
            % (db_host, db_port, e)
        )
        return 0


def fetch_subscriber_codec_ids_query() -> str:
    """
    Fetch list of the subscribers that has at least one of the preferences set.

    :return: a list of subscribers that have at least one of the preferences
    set.
    """
    return """
        SELECT s.id, s.uuid, s.username, p.attribute, u.value
        FROM provisioning.voip_subscribers s,
             provisioning.voip_usr_preferences u,
             provisioning.voip_preferences p
        WHERE p.id = u.attribute_id
          AND s.id = u.subscriber_id
          AND p.attributea IN (
            'codecs_filter',
            'codecs_list',
            'codecs_id_filter',
            'codecs_id_list'
          )
        ORDER BY s.username, p.attribute;
    """


def fetch_domain_codec_ids_query() -> str:
    """
    Fetch list of the domains that has at least one of the preferences set.

    :return: a list of domains that have at least one of the preferences set.
    """
    return """
        SELECT d.id, d.domain, p.attribute, u.value
        FROM provisioning.voip_domains d,
             provisioning.voip_dom_preferences u,
             provisioning.voip_preferences p
        WHERE p.id = u.attribute_id
          AND d.id = u.domain_id
          AND p.attribute IN (
            'codecs_filter',
            'codecs_list',
            'codecs_id_filter',
            'codecs_id_list'
          )
        ORDER BY d.domain, p.attribute;
    """


def fetch_peer_codec_ids_query() -> str:
    """
    Fetch list of the peers that has at least one of the preferences set.

    :return: a list of peers that have at least one of the preferences set.
    """
    return """
        SELECT h.id, h.name, h.ip, p.attribute, u.value
        FROM provisioning.voip_peer_hosts h,
             provisioning.voip_peer_preferences u,
             provisioning.voip_preferences p
        WHERE p.id = u.attribute_id
          AND h.id = u.peer_host_id
          AND p.attribute IN (
            'codecs_filter',
            'codecs_list',
            'codecs_id_filter',
            'codecs_id_list'
          )
        ORDER BY h.name, p.attribute;
    """


queries = {
    "subscriber": fetch_subscriber_codec_ids_query,
    "domain": fetch_domain_codec_ids_query,
    "peer": fetch_peer_codec_ids_query,
}


def initialize_codec_objects(results: tuple, group: str) -> None:
    ids = set(map(lambda result: result[0], results))
    for _id in ids:
        groups.get_group(group)[_id] = Codec(_id, group)


def extract_query_result_variables(group: str, result: tuple) -> QueryResult:
    """
    Extract variables from a query result tuple.

    :param group: the name of the group
    :param result: the query result tuple
    :return: a QueryResult object
    """
    match group:
        case "subscriber":
            return QueryResult(
                _id=result[0], attribute=result[3], value=result[4]
            )
        case "domain":
            return QueryResult(
                _id=result[0], attribute=result[2], value=result[3]
            )
        case "peer":
            return QueryResult(
                _id=result[0], attribute=result[3], value=result[4]
            )
        case _:
            return QueryResult(_id="", attribute="", value="")


def get_group_codec_information(group: str) -> None:
    """
    Fetch codec data for every group member, and assign it to its corresponding
    Codec object.

    :param group: the name of the group
    """
    db_connection = connect_db()
    if type(db_connection) is Connection:
        with db_connection.cursor() as cursor:
            cursor.execute(queries[group]())
            results = cursor.fetchall()
            initialize_codec_objects(results, group)
            for row in results:
                result = extract_query_result_variables(group, row)
                _id, attribute, value = result.attributes()
                if "list" in attribute and value != "":
                    value = set(value.split(","))
                    if "id" in attribute:
                        value = set(
                            (
                                map(
                                    lambda x: (
                                        int(x) if (x.isnumeric()) else None
                                    ),
                                    value,
                                )
                            )
                        )
                        if None in value:
                            value.remove(None)
                codec = groups.get_group(group)[_id]
                codec.set_attribute(attribute, value)
            db_connection.close()


def map_codec_names(group: str, conflict_handler: int) -> None:
    """
    Map codec IDs to codec names.

    :param group: the name of the group
    :param conflict_handler: the conflict handling strategy
    """

    log_verbose(f"Mapping codec names for all {group}s")
    for codec in groups.get_group(group).values():
        codec.map_id_to_name(conflict_handler)
    log_verbose(f"Successfully mapped codec names for all {group}s")


def get_conflict_resolution_mode(
    is_mode_skip: bool, is_mode_append: bool, is_mode_substitute: bool
) -> int:
    """
    Sets conflict resolution mode.

    Outputs the conflict resolution mode in case of conflicting codec name-id
    mappings.

    :param is_mode_skip: True to set conflict resolution mode to 'skip'.
    :param is_mode_append: True to set conflict resolution mode to 'append'.
    :param is_mode_substitute: True to set conflict resolution mode to
      'replace'.
    """
    if is_mode_skip:
        log_verbose("'skip-all' mode has been enabled")
        return 1
    if is_mode_append:
        log_verbose("'append-all' mode has been enabled")
        return 2
    if is_mode_substitute:
        log_verbose("'substitute-all' mode has been enabled")
        return 3

    return 0


def commit_codec_name_changes(group: str) -> None:
    """
    Commit changes to the database.

    :param group: the name of the group
    """

    log_verbose(f"Commiting codec name changes for all {group}s")
    for codec in groups.get_group(group).values():
        [update_query, insert_query] = codec.generate_codec_query()
        if update_query is not None and insert_query is not None:
            db_connection = connect_db()
            if type(db_connection) is Connection:
                with db_connection.cursor() as cursor:
                    cursor.execute(update_query)
                    cursor.execute(insert_query)
                    db_connection.commit()
                    db_connection.close()
    log_verbose(f"Successfully commited codec names for all {group}s")


def main() -> None:
    """Main function."""
    global args
    global is_verbose
    argparser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description=desc
    )
    conflict_resolution_group = argparser.add_mutually_exclusive_group()
    conflict_resolution_group.add_argument(
        "-a",
        "--append-all",
        action="store_true",
        help=(
            "append all conflicting codec names to the existing "
            "codec name list"
        ),
    )
    conflict_resolution_group.add_argument(
        "-s",
        "--skip-all",
        action="store_true",
        help="skip all operations in case of conflicts",
    )
    conflict_resolution_group.add_argument(
        "-S",
        "--substitute-all",
        action="store_true",
        help=(
            "substitute all conflicting codec names by the ones specified "
            "by their codec ID"
        ),
    )
    argparser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="enable verbose output",
    )

    args, macros = argparser.parse_known_args()
    is_verbose = args.verbose
    is_substitute_all = args.substitute_all
    is_skip_all = args.skip_all
    is_append_all = args.append_all

    log_verbose("verbose mode has been enabled")
    conflict_handler = get_conflict_resolution_mode(
        is_skip_all, is_append_all, is_substitute_all
    )
    connect_db()

    log("fetching data...")
    for group in groups.get_group_names():
        get_group_codec_information(group)
        map_codec_names(group, conflict_handler)
        commit_codec_name_changes(group)


if __name__ == "__main__":
    main()
