import numpy as np
import geopandas as gpd
from shapely.geometry import Point, Polygon, box, MultiPolygon
from pyproj import CRS
import rtree
import matplotlib.pyplot as plt
import time
import warnings

warnings.filterwarnings('ignore', category=UserWarning)  # 忽略投影警告


class MultiRegionCoverageCalculator:
    def __init__(self, polygons_coords, grid_size_m=1000, buffer_m=5000):
        """
        基于经纬度的多区域网格化与覆盖计算

        参数:
        polygons_coords: 多个多边形经纬度坐标列表 [[(lng1,lat1), (lng2,lat2),...], [...], ...]
        grid_size_m: 网格大小(米)
        buffer_m: 多边形外扩距离(米)
        """
        # 创建地理多边形集合
        self.polygons = [Polygon(coords) for coords in polygons_coords]
        self.multi_polygon = MultiPolygon(self.polygons)

        # 确定合适的UTM投影（以第一个多边形为中心）
        self.utm_crs = self._determine_utm_crs(self.polygons[0])

        # 创建GeoDataFrame
        self.gdf = gpd.GeoDataFrame(geometry=self.polygons, crs="EPSG:4326")

        # 转换为UTM投影以进行米制计算
        self.gdf_utm = self.gdf.to_crs(self.utm_crs)
        self.utm_polygons = list(self.gdf_utm.geometry)

        # 创建带缓冲区的外扩多边形
        self.buffered_utm_polygons = [poly.buffer(buffer_m) for poly in self.utm_polygons]

        # 网格参数
        self.grid_size = grid_size_m

        # 生成网格
        self.grid_gdf = self._create_grid()

        # 构建空间索引
        self.idx = self._build_spatial_index()

    def _determine_utm_crs(self, polygon):
        """根据多边形中心点确定合适的UTM投影"""
        centroid = polygon.centroid
        lon, lat = centroid.x, centroid.y

        # UTM带号计算
        utm_zone = int((lon + 180) / 6) + 1
        hemisphere = 'north' if lat >= 0 else 'south'

        # 返回UTM CRS
        return CRS(f"+proj=utm +zone={utm_zone} +{hemisphere} +ellps=WGS84 +datum=WGS84 +units=m +no_defs")

    def _create_grid(self):
        """创建覆盖所有多边形的网格"""
        # 获取所有多边形的合并边界框
        bounds = self.gdf_utm.total_bounds
        minx, miny, maxx, maxy = bounds

        # 计算网格数量
        num_x = int(np.ceil((maxx - minx) / self.grid_size))
        num_y = int(np.ceil((maxy - miny) / self.grid_size))

        grid_cells = []
        cell_id = 0

        # 生成网格
        for i in range(num_x):
            for j in range(num_y):
                # 计算网格边界（UTM坐标）
                x1 = minx + i * self.grid_size
                y1 = miny + j * self.grid_size
                x2 = x1 + self.grid_size
                y2 = y1 + self.grid_size

                # 创建网格矩形
                grid_rect = box(x1, y1, x2, y2)

                # 检查网格是否与任何关键区域相交
                for poly_idx, poly in enumerate(self.utm_polygons):
                    if poly.intersects(grid_rect):
                        # 计算网格中心点
                        center_x = (x1 + x2) / 2
                        center_y = (y1 + y2) / 2

                        # 确定网格是否完全在当前多边形内
                        is_fully_inside = poly.contains(grid_rect)

                        # 计算网格权重（中心点离多边形中心越近权重越高）
                        centroid = poly.centroid
                        distance = Point(center_x, center_y).distance(centroid)
                        max_distance = Point(minx, miny).distance(centroid)
                        weight = max(0.1, 1.0 - (distance / max_distance))

                        # 创建网格几何
                        grid_geometry = Polygon([
                            (x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)
                        ])

                        grid_cells.append({
                            'id': cell_id,
                            'polygon_idx': poly_idx,  # 所属多边形索引
                            'bounds': (x1, y1, x2, y2),
                            'center_utm': (center_x, center_y),
                            'is_fully_inside': is_fully_inside,
                            'weight': weight,
                            'geometry': grid_geometry
                        })
                        cell_id += 1
                        break  # 一个网格只属于一个多边形

        # 创建GeoDataFrame
        grid_gdf = gpd.GeoDataFrame(grid_cells, crs=self.utm_crs)

        # 添加WGS84坐标（经纬度）
        grid_gdf['geometry_wgs'] = grid_gdf.geometry.to_crs("EPSG:4326")

        # 计算网格中心点的经纬度
        grid_gdf['center_lng'] = grid_gdf.geometry_wgs.apply(lambda geom: geom.centroid.x)
        grid_gdf['center_lat'] = grid_gdf.geometry_wgs.apply(lambda geom: geom.centroid.y)

        return grid_gdf

    def _build_spatial_index(self):
        """构建网格空间索引"""
        idx = rtree.index.Index()

        for i, row in self.grid_gdf.iterrows():
            bounds = row.geometry.bounds
            idx.insert(i, bounds)

        return idx

    def calculate_coverage(self, lng, lat, radius_m):
        """
        计算预警机覆盖比例

        参数:
        lng: 预警机经度
        lat: 预警机纬度
        radius_m: 探测半径(米)

        返回:
        coverage_ratios: 各区域覆盖比例列表 [0.0-1.0, ...]
        total_coverage: 总覆盖比例 (0.0-1.0)
        covered_grids: 被覆盖的网格索引列表
        """
        # 创建预警机点（WGS84）
        awacs_point_wgs = Point(lng, lat)

        # 转换为UTM投影
        awacs_gdf = gpd.GeoDataFrame(geometry=[awacs_point_wgs], crs="EPSG:4326")
        awacs_gdf_utm = awacs_gdf.to_crs(self.utm_crs)
        awacs_point_utm = awacs_gdf_utm.geometry.iloc[0]

        # 计算探测范围边界框（用于空间索引查询）
        minx, miny, maxx, maxy = awacs_point_utm.buffer(radius_m).bounds

        # 使用空间索引查询可能被覆盖的网格
        candidate_indices = list(self.idx.intersection((minx, miny, maxx, maxy)))
        candidate_grids = self.grid_gdf.loc[candidate_indices]

        # 初始化各区域统计
        region_stats = {
            idx: {'total_weight': 0.0, 'covered_weight': 0.0}
            for idx in range(len(self.polygons))
        }
        covered_grids = []

        # 检查每个候选网格是否真正被覆盖
        for idx, row in candidate_grids.iterrows():
            # 计算网格中心点到预警机的距离
            center_point = Point(row.center_utm)
            distance = awacs_point_utm.distance(center_point)

            # 如果网格在探测范围内
            if distance <= radius_m:
                # 考虑距离衰减
                decay_factor = 1.0 - (distance / radius_m) ** 2
                weighted_value = row['weight'] * decay_factor

                # 更新区域统计
                region_idx = row['polygon_idx']
                region_stats[region_idx]['total_weight'] += row['weight']
                region_stats[region_idx]['covered_weight'] += weighted_value

                covered_grids.append(idx)

        # 计算各区域覆盖比例
        coverage_ratios = []
        for idx in range(len(self.polygons)):
            total = region_stats[idx]['total_weight']
            covered = region_stats[idx]['covered_weight']
            coverage_ratios.append(covered / total if total > 0 else 0.0)

        # 计算总覆盖比例
        total_weight = sum(stat['total_weight'] for stat in region_stats.values())
        total_covered = sum(stat['covered_weight'] for stat in region_stats.values())
        total_coverage = total_covered / total_weight if total_weight > 0 else 0.0

        return coverage_ratios, total_coverage, covered_grids

    def visualize(self, awacs_lng=None, awacs_lat=None, radius_m=None, coverage_data=None):
        """可视化关键区域和覆盖情况"""
        fig, ax = plt.subplots(figsize=(12, 10))

        # 绘制所有关键区域
        colors = plt.cm.tab10.colors  # 使用不同颜色区分不同区域
        for idx, poly in enumerate(self.polygons):
            patch = plt.Polygon(
                list(poly.exterior.coords),
                closed=True,
                fill=True,
                color=colors[idx % len(colors)],
                alpha=0.3,
                edgecolor=colors[idx % len(colors)],
                linewidth=1.5,
                label=f'region {idx + 1}'
            )
            ax.add_patch(patch)

        # 绘制网格
        self.grid_gdf['geometry_wgs'].plot(
            ax=ax, facecolor='none', edgecolor='gray', linewidth=0.5, alpha=0.3
        )

        # 绘制覆盖情况
        if coverage_data:
            _, _, covered_grids = coverage_data
            covered_gdf = self.grid_gdf.loc[covered_grids]

            # 按区域绘制不同颜色的覆盖网格
            for idx in range(len(self.polygons)):
                region_covered = covered_gdf[covered_gdf['polygon_idx'] == idx]
                if not region_covered.empty:
                    region_covered['geometry_wgs'].plot(
                        ax=ax,
                        color=colors[idx % len(colors)],
                        alpha=0.6,
                        edgecolor=colors[idx % len(colors)],
                        label=f'region {idx + 1} covered'
                    )

        # 绘制预警机位置和探测范围
        if awacs_lng and awacs_lat and radius_m:
            # 绘制预警机位置
            ax.scatter(awacs_lng, awacs_lat, color='red', s=100,
                       label='sentry', edgecolors='black', zorder=10)

            # 创建探测范围圆（UTM坐标）
            awacs_point_wgs = Point(awacs_lng, awacs_lat)
            awacs_gdf = gpd.GeoDataFrame(geometry=[awacs_point_wgs], crs="EPSG:4326")
            awacs_gdf_utm = awacs_gdf.to_crs(self.utm_crs)

            # 创建圆形并转换回WGS84
            detection_circle_utm = awacs_gdf_utm.geometry.iloc[0].buffer(radius_m)
            detection_gdf_utm = gpd.GeoDataFrame(geometry=[detection_circle_utm], crs=self.utm_crs)
            detection_gdf_wgs = detection_gdf_utm.to_crs("EPSG:4326")

            # 绘制探测范围
            detection_gdf_wgs.plot(
                ax=ax, facecolor='none', edgecolor='red', linestyle='--',
                linewidth=1.5, label=f'detect range ({radius_m / 1000:.0f}km)'
            )

        # 设置图形属性
        ax.set_title('coverage analysis', fontsize=16)
        ax.set_xlabel('lon')
        ax.set_ylabel('lat')
        plt.legend(loc='upper right')
        plt.tight_layout()
        plt.show()


# ========================
# 使用示例
# ========================

if __name__ == "__main__":
    # 1. 定义多个不规则关键区域多边形（经纬度坐标）
    regions = [
        # H机场区域
        [
            (-96.5, 18.3), (-94.2, 19.6), (-92.23, 17.99),
            (-93.11, 16.3), (-95.8, 17.2)
        ],
        # 世宗区域
        [
            (-90.99, 21.87), (-89.53, 21.36), (-89.75, 20.18),
            (-91.11, 20.0), (-91.74, 20.1)
        ],
        [(-85.78,12.86), (-84.26,12.52), (-85.4,11.8)]
    ]

    # 2. 创建多区域覆盖计算器
    print("创建多区域覆盖计算器...")
    start_time = time.time()
    calculator = MultiRegionCoverageCalculator(
        polygons_coords=regions,
        grid_size_m=10000,  # 2公里网格
        buffer_m=10000  # 5公里外扩
    )
    print(f"创建完成，耗时: {time.time() - start_time:.2f}秒")
    print(f"生成的网格数量: {len(calculator.grid_gdf)}")

    # 3. 定义预警机位置和探测范围 <POINT (-87.258 15.571)>
    awacs_lng, awacs_lat = -84.51 ,12.27# -92.45, 19
    detection_radius = 449000.19716271957077  # 30公里探测半径

    # 4. 计算覆盖比例
    print("\n计算覆盖比例...")
    start_time = time.time()
    coverage_ratios, total_coverage, covered_grids = calculator.calculate_coverage(
        awacs_lng, awacs_lat, detection_radius
    )
    print(f"计算完成，耗时: {(time.time() - start_time) * 1000:.2f}毫秒")

    # 打印各区域覆盖情况
    for idx, ratio in enumerate(coverage_ratios):
        print(f"区域 {idx + 1} 覆盖比例: {ratio:.2%}")
    print(f"总覆盖比例: {total_coverage:.2%}")
    print(f"被覆盖网格数: {len(covered_grids)}/{len(calculator.grid_gdf)}")

    # 5. 可视化结果
    print("\n生成可视化...")
    calculator.visualize(
        awacs_lng=awacs_lng,
        awacs_lat=awacs_lat,
        radius_m=detection_radius,
        coverage_data=(coverage_ratios, total_coverage, covered_grids)
    )
