#!/usr/bin/python3
import argparse
import configparser
import logging
import os
import sys
import time

import pymysql.cursors

config = configparser.ConfigParser()
config.read("/etc/mysql/sipwise_extra.cnf")
mysqluser = config["client"]["user"]
mysqlpassword = config["client"]["password"]
parser = argparse.ArgumentParser(
    prog="Quick mysql replica skipper",
    description="Try to skip known errors and restore replication",
)
parser.add_argument("-p", "--port", dest="port", default=3306)
args = parser.parse_args()
cmdargs = vars(args)
myport = int(cmdargs["port"])

# 1062 (Duplicate entry) – This error occurs when inserting a duplicate primary
# key or unique value.  Skipping it can be safe when handling bulk inserts
# where duplicates are expected and should be ignored.

# 0 (Generic catch-all) – Error code 0 is often used as a placeholder for
# undefined errors.  Ignoring it is generally harmless since it does not
# indicate a specific failure.

safeskiperror = [1062, 0]


def get_envvars(
    env_file=".env",
    set_environ=True,
    ignore_not_found_error=False,
    exclude_override=(),
):
    """
    Set env vars from a file
    :param env_file:
    :param set_environ:
    :param ignore_not_found_error: ignore not found error
    :param exclude_override: if parameter found in this list,
     don't overwrite environment
    :return: list of tuples, env vars
    """
    env_vars = []
    try:
        with open(env_file) as f:
            for line in f:
                line = line.replace("\n", "")
                if not line or line.startswith("#"):
                    continue
                # Remove leading `export `
                if line.lower().startswith("export "):
                    key, value = (
                        line.replace("export ", "", 1).strip().split("=", 1)
                    )
                else:
                    try:
                        key, value = line.strip().split("=", 1)
                    except ValueError:
                        logging.error(
                            f"get_envvars error parsing line: '{line}'"
                        )
                        raise
                if set_environ and key not in exclude_override:
                    os.environ[key] = value
                if key in exclude_override:
                    env_vars.append({"name": key, "value": os.getenv(key)})
                else:
                    env_vars.append({"name": key, "value": value})
    except FileNotFoundError:
        if not ignore_not_found_error:
            raise

    return env_vars


neededinfo = get_envvars(
    env_file="/etc/default/ngcp-roles",
    set_environ=False,
    ignore_not_found_error=False,
)
clusternodes = []
for iteritem in neededinfo:
    if iteritem["name"] == "NGCP_HOSTNAME":
        ngcpself = iteritem["value"].strip('"')
        clusternodes.append(ngcpself)
    if iteritem["name"] == "NGCP_PEERNAME":
        ngcppeer = iteritem["value"].strip('"')
        clusternodes.append(ngcppeer)


def replicasinfo(nodes, mysqluser, mysqlpassword, dbname, port):
    for node in nodes:
        connection = pymysql.connect(
            host=node,
            user=mysqluser,
            password=mysqlpassword,
            db=dbname,
            charset="utf8mb4",
            cursorclass=pymysql.cursors.DictCursor,
            port=port,
        )
        with connection.cursor() as cursor:
            sql = "show all slaves status"
            cursor.execute(sql)
            result = cursor.fetchall()
        connection.close()
    return result


def sqlsstatements(node, mysqluser, mysqlpassword, dbname, sqlstate, port):
    print("send following sql to MySQL")
    print(sqlstate)
    connection = pymysql.connect(
        host=node,
        user=mysqluser,
        password=mysqlpassword,
        db=dbname,
        charset="utf8mb4",
        cursorclass=pymysql.cursors.DictCursor,
        autocommit=True,
        port=port,
    )
    with connection.cursor() as cursor:
        for state in sqlstate:
            cursor.execute(state)
            result = cursor.fetchall()
        connection.close()
        return result


def checkreplica(node, mysqluser, mysqlpassword, dbname, sqlstate, port):
    myresponse = sqlsstatements(
        node, mysqluser, mysqlpassword, dbname, sqlstate, port
    )
    returndict = {}
    returndict["slaverunning"] = False
    returndict["errornum"] = False
    if "Slave_IO_Running" in myresponse[0]:
        if "Slave_SQL_Running" in myresponse[0]:
            slaveio = myresponse[0]["Slave_IO_Running"]
            slavesql = myresponse[0]["Slave_SQL_Running"]
            if slaveio == "Yes" and slavesql == "Yes":
                returndict["slaverunning"] = True
            else:
                errornum = myresponse[0]["Last_Errno"]
                returndict["slaverunning"] = False
                returndict["errornum"] = errornum

    return returndict


allrepinfo = replicasinfo(
    nodes=["localhost"],
    mysqluser=mysqluser,
    mysqlpassword=mysqlpassword,
    dbname="provisioning",
    port=myport,
)
repid = 1
fakename = ""
conname = ""
if len(allrepinfo) > 0:
    for replica in allrepinfo:
        if "Connection_name" in replica:
            if replica["Connection_name"] == "":
                fakename = "default"
                conname = ""
            else:
                fakename = replica["Connection_name"]
                conname = replica["Connection_name"]
        print("Checking replica %s for %s, " % (repid, fakename))
        if "Slave_IO_Running" in replica:
            if "Slave_SQL_Running" in replica:
                statussql = (
                    'SET @@default_master_connection = "%s";' % conname,
                    " show slave status;",
                )
                isreplicarunning = checkreplica(
                    node="localhost",
                    mysqluser=mysqluser,
                    mysqlpassword=mysqlpassword,
                    dbname="provisioning",
                    port=myport,
                    sqlstate=statussql,
                )
                if isreplicarunning["slaverunning"]:
                    print(
                        "Both Slave_IO_Running and Slave_SQL_Running run fine,"
                        "nothing to do for connection %s" % fakename
                    )
                else:
                    if "Last_Errno" in replica:
                        if "Last_IO_Errno" in replica:
                            if replica["Last_IO_Errno"] > 0:
                                print("IO errors are not supported EXIT")
                                sys.exit(3)
                        errornum = replica["Last_Errno"]
                        print(
                            "Error detected in MySQL replication: %s"
                            % errornum
                        )
                        if errornum in safeskiperror:
                            print("Error %s is safe for skip" % errornum)
                            skipsql = (
                                "SET @@default_master_connection = '%s';"
                                % conname,
                                "stop slave;",
                                "SET GLOBAL SQL_SLAVE_SKIP_COUNTER=1;",
                                "start slave;",
                            )
                            print(
                                "Would execute following mysql code on %s"
                                % " ".join(skipsql)
                            )
                            countertry = 1
                            while (
                                isreplicarunning["slaverunning"] is not True
                                and countertry < 20
                            ):
                                myresponse = sqlsstatements(
                                    node="localhost",
                                    mysqluser=mysqluser,
                                    mysqlpassword=mysqlpassword,
                                    dbname="provisioning",
                                    sqlstate=skipsql,
                                    port=myport,
                                )
                                time.sleep(2)
                                statussql = (
                                    'SET @@default_master_connection = "%s";'
                                    % conname,
                                    " show slave status;",
                                )
                                isreplicarunning = checkreplica(
                                    node="localhost",
                                    mysqluser=mysqluser,
                                    mysqlpassword=mysqlpassword,
                                    dbname="provisioning",
                                    port=myport,
                                    sqlstate=statussql,
                                )
                                print(
                                    "Got following response about slave status"
                                )
                                print(isreplicarunning)
                                countertry = countertry + 1
                        else:
                            print(
                                "Error %s is not marked as safe for skip"
                                % errornum
                            )
        repid = repid + 1
