# _*_ coding: utf-8 _*_
# @Author : wm
# @Date : 2024-10-22
# @Version :V0.01
# @description :  生成特定格式的态势数据
import random
from discrete_simulation.time_window_calculation import *
from tools.common_functions import *
from agent.MOPSO import BinaryMOPSO
import numpy as np

# 配置参数
num_units = 5  # 地导营数量
num_targets = 50  # 目标数量
cur_time = 0  # 当前时间
period_duration = 600  # 动态阶段时长
fire_control_channel_num = 2  # 地导营火控通道数量
latitude_range = [35.9, 43.2]  # 纬度范围
longitude_range = [115.3, 121.2]  # 经度范围


# 计算拦截能力
def find_intersection_points(theta_degrees, location, center, radius):
    """
    在二维坐标系中，计算圆与直线的交点。此时，圆表示地导营的拦截范围，直线表示三维坐标系中的导弹轨迹在二维坐标系中的投影。
    :param theta_degrees:某个敌方来袭导弹的方向信息。
    :param location:某个敌方来袭导弹的位置。
    :param center:地导营的位置坐标，格式：[x, y]。
    :param radius: 地导营的拦截半径。
    :return: 返回一个列表，包含直线和圆的所有交点的坐标。
    """
    # 地导营的位置坐标
    x, y = center[1], center[0]
    # 直线上的一点
    x_1, y_1 = location[1], location[0]
    # 将角度转换为与正东方向的夹角
    angle_with_east = 90 - theta_degrees
    # 确保角度在0-360度之间
    if angle_with_east < 0:
        angle_with_east += 360
    # 将角度转换为弧度
    theta_radians = math.radians(angle_with_east)
    # 计算斜率
    slope = math.tan(theta_radians)
    # 计算截距
    intercept = y_1 - x_1 * slope
    # 将圆的方程与直线方程联立，得到一个关于x的二次方程，系数分别如下
    a = 1 + slope ** 2
    b_coefficient = 2 * slope * intercept - 2 * slope * y - 2 * x
    c = x ** 2 + (intercept - y) ** 2 - (radius / 111) ** 2
    # 计算判别式，判断圆与直线是否有交点，以及交点的个数
    discriminant = b_coefficient ** 2 - 4 * a * c

    # 如果判别式小于0，说明没有交点
    if discriminant < 0:
        return []

    # 否则，解二次方程，求得交点
    sqrt_discriminant = math.sqrt(discriminant)
    x1 = (-b_coefficient + sqrt_discriminant) / (2 * a)
    x2 = (-b_coefficient - sqrt_discriminant) / (2 * a)

    # 计算交点的y坐标
    y1 = slope * x1 + intercept
    y2 = slope * x2 + intercept

    # lat1 = get_lat_by_y(21.036, y1)
    # lon1 = get_lon_by_x(110.595,21.036, x1)

    # 返回交点坐标列表lat,lon
    return [[y1, x1], [y2, x2]]


def generate_unit_location():
    """
    随机生成一定数量的地导营的位置
    :return:包含地导营位置和拦截半径的字典
    """
    unit_location = {}
    min_latitude = latitude_range[0]
    max_latitude = latitude_range[1]
    per_width = (max_latitude - min_latitude) / num_units
    for unit_id in range(1, num_units + 1):
        location = [random.uniform(min_latitude + per_width * (unit_id - 1), min_latitude + per_width * unit_id),
                    random.uniform(117.1, 118.7)]  # 地导营所在位置，[纬度，经度]
        intercept_radius = 80  # 拦截半径， 公里
        unit_location[f"u{unit_id:05}"] = {
            "loc": location,
            "intercept_radius": intercept_radius
        }
    return unit_location


def find_point_randomly(i, intercept_radius):
    """
    随机生成固定区域内的点
    :param i: 区域ID
    :param intercept_radius: 拦截半径
    :return:随机生成区域内的点
    """

    point = [random.uniform(latitude_range[0] + (latitude_range[1] - latitude_range[0]) / num_units * i,
                            latitude_range[0] + (latitude_range[1] - latitude_range[0]) / num_units * (i + 1)),
             random.uniform(118.7 + intercept_radius / 111, longitude_range[1])]
    return point


# 生成目标的态势信息
def generate_target_info(unit_info):
    """
    按照地导单元位置随机生成目标
    :param unit_info:地导营的位置和拦截半径信息
    :return: 目标信息
    """
    target_id = 1
    target_info = {}
    i = 0
    target_parts = divide_into_even_parts(num_targets, num_units)

    for id, info in unit_info.items():
        location = info['loc']
        intercept_radius = info['intercept_radius']
        num = target_parts[i]
        # expansible_radius = intercept_radius / 111 + random.uniform(0, 0.5)
        for _ in range(num):
            # 让来袭目标在地导营周围适当的距离内（不要太远导致目标分配给其他的地导营，不要太近导致拦截时间过短拦截不到）
            target_location = find_point_randomly(i, intercept_radius)

            threat = random.randint(0, 100)  # 来袭目标的威胁值，[0,100]
            max_num_of_weapons = threat // 10 + 1  # 最多拦截武器弹药的数量，与目标威胁值相关联(0-10)
            goal_location = [random.uniform(latitude_range[0] + (latitude_range[1] - latitude_range[0]) / num_units * i,
                                            latitude_range[0] + (latitude_range[1] - latitude_range[0]) / num_units * (
                                                    i + 1)),
                             random.uniform(longitude_range[0], 117.1 - intercept_radius / 111)]  # 目标消失点
            velocity = 500  # 公里/时
            theta_degrees = cal_azi(target_location, goal_location)  # 导弹来袭方向，与正北方向夹角
            distance = get_horizontal_distance(target_location, goal_location)

            target_info[f"t{target_id:05}"] = {
                "name": f"missile_{target_id}",
                "id": f"t{target_id:05d}",
                "location": target_location,
                "height": random.randint(100, 100),
                "value": random.randint(10, 100),
                "velocity": velocity,  # 来袭速度，公里/时
                "theta_degrees": theta_degrees,  # 目标来袭方向与正北方向的夹角
                "threat": threat,  # 来袭目标的威胁值，[0,100]
                "max_num_of_weapons": max_num_of_weapons,  # 可分配给该目标的最多的武器数量, 与来袭目标的威胁值有关，[0,10]
                "goal_location": goal_location,  # 目标消失点
                "disappear_time": cur_time + distance / velocity * 3600  # 目标消失时间
            }
            target_id += 1
        i += 1
    return target_info


# 地导单元的信息
def generate_unit_info(target_info, unit_location):
    """
    完善地导单元信息，生成地导营-目标之间的分配结果信息
    :param target_info:目标信息
    :param unit_location:地导单元信息
    :return:agent_info和unit_info
    """
    unit_info = copy.deepcopy(unit_location)
    for key, value in unit_info.items():
        ability = {}
        probability = {}
        value["height"] = random.randint(100, 100)  # 地导营所在高度
        value["fire_control_channel_num"] = fire_control_channel_num  # 火控通道数量
        value["ammunition_amount"] = random.randint(0, 10)  # 可用弹药数量，[0,100]
        value["trajectory_speed"] = 800  # 发射速度，公里/时
        for missile, missile_info in target_info.items():
            intersection_points_lst = find_intersection_points(missile_info['theta_degrees'], missile_info['location'],
                                                               value['loc'],
                                                               value['intercept_radius'])
            if len(intersection_points_lst) == 0:
                ability[missile] = 0
                value["interception_ability"] = ability
                probability[missile] = 0
                value["interception_probability"] = probability
            else:
                ability[missile] = 1
                value["interception_ability"] = ability
                probability[missile] = 1 * random.uniform(0, 1)
                value["interception_probability"] = probability
        unit_info[key] = value
    return unit_info


def generate_agent_info(situation_info):
    """
    :param situation_info:态势信息
    生成agent_info的态势
    :return:
    """

    unit_info = situation_info["unit_info"]
    target_info = situation_info["target_info"]
    agent_info = {}
    for key, value in unit_info.items():
        agent_info[key] = []
    for target_id in target_info:
        missile_info_new = {}
        missile_info_new["id"] = target_id
        target_location = target_info[target_id]['location']
        closest_target = min(unit_info.items(), key=lambda x: get_horizontal_distance(x[1]["loc"], target_location))
        closest_target_id = closest_target[0]
        missile_info_new["interception_ability"] = unit_info[closest_target_id]["interception_ability"][target_id]
        missile_info_new["interception_probability"] = unit_info[closest_target_id]["interception_probability"][
            target_id]
        missile_info_new["max_num_of_weapons"] = target_info[target_id]["max_num_of_weapons"]
        missile_info_new["location"] = target_info[target_id]["location"]
        missile_info_new["velocity"] = target_info[target_id]["velocity"]
        missile_info_new["theta_degrees"] = target_info[target_id]["theta_degrees"]
        if missile_info_new["interception_ability"] == 1:
            agent_info[closest_target_id].append(missile_info_new)
    agent_info = intercept_time_window_calculation(cur_time, agent_info, unit_info)
    # print("agent_info:", agent_info)
    return agent_info


def generate_agent_info_alg(situation_info):
    """
    :param situation_info:态势信息
    生成agent_info的态势，在该函数中插入算法
    :return:
    """
    unit_info = situation_info["unit_info"]
    target_info = situation_info["target_info"]
    agent_info = {}
    p_matrix = []
    a_matrix = []
    ammo = []
    for key, value in unit_info.items():
        p_matrix.append(list(value['interception_probability'].values()))
        a_matrix.append(list(value['interception_probability'].values()))
        ammo.append(value['ammunition_amount'])

    ammo = np.array(ammo)
    p_matrix = np.array(p_matrix)
    a_matrix = np.array(a_matrix)

    m = p_matrix.shape[0]
    n = p_matrix.shape[1]

    V = []
    for key, value in target_info.items():
        V.append(value['value'])
    V_encode = np.array(V)
    if V_encode.size > 0:
        normal_V = V_encode / np.max(V_encode) * 10
    else:
        print("无可打击目标或者目标没有价值属性")
        normal_V = np.random.randint(1, 10, n)  # 目标价值1-10
    # ammo = np.ones(m, dtype=int)  # 每个武器最多使用1次
    # p_matrix, a_matrix, normal_V, ammo, available_alloc = get_attack_prob(unit_info, target_info)
    solver = BinaryMOPSO(p_matrix, normal_V, ammo, swarm_size=20, max_iter=10)
    best_solution, best_obj = solver.run()
    i = 0
    target_id_list = list(target_info.keys())
    for key, value in unit_info.items():
        agent_info[key] = []
        solu_i = best_solution[i]
        non_zero_cols_s = np.where(solu_i)[0]
        for col in non_zero_cols_s:
            target_id = target_id_list[col]
            missile_info_new = {}
            missile_info_new["id"] = target_id
            missile_info_new["interception_ability"] = unit_info[key]["interception_ability"][target_id]
            missile_info_new["interception_probability"] = unit_info[key]["interception_probability"][
            target_id]
            missile_info_new["max_num_of_weapons"] = target_info[target_id]["max_num_of_weapons"]
            missile_info_new["location"] = target_info[target_id]["location"]
            missile_info_new["velocity"] = target_info[target_id]["velocity"]
            missile_info_new["theta_degrees"] = target_info[target_id]["theta_degrees"]
            if missile_info_new["interception_ability"] == 1:
                agent_info[key].append(missile_info_new)
        i += 1
    agent_info = intercept_time_window_calculation(cur_time, agent_info, unit_info)
    # print("agent_info:", agent_info)
    return agent_info


def generate_agent_info_alg_dynamic(situation_info, delete_target):
    """
    :param situation_info:态势信息
    生成agent_info的态势，在该函数中插入算法
    :return:
    """
    unit_info = situation_info["unit_info"]
    target_info = situation_info["target_info"]
    delete_target=[]
    for i in range(50):
        if i + 1 < 10:
            target_id = 't0000' + str(i + 1)
        elif i + 1 >= 10 :
            target_id = 't000' + str(i + 1)
        else:
            target_id = 't00' + str(i + 1)
        if target_id not in list(target_info.keys()):
            delete_target.append(target_id)
    agent_info = {}
    p_matrix = []
    a_matrix = []
    ammo = []
    for key, value in unit_info.items():
        p = value['interception_probability']
        abi = value['interception_probability']
        # for id in delete_target:
            # p.pop(id, None)
            # abi.pop(id, None)
        p_matrix.append(list(p.values()))
        a_matrix.append(list(abi.values()))
        ammo.append(value['ammunition_amount'])
        agent_info[key] = []

    p_matrix = np.array(p_matrix)
    a_matrix = np.array(a_matrix)
    ammo = np.array(ammo)

    # 获取非零行索引
    non_zero_rows = np.where(a_matrix.any(axis=1))[0]
    # 获取非零列索引
    non_zero_cols = np.where(a_matrix.any(axis=0))[0]
    non_zero_lst = list(non_zero_cols)
    # 获取全零行索引
    zero_rows = np.where(~a_matrix.any(axis=1))[0]
    # 获取全零列索引
    zero_cols = np.where(~a_matrix.any(axis=0))[0]
    zero_lst = list(zero_cols)
    for id in delete_target:
        num = int(id[3:])-1
        if num in non_zero_lst:
            non_zero_lst.remove(num)
        if num not in zero_lst:
            zero_lst.append(num)
    non_zero_cols = np.array(non_zero_lst)
    non_zero_cols.sort()
    zero_cols = np.array(zero_lst)
    zero_cols.sort()

    # 过滤矩阵
    a_encode = a_matrix[non_zero_rows][:, non_zero_cols]
    p_encode = p_matrix[non_zero_rows][:, non_zero_cols]

    for col in zero_cols:
        if col + 1 < 10:
            target_id = 't0000' + str(col + 1)
        elif col + 1 >= 10 :
            target_id = 't000' + str(col + 1)
        else:
            target_id = 't00' + str(col + 1)
        target_info.pop(target_id,None)
    for row in zero_rows:
        if row + 1 < 10:
            unit_id = 'u0000' + str(row + 1)
        elif row + 1 >= 10 and row < 100:
            unit_id = 'u000' + str(row + 1)
        else:
            unit_id = 'u00' + str(row)
        unit_info.pop(unit_id,None)


    m = p_encode.shape[0]
    n = p_encode.shape[1]
    V = []
    for key, value in target_info.items():
        V.append(value['value'])
    V_encode = np.array(V)
    if V_encode.size > 0:
        normal_V = V_encode / np.max(V_encode) * 10
    else:
        print("无可打击目标或者目标没有价值属性")
        normal_V = np.random.randint(1, 10, n)  # 目标价值1-10
    # ammo = np.ones(m, dtype=int)  # 每个武器最多使用1次
    # p_matrix, a_matrix, normal_V, ammo, available_alloc = get_attack_prob(unit_info, target_info)
    solver = BinaryMOPSO(p_encode, normal_V, ammo, swarm_size=20, max_iter=10)
    best_solution, best_obj = solver.run()
    i = 0
    target_id_list = list(target_info.keys())
    for key, value in unit_info.items():
        solu_i = best_solution[i]
        non_zero_cols = np.where(solu_i)[0]
        for col in non_zero_cols:
            target_id = target_id_list[col]
            missile_info_new = {}
            missile_info_new["id"] = target_id
            missile_info_new["interception_ability"] = unit_info[key]["interception_ability"][target_id]
            missile_info_new["interception_probability"] = unit_info[key]["interception_probability"][
            target_id]
            missile_info_new["max_num_of_weapons"] = target_info[target_id]["max_num_of_weapons"]
            missile_info_new["location"] = target_info[target_id]["location"]
            missile_info_new["velocity"] = target_info[target_id]["velocity"]
            missile_info_new["theta_degrees"] = target_info[target_id]["theta_degrees"]
            if missile_info_new["interception_ability"] == 1:
                agent_info[key].append(missile_info_new)
        i += 1
    agent_info = intercept_time_window_calculation(cur_time, agent_info, unit_info)
    # print("agent_info:", agent_info)
    return agent_info


# 生成态势的信息
def generate_situation_info():
    """
    生成态势的信息
    """
    unit_location = generate_unit_location()
    # 生成目标的信息
    target_info = generate_target_info(unit_location)
    # 生成地导营信息
    unit_info = generate_unit_info(target_info, unit_location)
    # 生成态势信息
    situation_info = {}
    situation_info["cur_time"] = cur_time
    situation_info["period_duration"] = period_duration
    situation_info["target_info"] = target_info
    situation_info["unit_info"] = unit_info
    # print("situation_info:", situation_info)
    return situation_info


def agent_decision_proccess(agent_decision, situation_info):
    """
    将智能体决策信息矩阵进行处理，转换成字典格式
    :param agent_decision:
    :param situation_info:
    :return:
    """
    unit_info = situation_info['unit_info']
    target_info = situation_info['target_info']

    agent_info = dict()
    for key, value in unit_info.items():
        agent_info[key] = []
    # 先检验分配矩阵的维度是否符合态势中的武器目标数量
    unit_count = len(agent_decision)  # 武器数量
    target_count = len(agent_decision[0])  # 目标数量
    if unit_count != len(unit_info) or target_count != len(target_info):
        raise ValueError("决策结果与态势信息不符，请检查！")
    else:
        for i in range(len(agent_decision)):
            for j in range(len(agent_decision[i])):
                if i + 1 < 10:
                    unit_id = 'u0000' + str(i + 1)
                elif i + 1 >= 10 and i < 100:
                    unit_id = 'u000' + str(i + 1)
                else:
                    unit_id = 'u00' + str(i)
                if agent_decision[i][j] > 0:
                    missile_info_new = {}
                    if j + 1 < 10:
                        target_id = 't0000' + str(j + 1)
                    elif j + 1 >= 10 and i < 100:
                        target_id = 't000' + str(j + 1)
                    else:
                        target_id = 't00' + str(j + 1)
                    missile_info_new["id"] = target_id
                    missile_info_new["interception_ability"] = unit_info[unit_id]["interception_ability"][target_id]
                    missile_info_new["interception_probability"] = unit_info[unit_id]["interception_probability"][
                        target_id]
                    missile_info_new["max_num_of_weapons"] = agent_decision[i][j]
                    missile_info_new["location"] = target_info[target_id]["location"]
                    missile_info_new["velocity"] = target_info[target_id]["velocity"]
                    missile_info_new["theta_degrees"] = target_info[target_id]["theta_degrees"]
                    # if missile_info_new["interception_ability"] == 1:
                    agent_info[unit_id].append(missile_info_new)
                    # else:
                    # test, 默认决策结果根据匹配能力才输出的，此处不做check
                    # raise ValueError("决策结果与态势信息的匹配能力不符，请检查！")
    agent_info = intercept_time_window_calculation(cur_time, agent_info, unit_info)

    return agent_info


# 态势信息
situation_info = generate_situation_info()
# 生成agent信息
agent_info = generate_agent_info(situation_info)

agent_dscision = [
    [1, 6, 0, 6, 3, 0, 0, 3, 30, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 25, 3, 0, 0, 32, 0, 0, 0, 0, 0, 11, 24, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 24, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 12, 8, 9, 23, 4, 0, 0,
     12, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35,
     0, 0, 0, 0, 0, 0, 6, 13, 0, 0, 15, 21]
]
