# -*- coding:utf-8 -*-
# Author: wuruizhuo

"""
想定：“尤卡坦半岛争夺战_左侧有蓝色导弹和黄色飞机_后期.scen”，该想定是比赛想定运行到后期时的中间保存想定，可以方便看效果
如果直接用在比赛想定"想定3-30版.scen"中，可以等在火力集中和敌方飞机已经进入射程时使用，保证附近打击目标价值高，体现基于价值目标的优化算法作用
算法：多目标粒子群优化算法
"""
import logging
from engine.LingYi.entitys.global_util import *
from engine.LingYi.entitys.player import Player
import engine.LingYi.entitys.geo as geo
import math
import engine.LingYi.entitys.data.query_data as qd

import numpy as np
from openpyxl import load_workbook, Workbook
import os
import json
import time

from agent.rule_agent.baseline_WTA.MOPSO import BinaryMOPSO


class Agent(Player):
    def __init__(self, side_name, params):
        Player.__init__(self, side_name)
        logging.info('%s:%s play game' % (side_name, self.__class__.__name__))
        self.params = params
        self.group_guid = []  # wta所使用的编组单元guid列表
        self.wta_result = []  # 单步武器分配结果
        self.wta_log = []  # 武器分配下发记录
        self.situation_info = {}  # 收集的态势信息
        self.delete_target_id_lst = []  # 删除已经打击过的单元
        self.group_name = []  # WTA编组的单元名称
        self.blue_score_dict = {}  # 蓝方分值表，存储的所有单元的分值
        self.red_score_dict = {}  # 红方分值表，存储的所有单元的分值
        self.attack_capable_target = []  # 可以开火的单元
        self.check_is_shoot_down = []  # 检查是否击落的单元列表
        self.shoot_down_target = []  # 击落和拦截的单元
        self.check_KeyMsg = []  # 已经使用过的关键事件
        self.execution_time = []  # 记录算法运行时间
        self.wta_is_valid = None  # wta运行结果是都有效
        self.wta_is_done = None  # wta是否结束
        self.save_situation_mode = None  # 是否保存态势的模式，不进行推演

    def initial(self, situation):
        # 初始化参数
        self.wta_result = []
        self.wta_log = []
        self.attack_capable_target = []  # 可以开火的单元
        self.check_is_shoot_down = []  # 检查是否击落的单元列表
        self.shoot_down_target = []  # 击落和拦截的单元
        self.check_KeyMsg = []  # 已经使用过的关键事件
        self.execution_time = []  # 记录算法运行时间
        self.delete_target_id_lst = []  # 删除已经打击过的单元
        self.wta_is_valid = False
        self.wta_is_done = False
        self.save_situation_mode = False  # 是否保存态势的模式，不进行推演
        # self.set_agent_decision_time(10)  # 智能体调度时取消注释，设置决策间隔

        # 导入态势
        self.situation_info = {
            "cur_time": 0,  # 这个参数截击引擎中需要，在这里没有用到，可以去掉
            "period_duration": 30000,  # 这个参数也没有用到，可以去掉
            "target_info": {},  # 来袭目标信息
            "unit_info": {}  # 己方单元信息
        }
        # 设置导弹营条令谨慎开火,避免导弹营自动开火
        if self.client_info['unit']:
            self.group_guid = list(self.client_info['unit'].keys())
            unit_name_list = [unit['name'] for unit in self.client_info['unit'].values()]
            # print('单元列表为%s' % unit_name_list)
            self.agent_message_text(text='单元列表为%s' % unit_name_list)
        else:
            print('未配置单元，请至少选择一个单元')
            self.agent_message_text(text='未配置单元，请至少选择一个单元')
        # 设置WTA编组单元对空谨慎开火
        for guid in self.group_guid:
            unit = self.get_unit(guid)
            unit.doctrine_weapon_control_status_air(WeaponControlStatus.Tight)

        # 加载分值表excel文件，并存储红蓝方分值
        current_file_path = os.path.dirname(__file__)
        score_path = os.path.join(current_file_path, '分值表.xlsx')
        wb = load_workbook(score_path)
        if "蓝方" in wb.sheetnames:
            sheet = wb["蓝方"]  # 打开sheet名称为“蓝方”的sheet
        else:
            print("sheet不存在")
        # 转换为二维列表
        data_dict = {}
        for row in sheet.iter_rows(min_row=2, values_only=True):
            key = row[0]
            value = row[1]
            data_dict[key] = value
        data_dict.pop("总分")  # 最后一个是总分值，不计入内
        self.blue_score_dict = data_dict

        if "红方" in wb.sheetnames:
            sheet = wb["红方"]  # 打开sheet名称为“红方”的sheet
        else:
            print("sheet不存在")
        data_dict = {}
        for row in sheet.iter_rows(min_row=2, values_only=True):
            key = row[0]
            value = row[1]
            data_dict[key] = value
        data_dict.pop("总分")  # 最后一个是总分值，不计入内
        self.red_score_dict = data_dict

    def step(self, time_elapse, situation):
        # 这一行用于获取该步的起始时间，如果不需要计时，可以注释
        start_time = time.time()
        # 更新客户端选择的单元列表
        if self.client_info['unit']:
            self.group_guid = list(self.client_info['unit'].keys())
            unit_name_list = [unit['name'] for unit in self.client_info['unit'].values()]
            print('单元列表为%s' % unit_name_list)
            self.agent_message_text(text='单元列表为%s' % unit_name_list)
        else:
            self.group_guid = []
            print('未配置单元，请至少选择一个单元')
        # 评估时获取蓝方态势，可以在训练时获取蓝方态势用于效果评估，该智能体暂时没有使用，选手可以自行使用
        # blue_player = self.current_players["蓝方"]
        # blue_KeyMsg = blue_player.KeyMsg
        # blue_weapon_destroy = blue_KeyMsg.get(99, {})
        """效果评估信息筛选"""
        # 先判断是否开火，KeyMsg的87为武器开火事件
        fire_event_eval(self, situation)
        # 检查已经开火的事件中，武器是否已经销毁以及销毁情况
        weapon_des_event_eval(self)

        # 重置打击列表，每一轮推演的打击列表都只代表当前时间步打击计划，该计划不具备时间延续性
        self.wta_result = []

        """态势收集部分，收集target_info和unit_info，并存入situation_info中"""
        # 获取推演时间
        # self.situation_info["cur_time"] += time_elapse  # time_elapse格式变化，这个参数仅有记录时间的作用，不会影响算法执行
        # 获取敌方态势
        target_info = get_target_info(self, situation)
        # 获取己方态势
        # 首先删除已经被摧毁的单元
        situation_data = situation[0] if isinstance(situation[0], dict) else []
        for unit in situation_data.values():
            if isinstance(unit, dict):
                if (unit.get('name', '') in self.group_name) and (
                        unit.get('mounts', '') == 'DELETE'):
                    self.group_name.remove(unit.get('name', ''))
        # 其次获得仍存活的己方单元态势
        unit_info = get_unit_info(self, situation, target_info)
        # 将态势信息存入situation_info中，situation_info信息仅有记录作用，便于保存态势，与截击引擎呼应使用，对灵弈中的算法执行无影响
        self.situation_info["target_info"] = target_info
        self.situation_info["unit_info"] = unit_info

        # 这里是可以保存加载想定后的初始态势，可以用于在截击引擎中模拟灵弈的态势进行仿真，如果不需要保存态势可以注释
        if self.save_situation_mode:
            save_situation_2json(self)
            self.wta_is_done = True
            return

        # 获取wta编组中所有单元的所有弹药数量
        remaining_ammo = 0
        for unit_guid in (k for k in unit_info.keys() if k):
            for weapon_dbid in (k for k in unit_info[unit_guid].keys() if k):
                remaining_ammo += unit_info[unit_guid][weapon_dbid].get('ammunition_amount', 0)

        # 已方弹药是否全部打完/敌人消失为判断分配方案是否结束的条件
        if remaining_ammo > 0:
            if target_info:
                self.agent_message_text(text='发现空中目标')
                # WTA武器分配和弹药数量分配
                p_matrix, a_matrix, v, ammo, available_alloc = get_attack_prob(unit_info, target_info)
                # 判断是否有可以打击的目标
                if np.any(a_matrix) > 0:
                    """算法优化部分，这里使用了多目标粒子群优化算法求解武器目标分配结果"""
                    # 运行算法，swarm_size和max_iter可以自己设置
                    solver = BinaryMOPSO(p_matrix, v, ammo, swarm_size=10, max_iter=5)
                    best_solution, best_obj = solver.run()
                    # 获得可以执行打击的实际打击方案，排除无法打击和可以打击但概率为0的情况
                    solution = best_solution * a_matrix * p_matrix.astype(bool)
                    self.wta_result = process_matrix_solution(solution, available_alloc)
                    # self.wta_result = wta_close_dist(self.situation_info)   # 这个算法是最近距离分配算法，智能体一未使用到，可以作为参考
                    """指令下发部分"""
                    # 下发武器分配方案
                    # 判断是否有可以打击的目标,如果可以打击，将下发打击指令
                    num_total_attack = len(self.wta_result)
                    if num_total_attack > 0:
                        self.wta_is_valid = True  # 这个标志位暂时没有什么作用，在静态WTA中有效，多阶段WTA无效
                        self.agent_message_text(text='分配方案下发')
                        print("分配方案下发")
                        run_attack_command(self)
                else:
                    self.agent_message_text(text='目标不可打击')
                    print("目标不可打击")
            else:
                print("未探测到空中目标")
                self.agent_message_text(text='未探测到空中目标')
        else:
            self.wta_is_done = True

        # 终止条件判定，为了验证拦截概率，暂时用地空导弹营(KN-06型地空导弹 [Pongae-5型防空导弹])作为结束标志，
        # 实际使用时，可以设计当所有单元均被摧毁时结束该智能体
        unit_exist = [
            unit.get('name', '') for unit in situation_data.values()
            if isinstance(unit, dict)
               and (unit.get('name', '') == "地空导弹营(KN-06型地空导弹 [Pongae-5型防空导弹])")  # Pr.2235.0 “戈尔什科夫海军元帅”级护卫舰  地空导弹营(KN-06型地空导弹 [Pongae-5型防空导弹])
        ]
        if not unit_exist:
            self.wta_is_done = True
        # 这三行用于获取该步的终止时间，并通过与起始时间相减获得运行时间，如果不需要计时，可以注释
        end_time = time.time()
        execution_time = end_time - start_time
        self.execution_time.append(execution_time)

    def deduction_end(self):
        # 一局推演结束后将调用此接口
        """效果评估"""
        # 获得可以打击的所有目标单元列表
        attack_capable_target = list(set(self.attack_capable_target))
        # 计算可打击的所有单元价值之和，计算成功拦截的所有目标价值之和
        target_value = 0
        shoot_down_value = 0
        for i in attack_capable_target:
            target_value += i[1]
            if i[0] in self.shoot_down_target:
                shoot_down_value += i[1]
        target_num = len(attack_capable_target)  # 这是所有可打击的目标数量
        shoot_down_num = len(self.shoot_down_target)  # 这是所有成功拦截的目标数量
        # 拦截率和拦截价值率计算
        if target_num:
            interception_rate = shoot_down_num / target_num
            interception_value_rate = shoot_down_value / target_value
        else:
            interception_rate = 0
            interception_value_rate = 0
        print(f"拦截率是 {interception_rate} \n 拦截价值率是 {interception_value_rate}")

        # 保存算法运行时间到同目录的execution_time.xlsx文件中
        file_name = 'execution_time.xlsx'
        file_path = os.path.join(os.path.dirname(__file__), file_name)
        if not os.path.exists(file_path):
            wb = Workbook()
            ws = wb.active
            headers = ['单步运行时间']
            ws.append(headers)
            wb.save(file_path)
        append_to_excel(file_path, self.execution_time)
        print("保存单步运行时间至文件%s" % file_path)

        # 保存算法拦截率和拦截价值率到同目录的interception_data.xlsx文件中，如果运行多次可以追加到同一个文件中
        file_name = 'interception_data.xlsx'
        file_path = os.path.join(os.path.dirname(__file__), file_name)
        if not os.path.exists(file_path):
            wb = Workbook()
            ws = wb.active
            headers = ['拦截率', '拦截价值率']
            ws.append(headers)
            wb.save(file_path)
        append_to_excel(file_path, [interception_rate, interception_value_rate])
        print("保存拦截结果至文件%s" % file_path)

        logging.info('%s:%s play game result:%d' % (self.Name, self.__class__.__name__, self.TotalScore))

    def is_done(self):
        # 静态WTA中，方案下发后，不再下发
        # 多阶段WTA中，通过判断单元列表self.group_name中的某些单元是否存活，决定是否结束智能体
        if self.wta_is_done:
            return True
        return False


def get_unit_info(agent_obj, situation, target_info):
    """
    遍历situation，获取己方的态势信息
    :return:包含地导营位置和拦截半径的字典
    """
    unit_info = {}
    for unit_guid in agent_obj.group_guid:
        unit = agent_obj.get_unit(unit_guid)
        if not unit:
            continue
        # 获取该单元挂载的对空导弹
        if isinstance(unit.Mounts, dict):
            weapon_dbid_list = [
                (weapons.get('WeaponDBID', 0), weapons.get('CurrentLoad', 0)
                 )
                for mount in unit.Mounts.values()
                for weapons in mount.get("WeaponRecs", {}).values()
                if weapons.get('CurrentLoad', 0) >= 1
                   and weapons.get("AirRangeMax", 0) > 0
            ]
        else:
            weapon_dbid_list = []
        # 将DBID相同的武器合并，武器数量相加
        weapon_dict = {}
        for item in weapon_dbid_list:
            weapon_dict[(item[0])] = weapon_dict.get((item[0]), 0) + item[1]
        # 查询火控通道数
        if isinstance(unit.Sensors, dict):
            sensor_dbid = [(sensor.get("Component", "").get("DBID", ""))
                           for sensor in unit.Sensors.values() if sensor.get("Type", 0) == 2001]
            # 导弹营查询火控通道
            fire_control_channel_num = qd.query_max_contacts_illuminate(sensor_dbid)
        else:
            fire_control_channel_num = 0
        # 对不同类型的weapon分类
        unit_weapon_info = {}
        # 遍历该单元下所有武器，得到对导弹和飞机的拦截能力和拦截概率
        for weapon_dbid, ammunition_amount in weapon_dict.items():
            ability, probability = {}, {}
            # 获取己方导弹信息
            weapon_info = qd.get_weapon_info(weapon_dbid)
            # 遍历所有的来袭导弹，得到拦截能力和拦截概率
            for missile_guid, missile_info in ((g, i) for (g, i) in target_info.items() if
                                               i.get("type", '') == "missile"):
                missile = situation[1][missile_guid]
                unit_target_d = get_horizontal_distance((unit.Lat, unit.Lon),
                                                        (missile.get('Lat', 0), missile.get('Lon', 0)))
                can_intercept = False
                intercept_prob = 0
                # 先判断射程、高度和速度是否在导弹打击范围内，再调用开火条件检查接口，可以节约计算成本
                if pre_check_weapon_fire_condition(weapon_info, missile, unit_target_d):
                    capable_weapon = agent_obj.check_weapon_fire_condition((unit_guid, missile_guid))
                    if weapon_dbid in capable_weapon.keys():
                        can_intercept = True
                        agent_obj.attack_capable_target.append((missile_guid, missile_info['value'],))
                        intercept_prob = calculate_possibility(weapon_dbid, missile, contact_proficiency_level=2,
                                                               unit_location=(unit.Lat, unit.Lon),
                                                               target_location=(
                                                                   missile.get('Lat', 0), missile.get('Lon', 0)),
                                                               unit_target_distance=unit_target_d)
                ability[missile_guid] = can_intercept
                probability[missile_guid] = intercept_prob

            # 遍历所有的来袭飞机，得到拦截能力和拦截概率
            for air_guid, air_info in ((g, i) for (g, i) in target_info.items() if i.get("type", '') == "air"):
                air = situation[1][air_guid]
                unit_target_d = get_horizontal_distance((unit.Lat, unit.Lon), (air.get('Lat', 0), air.get('Lon', 0)))
                can_intercept = False
                intercept_prob = 0
                # 先判断射程、高度和速度是否在导弹打击范围内，再调用开火条件检查接口，可以节约计算成本
                if pre_check_weapon_fire_condition(weapon_info, air, unit_target_d):
                    capable_weapon = agent_obj.check_weapon_fire_condition((unit_guid, air_guid))
                    if weapon_dbid in capable_weapon.keys():
                        can_intercept = True
                        agent_obj.attack_capable_target.append((air_guid, air_info['value'],))
                        intercept_prob = calculate_possibility(weapon_dbid, air, contact_proficiency_level=2,
                                                               unit_location=(unit.Lat, unit.Lon),
                                                               target_location=(air.get('Lat', 0), air.get('Lon', 0)),
                                                               unit_target_distance=unit_target_d)
                ability[air_guid] = can_intercept
                probability[air_guid] = intercept_prob
            trajectory_speed = qd.get_weapon_launch_speed(weapon_dbid)  # # 获取发射速度，单位：海里/小时，即节
            intercept_radius = weapon_info[0]  # 获取拦截半径，单位: 海里  ,weapon_info的第一个是AirRangeMax拦截范围
            # 构建该己方单元的态势信息
            unit_weapon_info[weapon_dbid] = {
                "loc": (unit.Lat, unit.Lon),  # 纬度，经度信息，使用元组存储坐标
                "intercept_radius": intercept_radius * geo.NM2KM,  # 拦截半径，单位：m
                "height": getattr(unit, 'CurrentAltitude', 0),  # 高度，单位：m
                "fire_control_channel_num": fire_control_channel_num,  # 火控通道数接口
                "ammunition_amount": ammunition_amount,  # 剩余弹药数量
                "trajectory_speed": trajectory_speed * geo.NM2KM,  # 发射速度，单位：公里/时
                "interception_ability": ability,  # 拦截能力
                "interception_probability": probability  # 拦截概率
            }
        unit_info[unit_guid] = unit_weapon_info

    return unit_info


def get_target_info(agent_obj, situation):
    """
    整合导弹和飞机的信息
    :param agent_obj,
    :param situation: 部分态势信息
    :return: target_info
    """
    missile_info = get_missile_info(agent_obj, situation)
    air_info = get_plane_info(agent_obj, situation)
    target_info = {**missile_info, **air_info}
    return target_info


def get_missile_info(agent_obj, situation):
    """
    获取来袭导弹的信息
    :param agent_obj,
    :param situation: 部分态势信息
    :return: missile_info 目标信息
    """
    missile_info = {}
    missile_id_list = []
    # 获取情报接触的导弹guid列表, 排除已经进行打击的来袭导弹
    if isinstance(situation[1], dict):
        targets = (t for t in situation[1].items() if not isinstance(t, str))
        missile_id_list.extend(t[0] for t in targets if
                               (t[1].get('Type', -1) == ContactType.Missile.value) and (
                                       t[0] not in agent_obj.delete_target_id_lst))
        # and (t[1].get('Side', '') == '蓝方' or t[1].get('IDStatus', -1) >= 3))
    if missile_id_list:
        for missile_id in missile_id_list:
            missile = situation[1][missile_id]
            goal = predict_attack_target(missile, situation[0])
            airpok = 0.7  # 默认导弹的平均命中概率是0.7
            missile_value = 30  # 给导弹的默认价值为30
            threat = 80  # 给导弹的默认威胁度是80
            if goal:
                goal_location = [goal.get('lat', 0), goal.get('lon', 0)]
                for name, value in agent_obj.red_score_dict.items():
                    if name in goal.get('name', ''):
                        missile_value = value * airpok  # 拦截价值
                        threat = value * airpok  # 威胁值
                        if threat > 100:
                            threat = 100
                        break
            else:
                goal_location = None
            missile_info[missile_id] = {
                "name": missile.get("Name", ''),
                "type": "missile",
                "dbid": missile.get("DBID", 0),
                "id": missile_id,
                "location": [missile.get("Lat", 0), missile.get("Lon", 0)],
                "height": missile.get('CurrentAltitude', 0),
                "velocity": missile.get('CurrentSpeed', 0),  # 来袭速度，公里/时
                "head": missile.get('CurrentHeading', 0),  # 导弹航向角
                "theta_degrees": missile.get('CurrentHeading', 0) + 90,  # 目标来袭方向与正北方向的夹角
                "threat": threat,  # 来袭目标的威胁值，[0,100]
                "max_num_of_weapons": 2,  # 可分配给该目标的最多的武器数量
                "goal_location": goal_location,  # 目标进攻位置预测
                "disappear_time": missile.get("Age", 300),  # 目标消失时间
                "value": missile_value  # 目标打击价值
            }

    return missile_info


# 获取敌方飞机的态势信息
def get_plane_info(agent_obj, situation):
    """
    获取来袭飞机的信息
    :param agent_obj,
    :param situation: 部分态势信息
    :return: plane_info 目标信息
    """
    plane_info = {}
    plane_id_list = []
    # 获取情报接触的飞机guid列表, 排除已经进行打击的来袭飞机
    if isinstance(situation[1], dict):
        targets = (t for t in situation[1].items() if not isinstance(t, str))
        plane_id_list.extend(t[0] for t in targets if
                             (t[1].get('Type', -1) == ContactType.Air.value) and (
                                     t[0] not in agent_obj.delete_target_id_lst))
        # and (t[1].get('Side', '') == '蓝方' or t[1].get('IDStatus', -1) >= 3))
    if plane_id_list:
        for air_id in plane_id_list:
            air = situation[1][air_id]
            # air = self.get_contact(air_id)
            air_value = 30  # 给飞机的默认价值为30
            # 理论上需要# 当'IDStatus'值大于3时，飞机DBID信息才可用，但为了算法效果，先假设已知态势信息DBID
            # if air['IDStatus'] >= 3:
            air_data = qd.get_aircraft_info(air.get("DBID", 0))
            for name, value in agent_obj.blue_score_dict.items():
                if name in air_data[2]:
                    air_value = value
                    break
            # 根据飞机的类型计算飞机的威胁度，0-100
            air_type = get_air_type(air.get('Type', 0))
            threat = type_threat(air_type) * 100
            goal = get_closest_target(air, situation[0])
            if goal:
                goal_location = [goal.get('lat', 0), goal.get('lon', 0)]
            else:
                goal_location = None
            plane_info[air_id] = {
                "name": air.get("Name", ''),
                "type": "air",
                "dbid": air.get("DBID", 0),
                "id": air_id,
                "location": [air.get("Lat", 0), air.get("Lon", 0)],
                "height": air.get("CurrentAltitude", 0),
                "velocity": air.get("CurrentSpeed", 0),  # 来袭速度，公里/时
                "head": air.get("CurrentHeading", 0),  # 导弹航向角
                "theta_degrees": air.get("CurrentHeading", 0) + 90,  # 目标来袭方向与正北方向的夹角
                "threat": threat,  # 来袭目标的威胁值，[0,100]
                "max_num_of_weapons": 2,  # 可分配给该目标的最多的武器数量, 与来袭目标的威胁值有关，[0,10]
                "goal_location": goal_location,  # 目标进攻位置预测
                "disappear_time": air.get("Age", 300),  # 目标消失时间
                "value": air_value  # 目标打击价值
            }

    return plane_info


def fire_event_eval(agent_obj, situation):
    """
    这是开火事件的评估信息收集
    """
    # 以下是条令设置为谨慎开火时，同时有单元自动开火和wta智能体手动开火的评估信息收集
    # 查找self.group_guid单元列表的开火事件
    fire_event = agent_obj.KeyMsg.get(87, {})
    for event in (e for e in fire_event if e not in agent_obj.check_KeyMsg):
        if event.get('FiringUnitID', '') in agent_obj.group_guid:
            # 保存开火条目，用于检查该武器的销毁情况
            fire_item = [event.get('FiringUnitID', ''), event.get('WeaponDBID', ''),
                         event.get('TargetContactID', ''),event.get('WeaponID', '')]
            agent_obj.check_is_shoot_down.append(fire_item)
            # 检查是否有单元自动开火的目标，将其纳入已经开火的单元列表self.delete_target_id_lst中，避免重复开火
            target_guid = event.get('TargetContactID', '')
            if target_guid not in agent_obj.delete_target_id_lst:
                agent_obj.delete_target_id_lst.append(event.get('TargetContactID', ''))
            # 检查是否有单元自动开火的目标，将其纳入可打击单元列表self.attack_capable_target中，避免因为时间窗口遗漏目标
            # self.attack_capable_target是元组列表，包含(guid, t_value,)，所以这里还要计算自动开火的目标价值，
            # 这里的目标价值计算与后面的一致，可以看后面的注释
            if not any(item[0] == target_guid for item in agent_obj.attack_capable_target):
                if target_guid in agent_obj.situation_info.get("target_info", {}).keys():
                    t_value = agent_obj.situation_info.get("target_info", {}).get(target_guid, {}).get('value', 0)
                    agent_obj.attack_capable_target.append((event.get('TargetContactID', ''), t_value,))
                elif target_guid in situation[1].keys():
                    if situation[1][target_guid].get('Type', -1) == ContactType.Missile.value:
                        t_value = 30  # 给导弹的默认价值为30
                        missile = situation[1][target_guid]
                        goal = predict_attack_target(missile, situation[0])
                        airpok = 0.7  # 导弹的平均命中概率是0.7
                        if goal:
                            for name, value in agent_obj.red_score_dict.items():
                                if name in goal.get('name', ''):
                                    t_value = value * airpok  # 拦截价值
                                    break
                        agent_obj.attack_capable_target.append((event.get('TargetContactID', ''), t_value,))
                    elif situation[1][target_guid].get('Type', -1) == ContactType.Air.value:
                        air = situation[1][target_guid]
                        t_value = 30  # 给飞机的默认价值为30
                        air_data = qd.get_aircraft_info(air.get("DBID", 0))
                        for name, value in agent_obj.blue_score_dict.items():
                            if name in air_data[2]:
                                t_value = value
                                break
                        agent_obj.attack_capable_target.append((event.get('TargetContactID', ''), t_value,))
                    else:
                        # 因为这些单元还具有对地对潜打击能力，所以会打击其他类型的目标，需要把这个目标移除
                        agent_obj.check_is_shoot_down.remove(fire_item)
                        agent_obj.delete_target_id_lst.remove(event.get('TargetContactID', ''))
            # 如果该事件已经使用过，则不再重复检查该事件
            agent_obj.check_KeyMsg.append(event)

    # 以下为条令设置限制开火时，仅有wta智能体手动开火情况下的评估信息收集
    # fire_event = self.KeyMsg.get(87, {})
    # if len(self.wta_result) != 0:
    #     for wta_item in self.wta_result:
    #         for event in fire_event:
    #             # 判断武器开火事件中是否有下发武器分配方案，如果是，则在wta_item的末尾加上True标志，使得wta_item的长度由3增加到4
    #             if event.get('FiringUnitID', '') == wta_item[0] and event.get('WeaponDBID', '') == wta_item[1] \
    #                     and event.get('TargetContactID', '') == wta_item[2]:
    #                 wta_item.append(True)
    #                 break
    #         self.wta_log.append(wta_item)
    #         self.check_is_shoot_down.append(wta_item)


def weapon_des_event_eval(agent_obj):
    """
    这是武器销毁事件的评估信息收集
    """
    # 以下是条令设置为谨慎开火时，同时有单元自动开火和wta智能体手动开火的评估信息收集
    weapon_destroy = agent_obj.KeyMsg.get(99, {})
    for event in (e for e in weapon_destroy if e not in agent_obj.check_KeyMsg):
        for fire_item in agent_obj.check_is_shoot_down:
            # 检查目标ID
            if event.get('TargetContactID', '') == fire_item[2]:
                # reason字段含义：
                # 原因 1-被浪费(浪费的武器被定义为当在武器飞行之前或飞行过程中目标平台被其他方法损害) 2-被拦截 3-命中目标 4-未命中目标
                if event.get('Reason', -1) == 3:
                    # 目标单元被击中，添加到击中的目标列表中：
                    agent_obj.shoot_down_target.append(fire_item[2])
                elif event.get('Reason', -1) == 1 or event.get('Reason', -1) == 2 or event.get('Reason', -1) == 4:
                    # 目标单元没有被击中，则将该目标从已经删除的列表中移除，重新分配武器
                    if fire_item[2] in agent_obj.delete_target_id_lst:
                        agent_obj.delete_target_id_lst.remove(fire_item[2])
                agent_obj.check_is_shoot_down.remove(fire_item)
                agent_obj.check_KeyMsg.append(event)
                break

    # 以下为条令设置限制开火时，仅有wta智能体手动开火情况下的评估信息收集
    # weapon_destroy = self.KeyMsg.get(99, {})
    # for event in (e for e in weapon_destroy if e not in self.check_KeyMsg):
    #     for wta_item in self.check_is_shoot_down:
    #         # 需要wta下发的武器分配方案已经执行，并且武器销毁时间的单元ID、武器ID和目标ID均一致，才能判断该武器销毁事件来源于武器分配方案
    #         if len(wta_item) == 4 and event.get('FiringUnitID', '') == wta_item[0] and \
    #                 event.get('WeaponDBID', '') == wta_item[1] \
    #                 and event.get('TargetContactID', '') == wta_item[2]:
    #             # reason字段含义：
    #             # 原因 1-被浪费(浪费的武器被定义为当在武器飞行之前或飞行过程中目标平台被其他方法损害) 2-被拦截 3-命中目标 4-未命中目标
    #             if event.get('Reason', -1) == 3:
    #                 # 目标单元被击中，添加到击中的目标列表中：
    #                 self.shoot_down_target.append(wta_item[2])
    #             elif event.get('Reason', -1) == 1 or event.get('Reason', -1) == 2 or event.get('Reason', -1) == 4:
    #                 # 目标单元没有被击中，则将该目标从已经删除的列表中移除，重新分配武器
    #                 self.delete_target_id_lst.remove(wta_item[2])
    #             self.check_is_shoot_down.remove(wta_item)
    #             self.check_KeyMsg.append(event)
    #             break


def wta_close_dist(situation_info):
    """
    武器目标分配决策，基于最小距离原则和弹药分配策略
    :param situation_info: 包含单位信息和目标信息的态势字典
    :return: 分配完成的agent信息字典
    """
    unit_info = situation_info["unit_info"]
    target_info = situation_info["target_info"]

    # 初始化agent信息结构
    agent_info = {unit_id: [] for unit_id in unit_info}
    # 构建目标分配关系
    for target_id, target_data in target_info.items():
        # 寻找最近单位
        closest_unit_id, closest_unit = min(
            unit_info.items(),
            key=lambda item: get_horizontal_distance(item[1]["loc"], target_data["location"])
        )

        # 构建导弹信息
        missile_info = {
            "id": target_id,
            "interception_ability": closest_unit["interception_ability"][target_id],
            "interception_probability": closest_unit["interception_probability"][target_id],
            **{k: target_data[k] for k in (
                "max_num_of_weapons", "location", "velocity",
                "theta_degrees", "threat", "goal_location"
            )}
        }

        if missile_info["interception_ability"] == 1:
            agent_info[closest_unit_id].append(missile_info)

    # 弹药分配
    # 优先级排序逻辑：威胁等级 > 距离近的优先
    for unit_guid, targets in ((k, v) for k, v in agent_info.items() if v):
        target_sort = sorted(targets, key=lambda x: (
            -x['threat'], get_horizontal_distance(x['location'], x['goal_location'])))
        agent_info[unit_guid] = target_sort

        remaining_ammo = unit_info[unit_guid].get('ammunition_amount', 0)
        allocation_list = []
        # 第一轮分配：高优先级目标
        for target in target_sort:
            if remaining_ammo <= 0:
                break
            if target['threat'] >= 85:
                alloc = min(target['max_num_of_weapons'], remaining_ammo)
                target['alloc_ammo'] = alloc
                # self.wta_result[unit_guid][target['ID']] = alloc
                allocation_list.append(target['id'])
                remaining_ammo -= alloc

        # 第二轮分配：剩余目标
        for target in target_sort:
            if remaining_ammo <= 0:
                break
            if target['id'] not in allocation_list:
                alloc = min(1, remaining_ammo)
                target['alloc_ammo'] = alloc
                # self.wta_result[unit_guid][target['id']] = alloc
                allocation_list.append(target['id'])
                remaining_ammo -= alloc

    return agent_info


def get_horizontal_distance(geo_point1, geo_point2):
    """
    求两点的水平距离   Haversine公式
    :param geo_point1: tuple, (lat, lon), 例：(40.9, 140.0)
    :param geo_point2: tuple, (lat, lon), 例：(40.9, 142.0)
    :return: float, KM
    """
    pi = 3.1415926
    degree2radian = pi / 180.0
    earth_radius = 6371137  # 地球平均半径
    lat1 = geo_point1[0] * degree2radian
    lon1 = geo_point1[1] * degree2radian
    lat2 = geo_point2[0] * degree2radian
    lon2 = geo_point2[1] * degree2radian

    difference = lat1 - lat2
    mdifference = lon1 - lon2
    distance = 2 * math.asin(math.sqrt(math.pow(math.sin(difference / 2), 2)
                                       + math.cos(lat1) * math.cos(lat2)
                                       * math.pow(math.sin(mdifference / 2), 2)))
    distance = distance * earth_radius / 1000
    return distance


def run_attack_command(agent_obj):
    """
    下发打击指令，当条令设置为限制开火时，只能通过手动开火接口进行打击
    :param agent_obj
    """
    # 发射弹药数量
    qty_num = 1
    # 依次下发打击指令
    for attack_com in agent_obj.wta_result:
        unit_guid = attack_com[0]
        target_guid = attack_com[2]
        weapon_dbid = attack_com[1]
        unit = agent_obj.get_unit(unit_guid)
        unit.attack_weapon_allocate_to_target(
            target_guid,
            weapon_dbid,
            qty_num
        )
        print(f"单元{unit_guid}打击目标{target_guid}")
        agent_obj.agent_message_text(text=f"单元{unit_guid}打击目标{target_guid}")
        # 添加已经打击过的单元guid，避免重复打击
        agent_obj.delete_target_id_lst.append(target_guid)


def calculate_allocation(agent_obj, unit_info):
    """
    执行分配算法,该算法针对最近距离目标分配，暂时没有使用
    """
    for unit_guid, targets in ((k, v) for k, v in agent_obj.wta_result.items() if v):
        target_sort = sorted(targets, key=lambda x: (
            -x['threat'], get_horizontal_distance(x['location'], x['goal_location'])))
        agent_obj.wta_result[unit_guid] = target_sort

        remaining_ammo = unit_info[unit_guid].get('ammunition_amount', 0)
        allocation_list = []
        # 第一轮分配：高优先级目标
        for target in target_sort:
            if remaining_ammo <= 0:
                break
            if target['threat'] >= 85:
                alloc = min(target['max_num_of_weapons'], remaining_ammo)
                target['alloc_ammo'] = alloc
                # self.wta_result[unit_guid][target['ID']] = alloc
                allocation_list.append(target['id'])
                remaining_ammo -= alloc

        # 第二轮分配：剩余目标
        for target in target_sort:
            if remaining_ammo <= 0:
                break
            if target['id'] not in allocation_list:
                alloc = min(1, remaining_ammo)
                target['alloc_ammo'] = alloc
                # self.wta_result[unit_guid][target['id']] = alloc
                allocation_list.append(target['id'])
                remaining_ammo -= alloc
    return


def calculate_possibility(weapon_id, contact, contact_proficiency_level, unit_location,
                          target_location, unit_target_distance):
    """
    计算武器打击目标的概率
    :param: unit_target_distance, 海里
    """
    if "IDStatus" not in contact or "Type" not in contact:
        return 0
    IDStatus = contact.get("IDStatus", -1)
    contact_dbid = 0
    if IDStatus >= 3:
        contact_dbid = contact.get("DBID", 0)
    contact_type = contact.get("Type", 0)
    weapon_info = qd.get_weapon_info(weapon_id)
    attack_rate = 0  # 攻击成功概率
    if contact_type == 0 or contact_type == 1:
        # 目标是飞机或者导弹
        # 导弹打飞机： ActiveUnit_Weaponry.cs, method_29(float float_0, ref Weapon weapon_6)
        # Weapon.cs vmethod_146(ActiveUnit activeUnit_3, Scenario scenario_1, bool bool_34, ref List<string> list_4)
        air_range = weapon_info[0]
        if not air_range > 0:
            return 0
        base_pok = weapon_info[1]
        attack_rate = base_pok
        target_speed_max = weapon_info[8]
        if contact_type == 0:
            # 目标是飞机
            # Weapon.cs:   activeUnit_Weaponry5.method_29(float_52, ref weapon_);
            # 如果飞机致盲导弹失败  ActiveUnit_Weaponry.cs,5130行 obj2.vmethod_146(activeUnit_5, scenario_, false, ref list_2);
            if qd.is_missile(weapon_id):
                # 是导弹，发射导弹单元到目标距离对命中率修正
                launch_distance = unit_target_distance
                dis_change = launch_distance / air_range
                weapon_propulsion_type = qd.get_weapon_propulsion_type(weapon_id)
                if weapon_propulsion_type == 5001 or weapon_propulsion_type == 5003:
                    propulsion_rate = 0.5
                else:
                    propulsion_rate = 0.75
                if dis_change > propulsion_rate:
                    attack_rate = base_pok * propulsion_rate + base_pok * (1 - propulsion_rate) * \
                                  (1 - (dis_change - propulsion_rate) / (1 - propulsion_rate))
            if target_speed_max > 0:
                # 命中概率(经目标速度修正)
                speed_rate = 0
                target_speed = contact.get('CurrentSpeed', 0)
                if target_speed > target_speed_max:
                    speed_rate = 50
                elif target_speed > target_speed_max * 0.8:
                    speed_rate = 25
                elif target_speed > target_speed_max * 0.7:
                    speed_rate = 15
                elif target_speed > target_speed_max * 0.6:
                    speed_rate = 10
                elif target_speed > target_speed_max * 0.5:
                    speed_rate = 5
                if speed_rate:
                    attack_rate = round(attack_rate - speed_rate)
                    if attack_rate < 0:
                        attack_rate = 0
            if IDStatus >= 3:
                air_info = qd.get_aircraft_info(contact_dbid)
                agility, crew = air_info[3], air_info[8]
                weapon_detect_by_aircraft = True  # 默认飞机能够探测到攻击的导弹, 实际应根据是否已探测到进行判断
                if crew > 0 and agility > 0 and weapon_detect_by_aircraft:
                    actual_agility = get_aircraft_kinematics_agility(agility, contact, contact_proficiency_level)
                    # 假定飞机会90度左右躲避导弹，机动系数(强转向攻击无影响)
                    actual_agility = round(actual_agility, 1)
                    attack_rate = attack_rate - actual_agility * 10
            elif IDStatus >= 2:
                # if contact["IDEN"].startswith("类型："):
                #     type_name = contact["IDEN"][3:]  # 实体具体类型
                #     type_id = qd.get_unit_type_id("Aircraft", type_name)
                type_id = contact.get('ActType', -1)
                agility = estimate_agility(type_id)
                attack_rate = attack_rate - agility * 10
        if contact.get('CurrentAltitude', 0) > 0 and not qd.is_weapon_capableVSSeaskimmer(weapon_id):
            # 命中概率(掠海攻击修正量)
            target_alt = contact.get('CurrentAltitude', 0)
            sea_rate = 30
            if target_alt >= 91.44:
                sea_rate = 0
            elif target_alt >= 60.96:
                sea_rate = 5
            elif target_alt >= 30.48:
                sea_rate = 15
            attack_rate -= sea_rate
        if contact_type == 1:
            # 目标是导弹
            heading_sub = heading_sub_from_unit(contact, None, target_location, unit_location)
            heading_weak_rate = 1.0
            if heading_sub > 180:
                heading_sub = 360 - heading_sub
            if heading_sub > 90:
                heading_sub = 180 - heading_sub
            heading_weak_rate = 1 - heading_sub * 0.5 / 90
            attack_rate *= heading_weak_rate

            # 武器信号特征修正
            radar_signa, infra_signa = {}, {}
            section1 = qd.get_unit_signatures("Weapon", weapon_id)
            if 5002 in section1:
                radar_signa = section1[5002]
            if 4001 in section1:
                infra_signa = section1[4001]

            if 45 < heading_sub < 315:
                if 45 <= heading_sub <= 135 or 225 <= heading_sub <= 315:
                    signa_radar = get_missile_side_signa(contact_dbid, contact, unit_location, radar_signa, 5002)
                    signa_infra = get_missile_side_signa(contact_dbid, contact, unit_location, infra_signa, 4001)
                elif 135 < heading_sub < 225:
                    signa_radar = get_missile_rear_signa(contact_dbid, contact, unit_location, radar_signa, 5002)
                    signa_infra = get_missile_rear_signa(contact_dbid, contact, unit_location, infra_signa, 4001)
            else:
                signa_radar = get_missile_front_signa(contact_dbid, contact, unit_location, radar_signa, 5002)
                signa_infra = get_missile_front_signa(contact_dbid, contact, unit_location, infra_signa, 4001)
            # 武器速度和信号特征修改
            attack_rate = weapon_speed_signa_attack_rate(weapon_id, target_speed_max, attack_rate,
                                                         contact.get('CurrentSpeed', 0),
                                                         signa_radar, signa_infra)
        attack_rate /= 100
        if attack_rate < 0:
            attack_rate = 0
    return attack_rate


def estimate_agility(air_type):
    agility = 2
    if air_type == 2001 or air_type == 2002 or air_type == 3001 or air_type == 3002 or air_type == 3401:
        # 战斗机、攻击机
        agility = 4.5
    elif air_type == 3101:
        # 轰炸机
        agility = 2.0
    elif air_type == 4002 or air_type == 4003:
        # 预警机
        agility = 0.5
    elif air_type == 8201 or air_type == 8202:
        # 无人机
        agility = 0.5
    elif air_type == 4001 or 7003 <= air_type <= 7005:
        # 电子战或侦察机
        agility = 2.5
    return agility


def get_aircraft_kinematics_agility(agility, contact, proficiency_level):
    kinematics_agility = round(get_agility_with_alt(agility, contact), 1)
    if proficiency_level == 0:
        kinematics_agility *= 0.3
    elif proficiency_level == 1:
        kinematics_agility *= 0.5
    elif proficiency_level == 2:
        kinematics_agility *= 0.8
    elif proficiency_level == 3:
        pass
    elif proficiency_level == 4:
        kinematics_agility *= 1.2
    weight_rate = qd.get_load_weight_rate(contact.get("DBID", 0))
    kinematics_agility = 0.4 * kinematics_agility + 0.6 * kinematics_agility * (1 - weight_rate)
    return kinematics_agility


def get_agility_with_alt(agility, contact):
    alt = contact.get('CurrentAltitude', 0)
    if alt <= 3000:
        return agility
    weak_rate = 0.5
    air_dbid = contact.get("DBID", 0)
    if 4001 in qd.get_unit_codes("Aircraft", air_dbid):
        weak_rate = 0.25
    alt_to_max = qd.get_aircraft_ceiling(air_dbid) - contact.get('CurrentAltitude', 0)
    alt_to_3000 = contact.get('CurrentAltitude', 0) - 3000
    if alt_to_max < 0.001:
        agility_down = 9999999
        # print("aircraft_ceiling %d" % qd.get_aircraft_ceiling(air_dbid))
        # print("contact_altitude %d" % contact.get('CurrentAltitude', 0))
        # print(air_dbid)
    else:
        agility_down = agility * weak_rate * (alt_to_3000 / alt_to_max)
    return max(agility * (1 - weak_rate), agility - agility_down)


def heading_sub_from_unit(contact, unit, target_location=None, unit_location=None):
    contact_heading = contact.get('CurrentHeading', 0)
    if target_location is None:
        target_location = contact.get("Lat", 0), contact.get("Lon", 0)
    if unit_location is None:
        unit_location = unit.get("Lat", 0), unit.get("Lon", 0)
    to_unit_head = geo.get_azimuth(target_location, unit_location)
    return geo.normal_angle(to_unit_head - contact_heading)


def get_missile_side_signa(missile_id, missile_contact, unit_location, unit_signa_info, signa_type):
    """
    返回前信号特征
    xsection.cs
    float method_9(ActiveUnit activeUnit_0)
    """
    signa = 1
    if "Side" in unit_signa_info:
        signa = unit_signa_info["Side"]
    if signa_type != 3001 and signa_type != 4001:
        return signa
    engine_type = qd.get_weapon_propulsion_type(missile_id)
    if engine_type == 5001:
        # 火箭发动机引擎
        signa *= 5
    alt_rate = signa_alt_speed_rate(missile_contact.get('CurrentAltitude', 0), missile_contact.get('Current'
                                                                                                   ''
                                                                                                   'Speed', 0))
    if alt_rate > 1.0:
        signa *= alt_rate
    return signa


def signa_alt_speed_rate(alt, speed):
    """
    xsection.cs
    float smethod_0(double double_0, double double_1)
    """
    a_rate = 1718
    alt_m = alt / 0.3048
    if alt_m <= 36152.0:
        alt_m = 518.6 - 3.56 * alt_m / 1000
    elif 36152.0 < alt_m <= 82345.0:
        alt_m = 389.98
    elif 82345.0 < alt_m <= 155348.0:
        alt_m = 389.98 + 1.645 * (alt_m - 82345.0) / 1000.0
    elif 155348.0 < alt_m <= 175346.0:
        alt_m = 508.788
    elif 175346.0 < alt_m <= 262448.0:
        alt_m = 508.788 - 2.46888 * (alt_m - 175346.0) / 1000.0
    else:
        alt_m = 508.788
    num4 = math.sqrt(1.4 * a_rate * alt_m)
    num4 = num4 * 60.0 * 0.8689755962687 / 88.0
    return speed / num4


def get_missile_rear_signa(missile_id, missile_contact, unit_location, unit_signa_info, signa_type):
    """
    返回前信号特征
    xsection.cs
    float method_9(ActiveUnit activeUnit_0)
    """
    signa = 1
    if "Rear" in unit_signa_info:
        signa = unit_signa_info["Rear"]
    if signa_type != 3001 and signa_type != 4001:
        return signa
    engine_type = qd.get_weapon_propulsion_type(missile_id)
    if engine_type == 5001:
        # 火箭发动机引擎
        signa *= 5
    alt_rate = signa_alt_speed_rate(missile_contact.get('CurrentAltitude', 0), missile_contact.get('CurrentSpeed', 0))
    if alt_rate > 1.0:
        signa *= alt_rate
    return signa


def get_missile_front_signa(missile_id, missile_contact, unit_location, unit_signa_info, signa_type):
    """
    返回前信号特征
    xsection.cs
    float method_9(ActiveUnit activeUnit_0)
    """
    front_signa = 1
    if "Front" in unit_signa_info:
        front_signa = unit_signa_info["Front"]
    if signa_type != 3001 and signa_type != 4001:
        return front_signa
    engine_type = qd.get_weapon_propulsion_type(missile_id)
    if engine_type == 5001:
        # 火箭发动机引擎
        front_signa *= front_signa
    alt_rate = signa_alt_speed_rate(missile_contact.get('CurrentAltitude', 0), missile_contact.get('CurrentSpeed', 0))
    if alt_rate > 1.0:
        front_signa *= alt_rate
    return front_signa


def weapon_speed_signa_attack_rate(weapon_id, target_speed_max, attack_rate, target_speed, radar_signa, infra_signa):
    """
    武器拦截，速度和信号特征修改正
    Weapon.cs
    float method_273(int int_21, float float_52, float float_53, float float_54, StringBuilder stringBuilder_0)
    """
    weak_rate = 0
    if target_speed > target_speed_max:
        weak_rate = 50
    elif target_speed > target_speed_max * 0.8:
        weak_rate = 25
    elif target_speed > target_speed_max * 0.7:
        weak_rate = 15
    elif target_speed > target_speed_max * 0.6:
        weak_rate = 10
    elif target_speed > target_speed_max * 0.5:
        weak_rate = 5
    # 信号特征修正---之后有时间修改
    return max(1, attack_rate - weak_rate)


def get_attack_prob(unit_info, target_info):
    """
    筛选可以打击的目标和对应的武器，并生成拦截概率矩阵和对应的索引
    """
    p_matrix = []
    a_matrix = []
    weapon_list = []
    for unit_guid, unit in unit_info.items():
        for weapon_dbid, weapon in unit.items():
            a_matrix.append(list(weapon['interception_ability'].values()))
            p_matrix.append(list(weapon['interception_probability'].values()))
            weapon_list.append((unit_guid, weapon_dbid))

    p_matrix = np.array(p_matrix)
    a_matrix = np.array(a_matrix)

    # 获取非零行索引
    non_zero_rows = np.where(a_matrix.any(axis=1))[0]
    # 获取非零列索引
    non_zero_cols = np.where(a_matrix.any(axis=0))[0]

    # 过滤矩阵
    a_encode = a_matrix[non_zero_rows][:, non_zero_cols]
    p_encode = p_matrix[non_zero_rows][:, non_zero_cols]

    m = p_encode.shape[0]
    n = p_encode.shape[1]

    target_list = list(target_info.keys())
    available_alloc = {'weapon_list': weapon_list,
                       'target_list': target_list,
                       'non_zero_rows': non_zero_rows,
                       'non_zero_cols': non_zero_cols}
    # 获取每个目标的拦截价值
    ori_v = []
    # p_matrix = np.random.rand(m, n) * 0.8 + 0.1  # 拦截概率在0.1-0.9之间
    for target in target_info.values():
        ori_v.append(target['value'])
    v = np.array(ori_v)
    v_encode = v[non_zero_cols]
    # 将价值归一化到0-10之间
    if v_encode.size > 0:
        normal_v = v_encode / np.max(v_encode) * 10
    else:
        normal_v = np.random.randint(1, 10, n)  # 目标价值1-10
    ammo = np.ones(m, dtype=int)  # 每个武器最多使用1次

    return p_encode, a_encode, normal_v, ammo, available_alloc


def process_matrix_solution(solution, available_alloc):
    """
    将打击方案从矩阵格式转换为单元-武器-目标对
    :param, solution: array, 己方单元*来袭目标
    :param, unit_info  己方单元信息
    :param, target_info  目标单元信息
    :return, wta_result list，[[己方单元guid,己方单元武器dbid,打击目标单元dbid]]
    [['zjj7bf-0hn937r864rgp', 3467, 'iqfusl-0hnbn7gvic3bk'],
    ['zjj7bf-0hn9c004pi181', 3467, 'iqfusl-0hnbn7gvj5vka'],
    ['zjj7bf-0hn937r864oab', 1426, 'iqfusl-0hnbn7gvhqgrv'],
    ['zjj7bf-0hn937r864pqu', 1426, 'iqfusl-0hnbn7gvifptj'],
    ['zjj7bf-0hn937r864pqu', 3105, 'iqfusl-0hnbn7gvjicih']]
    """
    # 转换后的打击方案
    wta_result = []
    # 根据打击矩阵获得对应的单元、武器和目标
    indices = np.argwhere(solution == 1)
    for pair in indices:
        weapon_index = available_alloc['non_zero_rows'][pair[0]]
        target_index = available_alloc['non_zero_cols'][pair[1]]
        wta_result.append(
            [available_alloc['weapon_list'][weapon_index][0], available_alloc['weapon_list'][weapon_index][1],
             available_alloc['target_list'][target_index]])
    return wta_result


def save_situation_2json(agent_obj):
    # 将初始时刻的态势保存到智能体目录的situation_data.json文件中
    situation_info = agent_obj.situation_info
    file_path = os.path.join(os.path.dirname(__file__), 'situation_data.json')
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(situation_info, f, indent=4)
        print("保存初始态势至文件%s" % file_path)
    return


def predict_attack_target(missile, units):
    """
    寻找距离导弹航线最近的防御单元

    参数:
    missile_lat (float): 导弹纬度坐标
    missile_lon (float): 导弹经度坐标
    heading_deg (float): 导弹航向角（0-360度，正北为0顺时针）
    units (list[dict]): 防御单元列表，每个单元包含'latitude'和'longitude'

    返回:
    dict: 距离航线最近的单元信息（若无单元返回None）
    """
    R = 6371000  # 地球半径（米）

    missile_lat = missile.get('Lat', 0)
    missile_lon = missile.get('Lon', 0)
    heading_deg = missile.get('CurrentHeading', 0)

    # 计算导弹航向方向向量（单位向量）
    theta_rad = math.radians(heading_deg)
    dx = math.sin(theta_rad)  # 东方向分量
    dy = math.cos(theta_rad)  # 北方向分量

    # 保存最近结果
    closest_distance = None
    closest_unit = None

    # 计算导弹位置的纬度半径缩放因子
    cos_m_lat = math.cos(math.radians(missile_lat))

    for unit in units.values():
        # 将单元坐标转换为相对导弹位置的平面坐标
        dlat = (unit.get('lat', 0) - missile_lat) * math.pi / 180
        dlon = (unit.get('lon', 0) - missile_lon) * math.pi / 180

        # 转换为平面坐标（米）
        x = dlon * R * cos_m_lat
        y = dlat * R

        # 计算点到航线的垂直距离
        distance = abs(dy * x - dx * y)

        # 更新最近单元
        if (closest_distance is None) or (distance < closest_distance):
            closest_distance = distance
            closest_unit = unit

    return closest_unit


def get_closest_target(air, units):
    """
    寻找距离飞机最近的单元作为其目标单元
    params:
    air 来袭飞机单元
    units (list[dict]): 防御单元列表，每个单元包含'latitude'和'longitude'
    返回:
    dict: 距离航线最近的单元信息（若无单元返回None）
    """
    # 保存最近结果
    closest_distance = None
    closest_unit = None
    for unit in units.values():
        # 计算两点距离
        distance = geo.get_horizontal_distance((unit.get('lat', 0), unit.get('lon', 0)),
                                               (air.get('Lat', 0), air.get('Lon', 0)))
        # 更新最近单元
        if (closest_distance is None) or (distance < closest_distance):
            closest_distance = distance
            closest_unit = unit
    return closest_unit


def get_air_type(air_type):
    """
    获得飞机类别分类
    参数：
        air_type:态势中的飞机单元分类
    返回：
        态势飞机的分类
    """
    if air_type == 3101:
        # 轰炸机
        aircraft_type = 'bomber'
    elif air_type == 4002 or air_type == 4003:
        # 预警机、指挥机 （ACP）
        aircraft_type = 'awacs'
    elif air_type == 4001 or 7003 <= air_type <= 7005:
        # 电子情报（Electronic Intelligence)、电子战、侦察、区域监视、海上巡逻
        aircraft_type = 'elint'
    elif air_type == 8001:
        # 加油机
        aircraft_type = 'tanker'
    elif air_type == 3001 or air_type == 3002:
        # 攻击机、防空压制
        aircraft_type = 'attack'
    elif air_type == 2001 or air_type == 2002 or air_type == 3401:
        # 战斗机、多用途飞机、战场空中拦截（BAI/ CAS）  2001 2002 3401
        aircraft_type = 'fighter'
    elif air_type == 6001:
        # 反潜作战
        aircraft_type = 'asw'
    elif air_type == 8201 or air_type == 8202:
        # 无人机、无人作战飞行器
        aircraft_type = 'uav'
    else:
        aircraft_type = "unknown"
    return aircraft_type


def type_threat(aircraft_type):
    """
    根据目标的价值得分表计算出的目标类型威胁度字典type_threat_dict
    威胁度字典可自定义

    参数：
        aircraft_type:飞机类别
    返回：
        飞机类别对应的类别威胁度
    """
    type_threat_dict = {
        'bomber': 1,  # 轰炸机
        'awacs': 0.42,  # 预警机
        'elint': 0.51,  # 电子情报（Electronic Intelligence)、电子战、侦察、区域监视、海上巡逻
        'tanker': 0.41,  # 加油机
        'attack': 0.125,  # 攻击机
        'fighter': 0.15,  # 战斗机
        'asw': 0.0625,  # 反潜作战
        'uav': 0.05  # 无人机
    }
    return type_threat_dict.get(aircraft_type.lower(), 0.5)  # 默认威胁值设置为0.5


def pre_check_weapon_fire_condition(weapon_info, target, unit_target_d):
    """
    预开火条件检查
    :params weapon_info 己方武器信息
    :params target，目标信息
    :params unit_target_d，单元和目标的距离，单位 海里
    """
    intercept_radius = weapon_info[0]  # 拦截半径，单位: 海里  ,weapon_info的第一个是AirRangeMax拦截范围
    target_speed_max = weapon_info[8]  # 最大拦截速度，单位：海里/消失
    target_altitude_max = weapon_info[9]  # 最大拦截高度，单位：米
    target_altitude_min = weapon_info[10]  # 最小拦截高度，单位：米
    altitude = target.get('CurrentAltitude', 0)
    speed = target.get('CurrentSpeed', 99999999)
    if unit_target_d < intercept_radius and target_altitude_min < altitude < target_altitude_max \
            and speed < target_speed_max:
        fire_flag = True
    else:
        fire_flag = False
    return fire_flag


def append_to_excel(file_path, data):
    """
    向excel文件追加数据
    :param file_path str 文件名称，
    :param data list 待追加的数据
    """
    wb = load_workbook(file_path)
    ws = wb["Sheet"]
    # 写入新数据
    ws.append(data)
    wb.save(file_path)
    return
