#!/usr/bin/python3

from __future__ import annotations

import pymysql
import pymysql.cursors
import argparse

from pymysql import Connection
from pymysql.cursors import Cursor

desc = """
overview:
    The script ports the man_allowed_ips IPs to the allowed_ips property for
    all subscribers, and domains.

    If any conflicts arise, the user will be prompted with different options
    to handle these. Optionally, the user can use one of two flags ( -s, -a )
    to batch-automate this process. The behaviour for these three modes are,
    respectively, as follow: The the option to either skip the mapping for
    the current subscriber/domain, or append any new IPs to the existing
    allowed_ips property.
"""

db_host = "localhost"
db_port = 3306
database = "provisioning"

args = None
is_verbose = False
is_first_connection = True


class Groups:
    """Groups to be queried."""

    def __init__(self):
        self.subscriber = {}
        self.domain = {}

    def get_group(self, attr):
        """Get requested group."""
        return getattr(self, attr)

    def get_group_names(self):
        """Get the names of all groups"""
        return self.__dict__.keys()


class QueryResult:
    """Result of the query."""

    def __init__(self, _id, group_id, attribute, value):
        self.__id = _id
        self.group_id = group_id
        self.attribute = attribute
        self.value = value

    def attributes(self) -> tuple:
        return tuple(self.__dict__.values())


groups = Groups()


class Ip:
    """
    Ip object.

    The IP object stores all the IP configurations for any given subscriber,
    or domain.
    """

    def __init__(self, _id, group):
        """Initialize the IP object."""
        self.__id = _id
        self.nid = _id
        self.__group = group
        self.__conflict_resolution_mode = 0
        self.allowed_ips_grp = set()
        self.man_allowed_ips_grp = set()
        self.allowed_ips_grp_id = 0
        self.man_allowed_ips_grp_id = 0

    def __str__(self) -> str:
        """String representation of the IP object."""
        allowed_ips_grp_id = (
            f"group_id {self.allowed_ips_grp_id}"
            if self.allowed_ips_grp_id
            else "no assigned group_id"
        )
        man_allowed_ips_grp_id = (
            f"group_id {self.man_allowed_ips_grp_id}"
            if self.man_allowed_ips_grp_id
            else "no assigned group_id"
        )

        return (
            f"Configuration for {self.__group} with ID {self.__id} \n"
            f"-> allowed_ips > {self.allowed_ips_grp} "
            f"with {allowed_ips_grp_id}\n"
            f"-> man_allowed_ips -> {self.man_allowed_ips_grp} "
            f"with {man_allowed_ips_grp_id}"
        )

    def append_to_attribute(self, attribute, value, group_id) -> None:
        """
        Set an attribute of the IP object.

        :param attribute: attribute of the IP object.
        :param value: value of the attribute.
        """
        property_set: set = getattr(self, attribute)
        property_set.add(value)
        setattr(self, f"{attribute}_id", group_id)

    def _append_ips(self, target: set, data: set) -> None:
        """
        Adds members from data set to the target set.

        :param target: Set that will be updated.
        :param data: Set containing members to be inserted into target.
        """
        log_verbose(
            f"Updating 'allowed_ips_grp' property for {self.__group} "
            f"with id {self.__id}"
        )
        log_verbose(
            f"This operation will add the following IPs {data - target}"
        )
        target.update(data)

    def update_ip_preferences(self, conflict_handler: int):
        """
        Update the IP preferences.

        :param conflict_handler: conflict handling mode
        """
        log(
            f"Inspecting IP preferences for {self.__group} with ID {self.__id}"
        )
        if not len(self.man_allowed_ips_grp):
            log(
                f"Skipping {self.__group}, 'man_allowed_ips_grp' property "
                f"is empty"
            )
        elif self.man_allowed_ips_grp.issubset(self.allowed_ips_grp):
            log(
                f"Skipping {self.__group}, all of 'man_allowed_ips_grp' IPs "
                f"are already included in 'allowed_ips_grp'"
            )
        else:
            if conflict_handler > 0:
                self.__conflict_resolution_mode = conflict_handler
                if len(self.man_allowed_ips_grp - self.allowed_ips_grp) != 0:
                    log_verbose(
                        "The following IPs found on man_allowed_ips_grp "
                        "are missing from allowed_ips_grp"
                    )
                    log_verbose(
                        f"{self.man_allowed_ips_grp - self.allowed_ips_grp}"
                    )
            else:
                missing_ips_grp = (
                    self.man_allowed_ips_grp - self.allowed_ips_grp
                )
                if is_verbose:
                    user_input = input(
                        f"Conflicting IPs associated with this {self.__group} "
                        f"have been found!\n"
                        f"The following IPs found on 'man_allowed_ips_grp' "
                        f"are missing from 'allowed_ips_grp':\n"
                        f"{missing_ips_grp}\n"
                        f"In order to resolve this issue, please choose one "
                        f"of the following options:\n"
                        f"1. skip       >>> "
                        f"the IP properties will remain unchanged\n"
                        f"2. append     >>> "
                        f"the IPs from 'man_allowed_ips_grp' will be "
                        f"appended to 'allowed_ips_grp'\n"
                        f"Please, select option 1, or 2: "
                    )
                else:
                    user_input = input(
                        f"Conflicting IPs associated with this {self.__group} "
                        f"have been found!\n"
                        f"In order to resolve this issue, please choose one "
                        f"of the following options:\n"
                        f"1. skip       >>> "
                        f"the IP properties will remain unchanged\n"
                        f"2. append     >>> "
                        f"the IPs from 'man_allowed_ips_grp' will be "
                        f"appended to 'allowed_ips_grp'\n"
                        f"Please, select option 1, or 2: "
                    )
                self.__conflict_resolution_mode = (
                    int(user_input) if user_input.isdigit() else 0
                )
            match self.__conflict_resolution_mode:
                case 1:
                    # skip
                    log(f"Skipping {self.__group}")
                    return
                case 2:
                    # append
                    log(f"Appending new IPs for the current {self.__group}")
                    self._append_ips(
                        self.allowed_ips_grp, self.man_allowed_ips_grp
                    )
                case _:
                    print(
                        f"Sorry, I don't know how to handle this action. "
                        f"The {self.__group} will be skipped."
                    )
        return

    def generate_ip_query(
        self, db_connection: Connection[Cursor]
    ) -> tuple[list[str], str]:
        """
        Generate query for the IP object, with update-first-then-insert logic.

        :return list containing IP insert and relation_update queries.
        """

        query_variables = {
            "subscriber": {
                "target_table": "provisioning.voip_usr_preferences",
                "id_column": "subscriber_id",
            },
            "domain": {
                "target_table": "provisioning.voip_dom_preferences",
                "id_column": "domain_id",
            },
        }

        target_table = query_variables[self.__group]["target_table"]
        id_column = query_variables[self.__group]["id_column"]

        group_id_relation_query = ""

        if self.allowed_ips_grp_id == 0:
            with db_connection.cursor() as cursor:
                # Get the next group_id in case the allowed_ips_grp has no IPs
                cursor.execute("""
                    SELECT IFNULL(MAX(group_id), 0) + 1
                    FROM provisioning.voip_allowed_ip_groups
                """)
                result = cursor.fetchone()
                if result is not None:
                    self.allowed_ips_grp_id = result[0]

                group_id_relation_query = f"""
                    INSERT IGNORE INTO {target_table}
                      ({id_column}, attribute_id, value)
                    SELECT DISTINCT {self.__id}, p.id,
                      {self.allowed_ips_grp_id}
                    FROM provisioning.voip_preferences p, {target_table} u
                    WHERE p.attribute = 'allowed_ips_grp'
                      AND NOT EXISTS (
                        SELECT 1
                        WHERE u.{id_column} = {self.__id}
                          AND u.attribute_id = p.id
                    );
                """
        if self.__conflict_resolution_mode == 2:
            insert_queries = [
                f"""
                INSERT IGNORE INTO provisioning.voip_allowed_ip_groups
                  (group_id, ipnet)
                SELECT
                  IFNULL(
                   (SELECT group_id
                    FROM {target_table} t
                    JOIN {target_table} u ON
                      t.id = u.{id_column}
                    JOIN provisioning.voip_preferences p ON
                      p.id = u.attribute_id
                    JOIN provisioning.voip_allowed_ip_groups v ON
                      v.group_id = u.value
                    WHERE p.attribute = 'allowed_ips_grp'
                      AND u.{id_column} = {self.__id}
                    LIMIT 1), {self.allowed_ips_grp_id}
                  ) AS group_id,
                  '{new_allowed_ip}' AS ipnet
                WHERE NOT EXISTS (
                  SELECT 1
                  FROM provisioning.voip_allowed_ip_groups
                  WHERE group_id = IFNULL(
                      (
                        SELECT group_id
                         FROM {target_table} t
                         JOIN {target_table} u ON
                           t.id = u.{id_column}
                         JOIN provisioning.voip_preferences p ON
                           p.id = u.attribute_id
                         JOIN provisioning.voip_allowed_ip_groups v ON
                           v.group_id = u.value
                         WHERE p.attribute = 'allowed_ips_grp'
                           AND u.{id_column} = {self.__id}
                         LIMIT 1
                      ),
                      {self.allowed_ips_grp_id}
                    )
                    AND ipnet = '{new_allowed_ip}'
                );
                """
                for new_allowed_ip in self.man_allowed_ips_grp
            ]
        else:
            insert_queries = [""]

        return insert_queries, group_id_relation_query


def logit(str, level, end) -> None:
    """General logger.

    Format messages and prints to stdout.

    :param str: message to log
    :param level: logging severity level
    :param end: end character
    :return:
    """
    print("%s-> %s" % (" " * 4 * level, str), end=end)


def log(str, level=0, end="\n") -> None:
    """Print a log message to the terminal."""
    logit(str, level, end)


def log_verbose(str, level=0, end="\n") -> None:
    """Print a log message to the terminal if verbose mode is active."""
    global is_verbose
    if is_verbose is True:
        logit(str, level, end)


def error(str, level=0, end="\n") -> None:
    """Print an error message to the terminal."""
    logit("Error: " + str, level, end)


def connect_db() -> Connection[Cursor] | int:
    """
    Connect to the database.

    :return: the connection object or 0 if there is an error.
    """

    try:
        db_connection = pymysql.connect(
            host=db_host,
            port=db_port,
            read_default_file="/etc/mysql/sipwise_extra.cnf",
            database=database,
        )
        global is_first_connection
        if is_first_connection:
            log_verbose("connected to %s:%s as 'sipwise'" % (db_host, db_port))
            is_first_connection = False
        db_connection.autocommit(False)
        return db_connection
    except Exception as e:
        error(
            "could not connect to %s:%s as 'sipwise': %s"
            % (db_host, db_port, e)
        )
        return 0


def generate_ip_query(group: str) -> str:
    """
    Create query to find the all group members that has at least one of the
    preferences set.

    :return: a list of subscribers that have at least one of the preferences
    set.
    """
    table_data = {
        "subscriber": {
            "table": "provisioning.voip_subscribers",
            "preferences_table": "provisioning.voip_usr_preferences",
            "id": "subscriber_id",
        },
        "domain": {
            "table": "provisioning.voip_domains",
            "preferences_table": "provisioning.voip_dom_preferences",
            "id": "domain_id",
        },
    }
    target_table = table_data[group]["table"]
    target_preferences_table = table_data[group]["preferences_table"]
    target_id = table_data[group]["id"]

    return f"""
        SELECT t.id, v.group_id, p.attribute, v.ipnet
        FROM {target_table} t,
             {target_preferences_table} u,
             provisioning.voip_preferences p,
             provisioning.voip_allowed_ip_groups v
        WHERE p.id = u.attribute_id
          AND t.id = u.{target_id}
          AND v.group_id = u.value
          AND p.attribute IN ('man_allowed_ips_grp', 'allowed_ips_grp')
        ORDER BY t.id;
    """


queries = {
    "subscriber": generate_ip_query("subscriber"),
    "domain": generate_ip_query("domain"),
}


def initialize_ip_objects(results: tuple, group: str) -> None:
    ids = set(map(lambda result: result[0], results))
    for _id in ids:
        groups.get_group(group)[_id] = Ip(_id, group)


def extract_query_result_variables(group: str, result: tuple) -> QueryResult:
    """
    Extract variables from a query result tuple.

    :param group: the name of the group
    :param result: the query result tuple
    :return: a QueryResult object
    """
    if result is not None:
        return QueryResult(
            _id=result[0],
            group_id=result[1],
            attribute=result[2],
            value=result[3],
        )
    else:
        return QueryResult(_id="", attribute="", value="")


def get_group_ip_information(group: str) -> None:
    """
    Fetch IP data for every group member, and assign it to its corresponding
    Ip object.

    :param group: the name of the group
    """
    db_connection = connect_db()
    if type(db_connection) is Connection:
        with db_connection.cursor() as cursor:
            cursor.execute(queries[group])
            results = cursor.fetchall()
            initialize_ip_objects(results, group)
            for row in results:
                result = extract_query_result_variables(group, row)
                _id, group_id, attribute, value = result.attributes()
                ip = groups.get_group(group)[_id]
                ip.append_to_attribute(attribute, value, group_id)
            db_connection.close()


def update_group_ips(group: str, conflict_handler: int) -> None:
    """
    Update id preferences.

    :param group: the name of the group
    :param conflict_handler: the conflict handling strategy
    """

    log_verbose(f"Updating IPs for all {group}s")
    for ip in groups.get_group(group).values():
        ip.update_ip_preferences(conflict_handler)
    log_verbose(f"Successfully updated IPs for all {group}s")


def get_conflict_resolution_mode(
    is_mode_skip: bool, is_mode_append: bool
) -> int:
    """
    Sets conflict resolution mode.

    Outputs the conflict resolution mode in case of conflicting IPs.

    :param is_mode_skip: True to set conflict resolution mode to 'skip'.
    :param is_mode_append: True to set conflict resolution mode to 'append'.
    """
    if is_mode_skip:
        log_verbose("'skip-all' mode has been enabled")
        return 1
    if is_mode_append:
        log_verbose("'append-all' mode has been enabled")
        return 2

    return 0


def commit_ip_name_changes(group: str) -> None:
    """
    Commit changes to the database.

    :param group: the name of the group
    """

    log_verbose(f"Commiting IP changes for all {group}s")
    sorted_group_ips = dict(
        sorted(groups.get_group(group).items(), key=lambda item: int(item[0]))
    )
    for ip in sorted_group_ips.values():
        insert_queries, update_query = ip.generate_ip_query(connect_db())
        if insert_queries is not None:
            db_connection = connect_db()
            if type(db_connection) is Connection:
                with db_connection.cursor() as cursor:
                    for query in insert_queries:
                        if query != "":
                            cursor.execute(query)
                            db_connection.commit()
                    if update_query != "":
                        cursor.execute(update_query)
                        db_connection.commit()
                    db_connection.close()
    log_verbose(f"Successfully commited IPs for all {group}s")


def main() -> None:
    """Main function."""
    global args
    global is_verbose
    argparser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description=desc
    )
    conflict_resolution_group = argparser.add_mutually_exclusive_group()
    conflict_resolution_group.add_argument(
        "-a",
        "--append-all",
        action="store_true",
        help="append all conflicting IPs to the existing allowed IP list",
    )
    conflict_resolution_group.add_argument(
        "-s",
        "--skip-all",
        action="store_true",
        help="skip all operations in case of conflicts",
    )
    argparser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="enable verbose output",
    )

    args, macros = argparser.parse_known_args()
    is_verbose = args.verbose
    is_skip_all = args.skip_all
    is_append_all = args.append_all

    log_verbose("verbose mode has been enabled")
    conflict_handler = get_conflict_resolution_mode(is_skip_all, is_append_all)
    connect_db()

    log("fetching data...")
    for group in groups.get_group_names():
        get_group_ip_information(group)
        update_group_ips(group, conflict_handler)
        commit_ip_name_changes(group)


if __name__ == "__main__":
    main()
