#!/usr/bin/env python
# _*_ coding: UTF-8 _*_
"""
@Project : game_algorithm
@File : config.py
@Author: nudt-52
@Date: 2025/8/7 16:00
"""
import os
import gymnasium as gym
import numpy as np
import logging
import importlib

from corekit.config.base import TaskConfigAPI
from corekit.config.train import TrainTaskConfig
from corekit.utils.path import WORK_DIR
from corekit.utils.typings import AgentMap, Policies, PolicySpec

""""！！！选手需要在下方import要使用的强化学习智能体和规则智能体！！！"""
# 红方强化学习智能体
from agent.rl_agent.fighter_agent.Fighter import Agent as FighterAgent  # 需在此模块中有智能体脚本
from agent.rl_agent.sentry_agent.Sentry import Agent as SentryAgent  # 需在此模块中有智能体脚本
# 红方规则智能体
from agent.rule_agent.baseline_doctrine_add_mission.add_mission import Agent as MissionAgent
from agent.rule_agent.baseline_doctrine_add_mission.agent_doctrine_set import Agent as DoctrineAgent
from agent.rl_agent.sentry_agent.Escort import Agent as Escort_agent

# 蓝方强化学习智能体
# from agent.rl_agent.bomber_agent.bomber import Agent as BomberAgent
# from agent.rl_agent.fighter_agent.Fighter import Agent as BlueFighterAgent
# 蓝方规则智能体调用
# from agent.rule_agent.blue_airbat.blue_air import Agent as BlueAirBat


@TaskConfigAPI.register(mode='train', category='tune')
class LingYiConfig(TrainTaskConfig):
    # env = LingYiEnv

    @classmethod
    def get_frontend_config(cls) -> dict:
        # TODO 处理前端配置
        return {}

    @classmethod
    def get_space_config(cls, args: dict) -> dict:
        # 获取动作/状态空间类型
        act_space_type: str = args.get('act_space', 'MultiDiscrete')  # 动作空间类型
        obs_space_type: str = args.get('obs_space', 'Box')  # 状态空间类型

        # 动作/状态空间设置
        act_space_map: dict = {
            'MultiDiscrete': gym.spaces.MultiDiscrete([5, 2, 2, 2, 2]),
            'Discrete': gym.spaces.Discrete(80),
            'Box': gym.spaces.Box(low=0, high=1, shape=(5,)),
            'Tuple': gym.spaces.Tuple(gym.spaces.Discrete(80) for _ in range(1))
        }  # 动作空间映射
        obs_space_map: dict = {
            'MultiDiscrete': gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(1908,)),
            'Discrete': gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(1908,)),
            'Box': gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(1908,)),
            'Tuple': gym.spaces.Tuple(gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(1908,)) for _ in range(1))
        }  # 状态空间映射

        # 调整动作/状态空间参数设置
        act_space = act_space_map[act_space_type]
        obs_space = obs_space_map[obs_space_type]

        space_config: dict = {
            'act_space': act_space,
            'obs_space': obs_space
        }
        return space_config

    @classmethod
    def get_callback_config(cls, args: dict) -> dict:
        callback_config = {}
        return callback_config

    @classmethod
    def get_server_config(cls, args: dict) -> dict:
        """
        智能体环境（灵弈环境）配置
        """
        logger = logging.getLogger("Platform")
        logger.setLevel(level=logging.INFO)
        formatter_play = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                                           datefmt="%Y-%m-%d %H:%M:%S")
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        console.setFormatter(formatter_play)
        logger.propagate = False
        logger.addHandler(console)

        # 设置日志输出到文件---------可注释，即不输出到文件
        log_path = os.path.join(os.getcwd(), "logs")
        if not os.path.exists(log_path):
            os.mkdir(log_path)
        
        file_out = logging.FileHandler(r"./logs/simulate.log", mode='w', delay=False)
        file_out.setLevel(logging.INFO)
        file_out.setFormatter(formatter_play)
        logger.addHandler(file_out)

        temp_dict = {
            "game_server": {
                "public_key": "AAAAAA",    # 推演后台用户公钥
                "user_id": "testId",  # 推演后台用户id
                "simulate": {  # 推演环境过程配置
                    "clear_room": False,  # 是否删除服务器所有房间
                    "process_id": 1,  # 进程号 调试用

                    "main_client": True,  # 是主客户端, 开始推演, 设置推演倍速
                    "load_scenario": True,  # 是否需智能体端加载想定, 如需加载想定，一定初始新建重置房间，为False即可中途加入推演
                    "scenario": "蓝方初赛AI版（改进版）",  # "蓝方初赛AI版（改进版）"
                    "scen_check": True,  # 是否检查想定是否存在  确定想定存在的情况时可置为False

                    "episodes": 10,  # 灵弈推演重复运行想定最大局数；规则智能体开发使用此配置；
                    "compression": 4,  # 推演倍速，服务器：0:1, 1:5, 2:10, 3:20, 4:60, 5:'max'
                    "need_compression": True,  # 是主客户端，且需要由python端设置推演倍速
                    "decision_interval": 60,  # 智能体决策时间间隔（单位：秒）
                    "step_interval": 2,  # 智能体决策步输出信息间隔步数
                    "reset_episodes": 10,  # 每多少局删除房间后再新建房间

                    "training": True,  # 是否训练模式，agent的is_done函数（智能体自主判断一局推演结束）生效，否则不生效
                    "testing": True,  # True, 并且agent中代码出错直接抛出异常，中断运行, False: 已确认代码在服务器正常时可运行，出错则进行下一局

                    "events": False,  # 是否需要态势中接收到想定中设置的事件、触发器、条件和动作信息
                    "need_all_side": True,  # 训练过程中，本推演方是否需要其他推演方事件或实体单元信息

                    "save_result": False,  # 是否保存红蓝方对战结果到数据库
                },
                "server": {  # 服务器连接、通信配置
                    "server_ip": "127.0.0.1",  # 管理服务器ip地址，如本机配置推演服务器，则不用修改
                    "server_port": 20000,        # 管理服务器端口
                    "room_name":  "wargame",       # 房间名
                    "data_link_type":  3,       # 态势数据连接通信方式 DDS=0  组播=1  ZMQ:3
                    "continuous": 0,  # 服务器间断推演:0 仿真在智能体计算、决策期间暂停等待，决策完毕后再开始推演一段时间（智能体决策间隔时间）
                                      # 服务器连续推演:1 不等待智能体决策，可能造成决策的依据态势慢于当前推演态势
                    "save_situation_type": 0,  # 0:不保存  1:仅完整帧，9：仅状态帧,  99：所有收到态势
                },
                "logger": logger,  # logging日志
            }
        }
        cls.add_RL_agents(temp_dict['game_server'])
        return temp_dict

    @classmethod
    def add_RL_agents(cls, train_config):
        """
        把强化学习智能体配置，放入配置中
        """
        rl_config = cls.get_rl_agent()
        train_config["RL_agents"] = rl_config

    @classmethod
    def get_rl_agent(cls):
        """
        用户需要在下方配置使用的强化学习智能体以及控制的单元，智能体类
        """
        rl_dict: dict = \
            {
                "红方": {
                    "agents": {
                        # 'redFighterT50': {  # 强化学习智能体描述--此为战斗机，用于空战学习训练
                        #         "class": FighterAgent,  # 智能体类，用户基于基类RLAgentBase开发的子类，
                        #         "unit": ["米格-29型 #1", "米格-29型 #2", "米格-29型 #3","米格-29型 #4",
                        #                  "米格-29型 #5", "米格-29型 #6"],
                        #         # 单元名列表，智能体控制的单元，强化学习会给每个单元建立一个智能体
                        #         "contact": [],  # 情报名列表，情报参数, 可以客户端里智能体调度动态配置，示例：['空中#1', '空中#2']
                        #         "point": []  # 参考点名列表，参考点参数, 可以客户端里智能体调度动态配置，示例：['RP-112', 'RP-113', 'RP-114']
                        #     },
                        # 可以继续添加红方强化学习智能体，注意：智能体越多可能越不容易奖励值收敛，不容易训练出好的效果
                        'redSentryA100': {  # 强化学习智能体描述--此为战斗机，用于空战学习训练
                            "class": SentryAgent,  # 智能体类，用户基于基类RLAgentBase开发的子类，
                            "unit": ["A-100 Premier 预警机 [伊尔-476] #1"],
                            # 单元名列表，智能体控制的单元，强化学习会给每个单元建立一个智能体
                            "contact": [],  # 情报名列表，情报参数, 可以客户端里智能体调度动态配置，示例：['空中#1', '空中#2']
                            "point": []  # 参考点名列表，参考点参数, 可以客户端里智能体调度动态配置，示例：['RP-112', 'RP-113', 'RP-114']
                        },
                        # 可以继续添加红方强化学习智能体，注意：智能体越多可能越不容易奖励值收敛，不容易训练出好的效果
                    },
                    "enemy": "蓝方"  # 敌方，用于强化学习训练时状态输入和奖励计算
                },
                # "蓝方": {
                #     "agents": {
                #         'blueFighterF22': {  # 强化学习智能体描述
                #                 "class": BlueFighterAgent,  # 智能体类，用户基于基类RLAgentBase开发的子类，
                #                 "unit": ["F-22A型 #1", "F-22A型 #2"],  # 单元名列表，强化学习会给每个单元建立一个智能体
                #                 "contact": [],  # 情报名列表，情报参数, 可以客户端里智能体调度动态配置
                #                 "point": []  # 参考点列表，参考点参数, 可以客户端里智能体调度动态配置
                #             },
                #         'blueFighterF35': {  # 强化学习智能体描述
                #             "class": BlueFighterAgent,  # 智能体类，用户基于基类RLAgentBase开发的子类，
                #             "unit": ["F-35FA 型 #1", "F-35C 型 #2"],  # 单元名列表，强化学习会给每个单元建立一个智能体
                #             "contact": [],  # 情报名列表，情报参数, 可以客户端里智能体调度动态配置
                #             "point": []  # 参考点列表，参考点参数, 可以客户端里智能体调度动态配置
                #         }
                #     },
                #     "enemy": "红方"  # 敌方，用于强化学习训练时状态输入和奖励计算
                # },
            }
        return rl_dict

    @classmethod
    def get_rule_dict(self):
        """
        规则智能体配置
        用户需要在下方配置使用的规则智能体以及单元参数、情报或参考点（区域），智能体类

        以下'redFighterT50'示例为：在智能体开发训练平台中进行多局强化学习智能体推理应用（而不是客户端中导入智能体），可调试查看智能体效能和输出；
                操作方式为将get_rl_agent函数中的强化学习配置移动到函数get_rule_dict函数中；
                需确认强化学习智能体中已实现加载神经网络模型（可查看例子Fighter.py中，self.do_ai_apply设置为True，这样加载训练好的模型，
                在step函数中进行模型决策）
        """
        rule_dict = {
            "红方": {
                'redmission':  # 规则智能体描述---可以用户自定义
                    {
                        "class": MissionAgent,  # 智能体类，用户基于基类RuleAgentBase开发的子类
                        "unit": [],  # 单元名列表, 可以客户端里智能体调度动态配置，这里此规则智能体无需单元参数
                        "contact": [],  # 情报名列表, 可以客户端里智能体调度动态配置
                        "point": []  # 参考点名列表, 可以客户端里智能体调度动态配置
                    },
                'reddoctrine':  # 规则智能体描述---可以用户自定义
                    {
                        "class": DoctrineAgent,  # 智能体类，用户基于基类RuleAgentBase开发的子类
                        "unit": [],  # 单元名列表, 可以客户端里智能体调度动态配置，这里此规则智能体无需单元参数
                        "contact": [],  # 情报名列表, 可以客户端里智能体调度动态配置
                        "point": []  # 参考点名列表, 可以客户端里智能体调度动态配置
                    },
                'redescort':  # 规则智能体描述---可以用户自定义
                    {
                        "class": Escort_agent,  # 智能体类，用户基于基类RuleAgentBase开发的子类
                        "unit": ['Su-75 Checkmate 战斗机 #1','Su-75 Checkmate 战斗机 #2','Su-75 Checkmate 战斗机 #3'],  # 单元名列表, 可以客户端里智能体调度动态配置，这里此规则智能体无需单元参数
                        "contact": [],  # 情报名列表, 可以客户端里智能体调度动态配置
                        "point": []  # 参考点名列表, 可以客户端里智能体调度动态配置
                    },
                # 'redSentryA100': {  # 强化学习智能体描述--此为战斗机，用于空战学习训练
                #     "class": SentryAgent,  # 智能体类，用户基于基类RLAgentBase开发的子类，
                #     "unit": [],  # "A-100 Premier 预警机 [伊尔-476] #1"
                #     # 单元名列表，智能体控制的单元，强化学习会给每个单元建立一个智能体
                #     "contact": [],  # 情报名列表，情报参数, 可以客户端里智能体调度动态配置，示例：['空中#1', '空中#2']
                #     "point": []  # 参考点名列表，参考点参数, 可以客户端里智能体调度动态配置，示例：['RP-112', 'RP-113', 'RP-114']
                # },
            },
        }
        return rule_dict

    @classmethod
    def get_rule_config(cls, args):
        """生成规则智能体"""
        config_data: dict = {'rule_config': cls.get_rule_dict()}
        return config_data

    @classmethod
    def get_multiagent_dict(cls, args: dict):
        RL_name_list = []
        full_list = []
        for item in args['agent_map']:
            RL_name_list.append(args['agent_map'][item].agent_name)
        id_list = cls.map_dict(RL_name_list)
        full_list.append(RL_name_list)
        full_list.append(id_list)
        RL_agent_info = cls.get_rule_agent_info(full_list, args)
        return RL_agent_info

    @classmethod
    def get_train_config(self, args: dict) -> dict:
        """"！！！选手可以在下面配置训练算法参数(可选)！！！"""
        train_config: dict = {
            # 训练参数
            "alg": "PPO",  # 为训练使用的强化学习算法，默认为PPO算法；
            """
            可以选择的方案包括SAC,DQN,PPO,QMIX,IMPALA,TD3,DDPG（不推荐，速度较慢）强化学习算法，选手可以根据自己的要求自定义配置算法；
            需要注意的是如果选手使用的为SAC算法时强化学习智能体的动作空间需要为Discrete类型，
            在初始化过程中动作空间需要使用gym.spaces.Discrete进行初始化；
            """
            "engine": "LingYiWarGame",  # 引擎名
            "num_workers": 1,  # 房间数量，(超过训练计算机性能上限容易卡死); 配置超过1一般"training_mode"需配置"normal"或"cluster"
            "max_step": 180,  # 每一局推演运行的最大决策次数
            "train_batch_size": 120,  # 进行策略更新时用于训练的样本数量；即决策多少次训练、保存一次模型
            "training_mode": "local",  # local:本地训练, normal:本地训练(不可调试), cluster:集群训练, 需在配置文件中指定集群head节点ip,port
            "api": "tune",  # 测试分布式新增
            "framework": "tf",  # 机器学习框架；'tf'代表'tensorflow', 或者'torch'代表'pytorch'
            "lr": 5e-03,  # 强化学习过程中策略更新的学习率;
            "no_done_at_end": False,  # 是否在episode结束时忽略done信号；设置为True, 即使环境触发done状态（通常表示任务完成或失败）。
                                      # 训练过程仍会继续执行，不会立即终止
            "gamma": 0.95,  # 折扣因子，用于均衡训练过程中未来奖励的当前价值，决定当前智能体在训练过程中即使奖励和未来奖励的重要程度
            "lambda": 0.9,  # 用于控制优势估计的偏差和方差之间的权衡；
            "sgd_minibatch_size": 10,  # 决定每次更新模型参数时使用的样本数量；
            "store_model_interval": 1,  # 模型保存间隔
            "store_model_num": 1,  # 保存模型的最大数量
            "load_pretrained": False,  # 设置是否读取预训练模型
            "rule_engine": False,  # 规则引擎触发按钮
            # 停止条件
            "stop_iters": 10000,  # 设置停止前的最大训练模型次数；和之前参数"train_batch_size"（即决策多少次训练、保存一次模型）有关；
            "stop_timesteps": 1200000,  # 设置停止前的最大决策次数；当某次训练模型时，发现决策步数不小于此配置数据时，结束训练
            "stop_reward": 10000,  # 设置停止前的平均奖励值；当平均回合奖励达到此值后，结束训练
            # 硬件资源
            "num_envs_per_worker": 1,
            "num_cpus_per_worker": 1,
            "num_gpus_per_worker": 0,
            "num_cpus": 6,
            "num_gpus": 0,  # 分配到训练程序的gpu数量，可以是小数，例如：0.1
            # 检查点配置
            "checkpoint_config": {
                "num_to_keep": 3, "checkpoint_frequency": 1, "checkpoint_score_attribute": 'training_iteration'
            },
            "log_path": os.path.join(WORK_DIR, "log")  # 日志路径args
        }  # 训练默认设置
        if len(args['agent_map']) == 0:
            args['rollout_fragment_length'] = 1000
            train_config['train_batch_size'] = 1000
        return train_config

    @classmethod
    def get_alg_config(cls, args: dict):
        if args['alg'] == 'DQN' or args['alg'] == 'DDPG' \
                or args['alg'] == 'QMIX' or args['alg'] == 'TD3':
            args['num_steps_sampled_before_learning_starts'] = args['learning_starts']
            del args['learning_starts']

    @classmethod
    def get_class_file_path(cls, class_agent):
        """
        获取类定义所在的源文件路径。
        参数:
            cls (type): 需要查询的类类型（例如 MyClass）
        返回:
            str: 类定义的文件路径（例如 '/path/to/mymodule.py'）
        """
        module_name = class_agent.__module__
        try:
            # 动态导入模块
            module = importlib.import_module(module_name)
            # 获取模块的文件路径
            file_path = getattr(module, "__file__", None)
            if file_path:
                file_path = os.path.dirname(file_path)
                return os.path.abspath(file_path)  # 返回绝对路径
            else:
                return "无法获取%s的文件路径（可能是内置模块或动态生成的模块）" % module_name
        except ModuleNotFoundError:
            return "模块 %s 未找到" % module_name

    @classmethod
    def get_multiagent_config(cls, args: dict) -> dict:
        """通过该接口读取强化学习智能体"""
        rl_agent_config = cls.get_rl_agent()
        """用于保存所有智能体"""
        from engine.LingYi.entitys.deduce_RL import Deduction
        agents = Deduction.create_rl_agents(rl_agent_config, args)

        if len(agents) != 0:
            agent_map: AgentMap = {}  # dict[str, Agent_instance]
            for index, agent in enumerate(agents):
                agent_id = agent.uid
                agent_map[agent_id] = agent

            policy_module_path = {}  # 策略ID-->策略存储模型路径

            # agent description列表
            agent_des_list = []
            agent_map_policy_dict = {}
            policies: Policies = {}
            for index, agent in enumerate(agents):
                agent_id = agent.uid
                if agent.agent_type not in agent_des_list:
                    policy_id: str = "policy_%d" % len(agent_des_list)  # 智能体对应的策略
                    agent_des_list.append(agent.agent_type)
                else:
                    policy_index = agent_des_list.index(agent.agent_type)
                    policy_id: str = "policy_%d" % policy_index  # 智能体对应的策略
                module_path = cls.get_class_file_path(type(agent))
                policy_module_path[policy_id] = os.path.join(module_path, 'model', agent.agent_type)
                agent_map_policy_dict[agent_id] = policy_id
                if policy_id not in policies.keys():
                    policies[policy_id] = PolicySpec(
                        observation_space=agent.observation_space,
                        action_space=agent.action_space,
                        config={
                            "model": {"custom_model": agent.model},
                        }
                    )
                    policies[policy_id].agent_type = agent.agent_type  # 另外存储策略的算法描述

            def policy_mapping_fn(_agent_id, episode, worker, **kwargs):
                return agent_map_policy_dict[_agent_id]

            def is_policy_to_train(pid, batch=None):
                return pid == list(policies.keys())

            multiagent_config = {
                "agent_map": agent_map,
                "policies": policies,
                "policy_mapping_fn": policy_mapping_fn,
                "policies_to_train": list(policies.keys()),  # policies_to_train,
                "policy_module_path": policy_module_path  # 用于存储策略的模型路径
            }
            return multiagent_config
        elif len(agents) == 0:
            # 全部为规则智能体的情况，不存在强化学习智能体
            policies: Policies = {}
            """生成一个示例策略，运行过程中不进行训练"""
            policies['0'] = PolicySpec(observation_space=gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(316,)),
                                       action_space=gym.spaces.MultiDiscrete([3, 2, 2]),
                                       )

            def policy_mapping_fn(_agent_id, episode, worker, **kwargs):
                return agent_map_policy_dict[_agent_id]

            def is_policy_to_train(pid, batch=None):
                return pid == list(policies.keys())

            multiagent_config = {
                "agent_map": {},
                "policies": policies,
                "policy_mapping_fn": policy_mapping_fn,
                "policies_to_train": list(policies.keys())  # policies_to_train,
            }
            return multiagent_config
