import math
import random
import matplotlib.pyplot as plt
import numpy as np
import copy

random.seed(0)
np.random.seed(0)

x_max=6
x_min=0
J_Num=5
I_Num=5
Threat_v=[random.random() * 100 for i in range(J_Num)]
Weapon_n=[random.randint(0,100) for i in range(I_Num)]
Crack_res=[random.randint(0,1000) for i in range(J_Num)]
Feasible_ij=np.random.randint(0,2,size=(I_Num,J_Num))
Interception_Pro=np.random.rand(I_Num,J_Num)
X_num=np.random.randint(x_min,x_max,size=(I_Num,J_Num))




# x为公式里的x1,y为公式里面的x2
class SA:
    def __init__(self, iter=100, T0=100, Tf=10, alpha=0.99):
        self.iter = iter  # 内循环迭代次数,即为L =100
        self.alpha = alpha  # 降温系数，alpha=0.99
        self.T0 = T0  # 初始温度T0为100
        self.Tf = Tf  # 温度终值Tf为0.01
        self.T = T0  # 当前温度
        self.x = np.random.randint(0,6,size=(I_Num,J_Num))  # 随机生成100个x的值
        self.most_best = []
        """
        random()这个函数取0到1之间的小数
        如果你要取0-10之间的整数（包括0和10）就写成 (int)random()*11就可以了，11乘以零点多的数最大是10点多，最小是0点多
        该实例中x1和x2的绝对值不超过5（包含整数5和-5），（random() * 11 -5）的结果是-6到6之间的任意值（不包括-6和6）
        （random() * 10 -5）的结果是-5到5之间的任意值（不包括-5和5），所有先乘以11，取-6到6之间的值，产生新解过程中，用一个if条件语句把-5到5之间（包括整数5和-5）的筛选出来。
        """
        self.history = {'f': [], 'T': []}

    def func(self,x):  # 函数优化问题
        res = 0
        for j in range(J_Num):
            pro_num = 1
            for i in range(I_Num):
                pro_num = pro_num * ((1 - Interception_Pro[i, j]) ** x[i, j])
            res += Threat_v[j] * pro_num
        return res
    def generate_new(self, x):  # 扰动产生新解的过程
        x_new=copy.deepcopy(x)
        while True:
            for i in range(I_Num):
                while True:
                    for j in range(J_Num):
                        if Feasible_ij[i,j]==1:
                                x_new[i,j] = max(min(x[i,j] + int(self.T/self.T0 * (random.randint(x_max-x_min*2,x_max))),x_max),x_min)
                        else:
                            x_new[i,j] = 0
                    if np.sum(x_new[i,:])<=Weapon_n[i]:
                        break
            temp = 0
            for j in range(J_Num):
                if np.sum(x_new[:,j])>Crack_res[j]:
                    temp=1
                    break
            if temp==0: break
        return x_new

    def Metrospolis(self, f, f_new):  # Metropolis准则
        if f_new <= f:
            return 1
        else:
            p = math.exp((f - f_new) / self.T)
            if random.random() < p:
                return 1
            else:
                return 0

    # def best(self):  # 获取最优目标函数值
    #     f_list = []  # f_list数组保存每次迭代之后的值
    #     for i in range(self.iter):
    #         f = self.func(self.x[i], self.y[i])
    #         f_list.append(f)
    #     f_best = min(f_list)
    #
    #     idx = f_list.index(f_best)
    #     return f_best, idx  # f_best,idx分别为在该温度下，迭代L次之后目标函数的最优解和最优解的下标

    def run(self):
        count = 0
        self.x=self.generate_new(self.x)
        record_x=[]
        record_f=[]
        # 外循环迭代，当前温度小于终止温度的阈值
        while self.T > self.Tf:

            # 内循环迭代100次
            for i in range(self.iter):
                f =self.func(self.x)  # f为迭代一次后的值
                x_new = self.generate_new(self.x)  # 产生新解
                f_new = self.func(x_new)  # 产生新值
                if self.Metrospolis(f, f_new):  # 判断是否接受新值
                    self.x = x_new  # 如果接受新值，则把新值的x,y存入x数组和y数组
                    f=f_new
                record_x.append(self.x)
                record_f.append(f)
            # 迭代L次记录在该温度下最优解
            # ft, _ = self.best()
            self.history['f'].append(f_new)
            self.history['T'].append(self.T)
            # 温度按照一定的比例下降（冷却）
            self.T = self.T * self.alpha
            count += 1
            print(min(record_f),"---",record_f.index(min(record_f)))
        t_idex=record_f.index(min(record_f))
        x=record_x[t_idex]
        return x
            # 得到最优解
        # f_best, idx = self.best()
        # print(f"F={f_best}, x={self.x[idx]}, y={self.y[idx]}")



sa = SA()
X_value=sa.run() #todo 算法输出值

# plt.plot(sa.history['T'], sa.history['f'])
# plt.title('SA')
# plt.xlabel('T')
# plt.ylabel('f')
# plt.gca().invert_xaxis()
# plt.show()