import matplotlib.pyplot as plt
import numpy as np


class Dijkstra:
    def __init__(self):
        self.open_list = None
        self.close_list = None
        pass

    def draw_effect(self, map_grid, second_path, x, y, title="Path Heatmap"):
        """
        可视化显示路径规划结果热力图
        参数：
            map_grid (np.ndarray): 二维网格地图数据，0表示障碍，数值越大表示通行成本越高
            second_path (list/ndarray): 规划路径坐标列表，格式为[[y0,x0], [y1,x1], ...]
            x (int): 地图网格的横向维度
            y (int): 地图网格的纵向维度
            title (str): 图表标题，默认为"Path Heatmap"
        返回：
            matplotlib.figure.Figure: 生成的可视化图表对象
        """
        plt.imshow(map_grid, cmap=plt.cm.hot, interpolation='nearest', vmin=0, vmax=10)  # 绘制热力图
        plt.colorbar()
        plt.xlim(-1, x)  # x轴的范围
        plt.ylim(-1, y)

        # 设置坐标轴刻度
        my_x_ticks = np.arange(0, x, 1)  # x轴标号的范围
        my_y_ticks = np.arange(0, y, 1)
        plt.xticks(my_x_ticks)
        plt.yticks(my_y_ticks)
        second_path = np.array(second_path)
        plt.plot(second_path[:, 1:2], second_path[:, 0:1], '-')
        plt.title(title)
        # plt.grid(True)  # 开启栅格  可以不开启
        plt.show()  # 可视化

    def open_list_append(slef, x0, y0, x1, y1, h0, h1, map_grid, open_list):
        """
            处理相邻节点到开放列表的添加/更新操作（用于D*/A*类路径规划算法）
            参数：
                x0, y0: 当前节点的坐标
                x1, y1: 待处理的相邻节点坐标
                h0: 当前节点到起点的实际代价
                h1: 当前节点到相邻节点的移动代价
                map_grid: 二维网格地图，数值含义：
                    0=障碍, 3=在open_list, 4=在close_list, 5=起点, 6=终点
                open_list: 当前open列表，元素格式为[x, y, h, k, parent_x, parent_y]
            返回：
                open_list
        """
        map_row = len(map_grid)
        if map_row < 1:
            return
        map_col = len(map_grid[0])
        if map_col < 1:
            return
        if (0 <= x1 and x1 < map_row and 0 <= y1 and y1 < map_col and map_grid[x1, y1] != 4 and map_grid[
            x1, y1] != 0):  # 左边没有越界并且没有在closelist里面
            if map_grid[x1, y1] == 3:  # 如果是在open_list中,h要更新
                open_list = np.array(open_list)
                if (h1 + h0) < open_list[np.where((open_list[:, 0] == x1) & (open_list[:, 1] == y1)), 2]:
                    h = h1 + h0
                    k = h1 + h0
                    open_list[np.where((open_list[:, 0] == x1) & (open_list[:, 1] == y1)), 2] = h
                    open_list[np.where((open_list[:, 0] == x1) & (open_list[:, 1] == y1)), 3] = k
                    open_list[np.where((open_list[:, 0] == x1) & (open_list[:, 1] == y1)), 4] = x0
                    open_list[np.where((open_list[:, 0] == x1) & (open_list[:, 1] == y1)), 4] = y0
                open_list = list(open_list.tolist())

            else:  # 是new节点
                h = h1 + h0
                k = h1 + h0
                # open_list = list(open_list)
                open_list.append([x1, y1, h, k, x0, y0])
                map_grid[x1, y1] = 3

        return open_list

    def first_search(self, open_list, close_list, map_grid):  # 给出终点坐标，完成首次遍历
        """
            执行D*算法的首次搜索迭代，采用D算法遍历
            参数：
                open_list: 优先队列，元素格式为[x, y, h, k, parent_x, parent_y]
                close_list: 已探索节点列表
                map_grid: 二维网格地图
            返回：
                更新后的(open_list, close_list, map_grid)
        """
        # 采用D算法遍历
        # 选openlist中h最小的,将openlist按照h排序，取第一个，并删除第一个，将它放到close_list里面
        open_list = list(open_list)
        open_list.sort(key=lambda x: x[2])
        # open_list.pop(0)
        insert_list = open_list[0]  # 引入中间列表，用来存储每一次被选中的遍历的点
        x0 = int(insert_list[0])
        y0 = int(insert_list[1])
        open_list.pop(0)
        close_list.append(list(insert_list))
        map_grid[x0, y0] = 4  # 被加入到close_list里面

        # 找insert_list的邻域 ----->寻找顺序：从左边开始逆时针
        h0 = int(insert_list[2])

        x1 = x0
        y1 = y0 - 1
        h1 = 10
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 - 1
        y1 = y0 - 1
        h1 = 14
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 - 1
        y1 = y0
        h1 = 10
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 - 1
        y1 = y0 + 1
        h1 = 14
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0
        y1 = y0 + 1
        h1 = 10
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 + 1
        y1 = y0 + 1
        h1 = 14
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 + 1
        y1 = y0
        h1 = 10
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        x1 = x0 + 1
        y1 = y0 - 1
        h1 = 14
        open_list = self.open_list_append(x0, y0, x1, y1, h0, h1, map_grid, open_list)

        return [open_list, close_list, map_grid]
