#!/usr/bin/env python3
import argparse
import asyncio
import base64
import json
import os
import random
import signal
import sys
from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
from typing import Any, Dict, List
from uuid import uuid4

import redis
import redis.asyncio as aioredis
from redis.asyncio.client import PubSub
from yaml import BaseLoader, load

from ngcp_task_agent.request import Request
from ngcp_task_agent.response import Response

CONFIG_FILE = '/etc/ngcp-task-agent/backends/redis.yml'
if os.getenv('VIRTUAL_ENV') is not None:
    CONFIG_FILE = 'etc/backends/redis.yml'

HOSTNAME = os.uname().nodename


class Client:
    """Base class for task client.

    Attributes:
        req_channel (str): control channel used by the task.
        By default is 'ngcp-task-agent-redis'
        request (Request): define all the parameters required to run a task
        redis_config (Dict[str, Any]): load the redis backend config
        running (bool): Define whether the client is running or not
        redis (aioredis.Redis): The redis connector
        pubsub (PubSub): The pubsub redis connector
        resp_channel_done (bool): Define whether the response channel is done
        resp_channel_status (str): Define the status of the response
        channel (accepted|rejected|done|error)
        resp_channel_data (str): Response channel data
        resp_nodes_list (List[str]): List of response nodes
        nodes_accepted (Dict[str, bool]): Keep the list of nodes that accepted
        resp_list (List[str]): Contains the list of response status from
        each node
        resp_count (int): Counter each responses
        resp_accepted (bool): If True, the 'accepted' responses has
        been received
        retcode (int): Return code from the client
        verbose (bool): Output verbose
        quiet (bool): Disable output verbose
        loop_timeout (int): loop in case of None reply

    """
    req_channel: str
    request: Request
    redis_config: Dict[str, Any]
    running: bool
    redis: aioredis.Redis = None  # type: ignore
    pubsub: PubSub
    resp_channel_done: bool
    resp_channel_status: str
    resp_channel_data: str
    resp_nodes_list: List[str]
    nodes_accepted: Dict[str, bool]
    resp_list: List[str]
    resp_count: int
    resp_accepted: bool
    retcode: int
    verbose: bool
    quiet: bool
    loop_timeout: int

    def __init__(self) -> None:
        """Constructor for Client class.

        Args:
            None

        Returns:
            None

        """
        self.request = Request()
        self.req_channel = 'ngcp-task-agent-redis'
        self.running: bool = False
        self.resp_channel_done: bool = False
        self.resp_channel_status = 'none'
        self.read_config()
        self.resp_count = 0
        self.resp_list = []
        self.resp_nodes_list = []
        self.nodes_accepted = {}
        self.resp_accepted = False
        self.verbose = False
        self.quiet = False
        self.loop_timeout = 0

    def read_config(self) -> None:
        """Fetches redis backend config.

        Args:
            None

        Returns:
            redis_config (Dict): redis backend config

        """
        with open(CONFIG_FILE, 'r', encoding='utf-8') as file:
            self.redis_config = load(file, BaseLoader)

    async def _on_accepted(self, resp: Response) -> None:
        """Handles accepted responses.

        Args:
            resp (Response Obj): The response message

        Returns:
            None

        """
        # If node 'src' already accepted,
        # raise an error
        if resp.src in self.nodes_accepted:
            self.resp_list.append('error')
            self.resp_count = self.resp_count + 1
            raise ValueError('Multiple accepted received!')

        # Save that node 'src' accepted
        self.nodes_accepted[resp.src] = True
        # intercept the first 'accept' reply
        # and get the list of nodes where I should
        # get the response from: response_nodes_list
        if not self.resp_accepted and len(resp.data) > 0:
            # save responded nodes list and increase number of reply
            self.resp_nodes_list = resp.data
            self.resp_accepted = True

    async def _on_done(self, resp: Response) -> None:
        """Handles done responses.

        Args:
            resp (Response Obj): The response message

        Returns:
            None

        """
        if (
            resp.chunk == resp.chunks and
            self.resp_accepted and self.resp_count >= 0
        ):
            # This is where we enter for 'done' replies
            # when accepted response has
            # been received already.
            self.resp_list.append(resp.status)
            self.resp_count = self.resp_count + 1
        elif resp.chunk != resp.chunks:
            # Do not support chunked replies
            self.resp_list.append('error')
            raise ValueError('Unsupported chunked message received')

    async def _on_reject(self, resp: Response) -> None:
        """Handles reject responses.

        Args:
            resp (Response Obj): The response message

        Returns:
            None

        """
        if self.resp_accepted and self.resp_count >= 0:
            # This is where we enter 'reject',
            # when accepted response has
            # been received already.
            self.resp_list.append(resp.status)
            self.resp_count = self.resp_count + 1

    async def _on_error(self, resp: Response) -> None:
        """Handles error responses.

        Args:
            resp (Response Obj): The response message

        Returns:
            None

        """
        # This is an error, we need to count
        # this as error response reply.
        self.resp_list.append(resp.status)
        self.resp_count = self.resp_count + 1

    async def _check_remaining_responses(self, resp: Response) -> None:
        """Collects overall responses status.

        Args:
            resp (Response Obj): The response message

        Returns:
            None

        """
        if (
            self.resp_count == len(self.resp_nodes_list) or
            not self.resp_accepted
        ):
            # Let's collect the overall status only when
            # number of responses are reached or in case
            # we get all errors and no accepted at all.
            # After that, we are done, we set
            # resp_channel_done = True
            collective_status = all(x == 'done' for x in self.resp_list)

            if collective_status:
                # if all response are 'done', then
                # task was accomplished
                self.resp_channel_status = 'done'
            else:
                # if not all replies are 'done', then
                # we return 'error'
                self.resp_channel_status = 'error'

            self.resp_channel_done = True

    async def _print_responses(self, resp: Response) -> None:
        """Prints responses from the channel.

        Args:
            resp (Response Obj): The response message

        Returns
            None

        """
        if (
            self.verbose and
            not self.quiet
        ):
            print(f'response({resp.ref}):\n{pformat(vars(resp))}')
        elif (
              not self.quiet and not self.verbose and
              resp.status != 'accepted'
        ):
            if not resp.data:
                print(f'==> from "{resp.src}" [{resp.datetime}]\
                    \nstatus: {resp.status}\n')
            else:
                print(f'==> from "{resp.src}" [{resp.datetime}]\
                    \nstatus: {resp.status}\
                    \ndata: {resp.data}\n')

    async def _read_from_channel(self) -> None:
        """Read response from the channel.

        Args:
            None

        Returns
            None

        """
        while client.running and not self.resp_channel_done:
            await asyncio.sleep(0.01)
            try:
                response = await client.pubsub.get_message(
                    ignore_subscribe_messages=True,
                    timeout=0.0
                )

                # None reply is when a task is skipped
                # because it does not match the destination.
                # We need a way out in case of infinite loop
                if response is None:
                    if self.loop_timeout < 5:
                        self.resp_count = self.resp_count + 1
                        self.loop_timeout = self.loop_timeout + 1
                        continue
                    else:
                        self.resp_channel_done = True
                        self.resp_channel_status = self.resp_channel_status
                        continue

                # ch = redis.utils.str_if_bytes(response['channel'])
                data = redis.utils.str_if_bytes(response['data'])
                message = json.loads(data)
                resp = Response()
                resp.__dict__.update(message)

                await self._print_responses(resp)

                on_response: Dict[str, Any] = {
                    'accepted': self._on_accepted,
                    'done':     self._on_done,
                    'rejected': self._on_reject,
                    'error':    self._on_error,
                    'default': lambda resp: f'Unkonwn\
                               response status={resp.status}',
                }
                await on_response.get(resp.status,
                                      on_response['default']
                                      )(resp)

                await self._check_remaining_responses(resp)

            except (TypeError, ValueError) as err:
                if self.verbose:
                    print(f'{type(err)} during response'
                          f' "{resp.status}": {err}')
                    self.resp_channel_status = 'error'
                    self.resp_channel_done = True

    async def shutdown(self) -> None:
        """Gracefully shuts down the client.

        Stops the client connectors

        Args:
            None

        Returns
            None

        """
        self.running = False
        if self.redis:
            await self.redis.aclose()

        for task in asyncio.tasks.all_tasks():
            try:
                task.cancel()
            except asyncio.CancelledError:
                pass


def arg_parsing() -> argparse.Namespace:
    """The function to parse the arguments passed to the script.

    Args:
        None

    Returns:
        args: The parsed arguments

    """
    parser = argparse.ArgumentParser(
        description='Invoke a task sending a request to ngcp-task-agent',
        formatter_class=argparse.RawTextHelpFormatter
    )
    group_data = parser.add_mutually_exclusive_group()
    group_verb = parser.add_mutually_exclusive_group()
    # Define arguments
    parser.add_argument('--task', '-t', dest='task', type=str,
                        help='The task name to invoke, is must be a task\n'
                        'configured under /etc/ngcp-task-agent/tasks',
                        required=True)
    parser.add_argument('--dst', '-d', dest='dst', type=str,
                        help='Destination hostname where the task should\n'
                        'be run. Some examples:\n\n'
                        'dst="prx01a" - delivers the request to prx01a node\n'
                        'dst="prx*" - delivers the request to all prx nodes\n'
                        'dst="*|state=active" - delivers the request to all\n'
                        'active nodes\n'
                        'dst="*|status=online" delivers the request to all\n'
                        'nodes with online status\n'
                        'dst="*|state=active;role=proxy" - delivers the\n'
                        'request to all active proxy nodes\n'
                        'dst="*|state=active;role=proxy+lb" - delivers the\n'
                        'request to all active proxy and lb nodes\n'
                        'dst="prx01a,lb*,db01?|state=active" - delivers the\n'
                        'request to prx01a, all lb nodes\n'
                        'and the active db01 node', required=True)
    parser.add_argument('--src', '-s', dest='src', type=str,
                        help='Source hostname. Default value is local\n'
                        'hostname')
    parser.add_argument('--fb_channel', '-fb', dest='fb_channel', type=str,
                        help='Feedback channel to be used.\n'
                        'If not set, the value is autogenerated')
    group_data.add_argument('--data', '-D', dest='data', type=str,
                            help='Data to be passed to the task request in\n'
                            'text format.')
    group_data.add_argument('--data-from-file', '-F', dest='file_path',
                            type=str, help='Data to be passed to the task\n'
                            'request via a file path.')
    group_data.add_argument('--binary-from-file', '-B', dest='binary_path',
                            type=str, help='Binary data to be passed to the\n'
                            'task request via a file path, encoded as base64\n'
                            'string')
    group_verb.add_argument('--quiet', '-q', dest='quiet',
                            action='store_true', help='Enable quiet output.',
                            default=False)
    group_verb.add_argument('--verbose', '-v', dest='verbose',
                            action='store_true', help='Enable verbose output.',
                            default=False)
    parser.add_argument('--options', '-o', dest='options', metavar='KEY=VALUE',
                        nargs='+', default=['dst_nodes_in_accepted=1'],
                        help='Set task options as key=value pairs.\n'
                        'Task options are sent in request.\n'
                        '(do not put spaces before or after the = sign).\n'
                        'NOTE: option "dst_nodes_in_accepted" is enable\n'
                        'by default.')

    # Parse the command-line arguments
    try:
        args = parser.parse_args()
    except argparse.ArgumentError:
        # Handle the case when the mandatory parameter is not provided
        print('Error: Mandatory parameter is required.\
               Use --help for usage information.')

    return args


def parse_key_value_options(kv_options: str) -> tuple[Any, str]:
    """Parses key=value options.

    Args:
        kv_options (str): The key=value pair

    Returns:
        tuple[Any, str]: tuple of key, value

    """
    items = kv_options.split('=')
    key = items[0].strip()  # Remove blanks around keys
    if len(items) > 1:
        # rejoin the rest:
        value = '='.join(items[1:])

    return (key, value)


async def main(params: argparse.Namespace) -> None:
    """Async main function.

    Task client start

    rc: the return code value, based on the channel response:
        0 - done
        1 - rejected
        2 - error

    Args:
        params (argparse.Namespace): the parsed parameters

    Returns
        None

    """
    rc = 0
    rand = random.randint(0, 10000)
    fb_channel = f'fb_ctrl_{rand}'
    # Set dst_nodes_in_accepted=1 by default.
    request_options: Dict[str, Any] = {}

    # Parse a series of key=value pairs and return
    # a dictionary, if any.
    if params.options:
        for item in params.options:
            key, value = parse_key_value_options(item)
            request_options[key] = value

    if not params.src:
        params.src = HOSTNAME

    if not params.fb_channel:
        params.fb_channel = fb_channel

    try:
        redis_cfg = client.redis_config
        [host, port, use_db] = [redis_cfg[k]
                                for k in ['host', 'port', 'db']]
        client.redis = await aioredis.Redis.from_url(
            f'redis://{host}:{port}/{use_db}'
        )
        client.pubsub = client.redis.pubsub()
    except redis.exceptions.ConnectionError as err:
        print(f'Error: Redis connection creation error: {err}')
        client.retcode = 2
        await client.shutdown()
        return

    # here we prepare the request to the task
    client.request.uuid = str(uuid4())
    client.request.task = params.task
    client.request.dst = params.dst
    client.request.src = params.src
    client.request.options = {
        'feedback_channel': params.fb_channel,
        **request_options,
    }
    if params.data:
        client.request.data = params.data
    elif params.file_path:
        with open(params.file_path, 'r', encoding='utf-8') as path:
            data = path.read()

        client.request.data = data
    elif params.binary_path:
        with open(params.binary_path, 'rb') as path:
            blob = path.read()

        client.request.data = bytes.decode(base64.b64encode(blob))

    client.verbose = params.verbose
    client.quiet = params.quiet

    try:
        await client.pubsub.subscribe(params.fb_channel)
    except redis.exceptions.ConnectionError as err:
        print(f'Error: Redis subscribe connection error: {err}')
        client.retcode = 2
        await client.shutdown()
        return

    asyncio.create_task(client._read_from_channel())

    client.running = True
    if client.verbose:
        print(
            f'request({client.req_channel}):\n{pformat(vars(client.request))}'
        )

    await client.redis.publish(client.req_channel,
                               json.dumps(vars(client.request)))
    while not client.resp_channel_done:
        await asyncio.sleep(0.01)

    # response status: [accepted|rejected|done|error]
    if (
        client.resp_channel_status == 'done' or
        client.resp_channel_status == 'none'
    ):
        rc = 0
    elif client.resp_channel_status == 'rejected':
        rc = 1
    elif client.resp_channel_status == 'error':
        rc = 2

    await client.shutdown()
    client.retcode = rc


if __name__ == '__main__':
    params: argparse.Namespace = arg_parsing()
    client = Client()
    loop = asyncio.new_event_loop()
    executor = ThreadPoolExecutor()
    loop.set_default_executor(executor)
    for signame in ('SIGINT', 'SIGTERM', 'SIGHUP'):
        loop.add_signal_handler(
            getattr(signal, signame),
            lambda: asyncio.ensure_future(client.shutdown())
        )
    try:
        loop.run_until_complete(main(params))
    finally:
        loop.close()
        sys.exit(client.retcode)
