#!/usr/bin/env python3

from __future__ import annotations

import math
from pathlib import Path


PANEL = 1000.0
FRAME = 0.0
OPENING = PANEL - 2 * FRAME
GRID = 4
CELL = OPENING / GRID
PIECE = CELL
TOOTH_SPAN = 66.0
TOOTH_NECK = 30.0
TOOTH_DEPTH = 22.0

SVG_DIR = Path(__file__).resolve().parent.parent / "svg"
PIECES_DIR = SVG_DIR / "pieces"
MASKS_DIR = SVG_DIR / "piece_masks"
OUTLINES_DIR = SVG_DIR / "piece_outlines"
PREVIEW_DIR = Path(__file__).resolve().parent.parent / "preview"

CENTER = PANEL / 2
INNER_MIN = 188.0
INNER_MAX = PANEL - INNER_MIN
INNER_SIZE = INNER_MAX - INNER_MIN
CORNER_CENTERS = [(286.0, 286.0), (714.0, 286.0), (286.0, 714.0), (714.0, 714.0)]


def fmt(value: float) -> str:
    return f"{value:.3f}".rstrip("0").rstrip(".")


def polygon(points, fill="black", stroke="none", stroke_width=0, extra=""):
    pts = point_string(points)
    attrs = [f'points="{pts}"', f'fill="{fill}"', f'stroke="{stroke}"']
    if stroke != "none" and stroke_width:
        attrs.append(f'stroke-width="{fmt(stroke_width)}"')
    if extra:
        attrs.append(extra)
    return f"<polygon {' '.join(attrs)} />"


def rect(x, y, w, h, fill="black", stroke="none", stroke_width=0, extra=""):
    attrs = [
        f'x="{fmt(x)}"',
        f'y="{fmt(y)}"',
        f'width="{fmt(w)}"',
        f'height="{fmt(h)}"',
        f'fill="{fill}"',
        f'stroke="{stroke}"',
    ]
    if stroke != "none" and stroke_width:
        attrs.append(f'stroke-width="{fmt(stroke_width)}"')
    if extra:
        attrs.append(extra)
    return f"<rect {' '.join(attrs)} />"


def svg_doc(elements, view_box=(0, 0, PANEL, PANEL), width=None, height=None, background=None):
    vx, vy, vw, vh = view_box
    body = []
    if background is not None:
        body.append(rect(vx, vy, vw, vh, fill=background))
    body.extend(elements)
    width = vw if width is None else width
    height = vh if height is None else height
    return "\n".join(
        [
            '<?xml version="1.0" encoding="UTF-8"?>',
            (
                f'<svg xmlns="http://www.w3.org/2000/svg" '
                f'width="{fmt(width)}mm" height="{fmt(height)}mm" '
                f'viewBox="{fmt(vx)} {fmt(vy)} {fmt(vw)} {fmt(vh)}">'
            ),
            *body,
            "</svg>",
        ]
    )


def write_svg(name: str, elements, background=None):
    path = SVG_DIR / name
    path.write_text(svg_doc(elements, background=background) + "\n", encoding="utf-8")


def rotate_point(point, angle_deg, origin=(CENTER, CENTER)):
    ox, oy = origin
    px, py = point
    angle = math.radians(angle_deg)
    tx = px - ox
    ty = py - oy
    return (
        ox + tx * math.cos(angle) - ty * math.sin(angle),
        oy + tx * math.sin(angle) + ty * math.cos(angle),
    )


def rotate_points(points, angle_deg, origin=(CENTER, CENTER)):
    return [rotate_point(point, angle_deg, origin=origin) for point in points]


def regular_polygon(cx, cy, radius, sides, rotation_deg=0):
    return [
        (
            cx + radius * math.cos(math.radians(rotation_deg + 360 * idx / sides)),
            cy + radius * math.sin(math.radians(rotation_deg + 360 * idx / sides)),
        )
        for idx in range(sides)
    ]


def star_points(cx, cy, outer_r, inner_r, points, rotation_deg=0):
    coords = []
    for idx in range(points * 2):
        radius = outer_r if idx % 2 == 0 else inner_r
        angle = math.radians(rotation_deg + 180 * idx / points)
        coords.append((cx + radius * math.cos(angle), cy + radius * math.sin(angle)))
    return coords


def diamond(cx, cy, radius):
    return regular_polygon(cx, cy, radius, 4, rotation_deg=0)


def triangle(cx, cy, radius, rotation_deg=0):
    return regular_polygon(cx, cy, radius, 3, rotation_deg=rotation_deg)


def wide_segment(p1, p2, width):
    x1, y1 = p1
    x2, y2 = p2
    dx = x2 - x1
    dy = y2 - y1
    length = math.hypot(dx, dy)
    if length == 0:
        raise ValueError("Segment length cannot be zero.")
    nx = -dy / length
    ny = dx / length
    half = width / 2
    return [
        (x1 + nx * half, y1 + ny * half),
        (x1 - nx * half, y1 - ny * half),
        (x2 - nx * half, y2 - ny * half),
        (x2 + nx * half, y2 + ny * half),
    ]


def tile_origin(row, col):
    return (
        FRAME + col * CELL,
        FRAME + row * CELL,
    )


def point_string(points):
    return " ".join(f"{fmt(x)},{fmt(y)}" for x, y in points)


def right_edge_sign(row, col):
    if col >= GRID - 1:
        return 0
    return 1 if (row + col) % 2 == 0 else -1


def bottom_edge_sign(row, col):
    if row >= GRID - 1:
        return 0
    return 1 if (row + 2 * col) % 2 == 0 else -1


def piece_signs(row, col):
    return {
        "top": 0 if row == 0 else -bottom_edge_sign(row - 1, col),
        "right": right_edge_sign(row, col),
        "bottom": bottom_edge_sign(row, col),
        "left": 0 if col == 0 else -right_edge_sign(row, col - 1),
    }


def top_edge_points(x0, x1, y, sign):
    if sign == 0:
        return [(x0, y), (x1, y)]
    cx = (x0 + x1) / 2
    a = TOOTH_SPAN / 2
    b = TOOTH_NECK / 2
    offset = -sign * TOOTH_DEPTH
    return [(x0, y), (cx - a, y), (cx - b, y + offset), (cx + b, y + offset), (cx + a, y), (x1, y)]


def bottom_edge_points(x0, x1, y, sign):
    if sign == 0:
        return [(x0, y), (x1, y)]
    cx = (x0 + x1) / 2
    a = TOOTH_SPAN / 2
    b = TOOTH_NECK / 2
    offset = sign * TOOTH_DEPTH
    return [(x0, y), (cx - a, y), (cx - b, y + offset), (cx + b, y + offset), (cx + a, y), (x1, y)]


def right_edge_points(x, y0, y1, sign):
    if sign == 0:
        return [(x, y0), (x, y1)]
    cy = (y0 + y1) / 2
    a = TOOTH_SPAN / 2
    b = TOOTH_NECK / 2
    offset = sign * TOOTH_DEPTH
    return [(x, y0), (x, cy - a), (x + offset, cy - b), (x + offset, cy + b), (x, cy + a), (x, y1)]


def left_edge_points(x, y0, y1, sign):
    if sign == 0:
        return [(x, y0), (x, y1)]
    cy = (y0 + y1) / 2
    a = TOOTH_SPAN / 2
    b = TOOTH_NECK / 2
    offset = -sign * TOOTH_DEPTH
    return [(x, y0), (x, cy - a), (x + offset, cy - b), (x + offset, cy + b), (x, cy + a), (x, y1)]


def piece_outline(row, col):
    x0, y0 = tile_origin(row, col)
    x1 = x0 + CELL
    y1 = y0 + CELL
    signs = piece_signs(row, col)

    top = top_edge_points(x0, x1, y0, signs["top"])
    right = right_edge_points(x1, y0, y1, signs["right"])
    bottom = list(reversed(bottom_edge_points(x0, x1, y1, signs["bottom"])))
    left = list(reversed(left_edge_points(x0, y0, y1, signs["left"])))
    return top + right[1:] + bottom[1:] + left[1:]


def layer_border_chain():
    elems = []
    top_y = FRAME + 59
    bottom_y = PANEL - top_y
    side_x = top_y
    medallions = [246, 374, 500, 626, 754]

    for x in medallions:
        elems.append(polygon(diamond(x, top_y, 39)))
        elems.append(polygon(diamond(x, bottom_y, 39)))
    for y in medallions:
        elems.append(polygon(diamond(side_x, y, 39)))
        elems.append(polygon(diamond(PANEL - side_x, y, 39)))

    for x in [310, 437, 563, 690]:
        elems.append(polygon(diamond(x, top_y, 18)))
        elems.append(polygon(diamond(x, bottom_y, 18)))
    for y in [310, 437, 563, 690]:
        elems.append(polygon(diamond(side_x, y, 18)))
        elems.append(polygon(diamond(PANEL - side_x, y, 18)))

    return elems


def layer_border_insets():
    elems = []
    top_y = FRAME + 59
    bottom_y = PANEL - top_y
    side_x = top_y
    medallions = [246, 374, 500, 626, 754]

    for x in medallions:
        elems.append(polygon(star_points(x, top_y, 22, 10, 8, rotation_deg=22.5)))
        elems.append(polygon(star_points(x, bottom_y, 22, 10, 8, rotation_deg=22.5)))
    for y in medallions:
        elems.append(polygon(star_points(side_x, y, 22, 10, 8, rotation_deg=22.5)))
        elems.append(polygon(star_points(PANEL - side_x, y, 22, 10, 8, rotation_deg=22.5)))

    return elems


def layer_inner_frame_outer():
    return [rect(INNER_MIN, INNER_MIN, INNER_SIZE, INNER_SIZE)]


def layer_inner_frame_inner():
    return [rect(INNER_MIN + 16, INNER_MIN + 16, INNER_SIZE - 32, INNER_SIZE - 32)]


def layer_corner_boxes_outer():
    elems = []
    for cx, cy in CORNER_CENTERS:
        elems.append(rect(cx - 68, cy - 68, 136, 136))
    return elems


def layer_corner_boxes_inner():
    elems = []
    for cx, cy in CORNER_CENTERS:
        elems.append(rect(cx - 52, cy - 52, 104, 104))
    return elems


def layer_corner_diamonds():
    return [polygon(diamond(cx, cy, 52)) for cx, cy in CORNER_CENTERS]


def layer_corner_rosettes():
    return [polygon(star_points(cx, cy, 24, 11, 8, rotation_deg=22.5)) for cx, cy in CORNER_CENTERS]


def layer_arm_bands(width=38):
    elems = []
    base_left = wide_segment((456, 248), (495, 410), width)
    base_right = wide_segment((544, 248), (505, 410), width)
    for angle in [0, 90, 180, 270]:
        elems.append(polygon(rotate_points(base_left, angle)))
        elems.append(polygon(rotate_points(base_right, angle)))
    return elems


def layer_arm_rosettes():
    elems = []
    north = diamond(CENTER, 252, 32)
    for angle in [0, 90, 180, 270]:
        elems.append(polygon(rotate_points(north, angle)))
    return elems


def layer_center_octagon_outer():
    return [polygon(regular_polygon(CENTER, CENTER, 118, 8, rotation_deg=22.5))]


def layer_center_octagon_inner():
    return [polygon(regular_polygon(CENTER, CENTER, 92, 8, rotation_deg=22.5))]


def layer_center_star():
    return [polygon(star_points(CENTER, CENTER, 74, 35, 8, rotation_deg=22.5))]


def layer_center_core():
    return [polygon(star_points(CENTER, CENTER, 28, 13, 8, rotation_deg=22.5))]


def layer_center_satellites():
    elems = []
    for angle in [0, 45, 90, 135, 180, 225, 270, 315]:
        point = rotate_point((CENTER, CENTER - 58), angle)
        elems.append(polygon(diamond(point[0], point[1], 11)))
    return elems


def guide_piece_layout():
    elems = [
        rect(0, 0, PANEL, PANEL, fill="white", stroke="#111111", stroke_width=2),
        rect(FRAME, FRAME, OPENING, OPENING, fill="none", stroke="#111111", stroke_width=2),
    ]
    for row in range(GRID):
        for col in range(GRID):
            elems.append(polygon(piece_outline(row, col), fill="none", stroke="#666666", stroke_width=1.4))
    return elems


def write_frame_profiles():
    write_svg("frame_outer_profile.svg", [rect(0, 0, PANEL, PANEL)])
    write_svg("frame_inner_pocket.svg", [rect(FRAME, FRAME, OPENING, OPENING)])
    write_svg("backer_panel.svg", [rect(0, 0, PANEL, PANEL)])


def build_preview_layers():
    border_chain = layer_border_chain()
    border_insets = layer_border_insets()
    inner_outer = layer_inner_frame_outer()
    inner_inner = layer_inner_frame_inner()
    corner_boxes_outer = layer_corner_boxes_outer()
    corner_boxes_inner = layer_corner_boxes_inner()
    corner_diamonds = layer_corner_diamonds()
    corner_rosettes = layer_corner_rosettes()
    arm_bands = layer_arm_bands(width=40)
    arm_insets = layer_arm_bands(width=18)
    arm_rosettes = layer_arm_rosettes()
    center_oct_outer = layer_center_octagon_outer()
    center_oct_inner = layer_center_octagon_inner()
    center_star = layer_center_star()
    center_core = layer_center_core()
    center_satellites = layer_center_satellites()

    layers = {
        "border_chain": border_chain,
        "border_insets": border_insets,
        "inner_outer": inner_outer,
        "inner_inner": inner_inner,
        "corner_boxes_outer": corner_boxes_outer,
        "corner_boxes_inner": corner_boxes_inner,
        "corner_diamonds": corner_diamonds,
        "corner_rosettes": corner_rosettes,
        "arm_bands": arm_bands,
        "arm_insets": arm_insets,
        "arm_rosettes": arm_rosettes,
        "center_oct_outer": center_oct_outer,
        "center_oct_inner": center_oct_inner,
        "center_star": center_star,
        "center_core": center_core,
        "center_satellites": center_satellites,
    }
    return layers


def recolor(elements, fill):
    return [element.replace('fill="black"', f'fill="{fill}"') for element in elements]


def make_full_preview(layers):
    base_wood = "#4d382d"
    dark_wood = "#2c211d"
    tan = "#d7b970"
    sand = "#f1e6c5"
    copper = "#8a5b3c"
    silver = "#d9dbdc"
    accent = "#a4472f"

    elems = [
        rect(0, 0, PANEL, PANEL, fill="#171313"),
        rect(FRAME, FRAME, OPENING, OPENING, fill=base_wood),
        rect(22, 22, PANEL - 44, PANEL - 44, fill="none", stroke="#2a2422", stroke_width=32),
        *recolor(layers["border_chain"], tan),
        *recolor(layers["border_insets"], dark_wood),
        *recolor(layers["inner_outer"], tan),
        *recolor(layers["inner_inner"], base_wood),
        *recolor(layers["corner_boxes_outer"], tan),
        *recolor(layers["corner_boxes_inner"], dark_wood),
        *recolor(layers["corner_diamonds"], tan),
        *recolor(layers["corner_rosettes"], dark_wood),
        *recolor(layers["arm_bands"], tan),
        *recolor(layers["arm_insets"], copper),
        *recolor(layers["arm_rosettes"], tan),
        *recolor(layers["center_oct_outer"], tan),
        *recolor(layers["center_oct_inner"], copper),
        *recolor(layers["center_star"], tan),
        *recolor(layers["center_satellites"], sand),
        *recolor(layers["center_core"], accent),
    ]

    for row in range(GRID):
        for col in range(GRID):
            elems.append(polygon(piece_outline(row, col), fill="none", stroke="#221917", stroke_width=2))
    return elems


def write_piece_masks():
    for row in range(GRID):
        for col in range(GRID):
            piece_id = row * GRID + col + 1
            path = MASKS_DIR / f"piece_{piece_id:02d}.svg"
            path.write_text(
                svg_doc([polygon(piece_outline(row, col))], background=None) + "\n",
                encoding="utf-8",
            )


def write_piece_outlines():
    for row in range(GRID):
        for col in range(GRID):
            outline = piece_outline(row, col)
            xs = [point[0] for point in outline]
            ys = [point[1] for point in outline]
            min_x = min(xs)
            min_y = min(ys)
            width = max(xs) - min_x
            height = max(ys) - min_y
            local_outline = [(x - min_x, y - min_y) for x, y in outline]
            piece_id = row * GRID + col + 1
            path = OUTLINES_DIR / f"piece_{piece_id:02d}_outline.svg"
            path.write_text(
                svg_doc(
                    [polygon(local_outline, fill="none", stroke="black", stroke_width=1.2)],
                    view_box=(0, 0, width, height),
                    width=width,
                    height=height,
                    background=None,
                ) + "\n",
                encoding="utf-8",
            )


def write_piece_views(preview_elements):
    for row in range(GRID):
        for col in range(GRID):
            outline = piece_outline(row, col)
            xs = [point[0] for point in outline]
            ys = [point[1] for point in outline]
            margin = 8
            min_x = min(xs) - margin
            min_y = min(ys) - margin
            width = max(xs) - min(xs) + 2 * margin
            height = max(ys) - min(ys) + 2 * margin
            piece_id = row * GRID + col + 1
            path = PIECES_DIR / f"piece_{piece_id:02d}.svg"
            clip = point_string(outline)
            content = [
                '<?xml version="1.0" encoding="UTF-8"?>',
                (
                    f'<svg xmlns="http://www.w3.org/2000/svg" '
                    f'width="{fmt(width)}mm" height="{fmt(height)}mm" '
                    f'viewBox="{fmt(min_x)} {fmt(min_y)} {fmt(width)} {fmt(height)}">'
                ),
                "<defs>",
                f'<clipPath id="piece-clip"><polygon points="{clip}" /></clipPath>',
                "</defs>",
                '<g clip-path="url(#piece-clip)">',
                *preview_elements,
                "</g>",
                f'<polygon points="{clip}" fill="none" stroke="#111111" stroke-width="2.2" />',
                (
                    f'<text x="{fmt(min_x + 10)}" y="{fmt(min_y + 22)}" '
                    'font-family="Helvetica, Arial, sans-serif" font-size="16" '
                    f'fill="#f7f1de">P{piece_id:02d}</text>'
                ),
                "</svg>",
            ]
            path.write_text(
                "\n".join(content) + "\n",
                encoding="utf-8",
            )


def main():
    SVG_DIR.mkdir(parents=True, exist_ok=True)
    PIECES_DIR.mkdir(parents=True, exist_ok=True)
    MASKS_DIR.mkdir(parents=True, exist_ok=True)
    OUTLINES_DIR.mkdir(parents=True, exist_ok=True)
    PREVIEW_DIR.mkdir(parents=True, exist_ok=True)

    layers = build_preview_layers()

    write_frame_profiles()
    write_svg("guide_frame_opening.svg", guide_piece_layout())
    write_svg("layer_border_chain.svg", layers["border_chain"])
    write_svg("layer_border_insets.svg", layers["border_insets"])
    write_svg("layer_inner_frame_outer.svg", layers["inner_outer"])
    write_svg("layer_inner_frame_inner.svg", layers["inner_inner"])
    write_svg("layer_corner_boxes_outer.svg", layers["corner_boxes_outer"])
    write_svg("layer_corner_boxes_inner.svg", layers["corner_boxes_inner"])
    write_svg("layer_corner_diamonds.svg", layers["corner_diamonds"])
    write_svg("layer_corner_rosettes.svg", layers["corner_rosettes"])
    write_svg("layer_arm_bands.svg", layers["arm_bands"])
    write_svg("layer_arm_insets.svg", layers["arm_insets"])
    write_svg("layer_arm_rosettes.svg", layers["arm_rosettes"])
    write_svg("layer_center_octagon_outer.svg", layers["center_oct_outer"])
    write_svg("layer_center_octagon_inner.svg", layers["center_oct_inner"])
    write_svg("layer_center_star.svg", layers["center_star"])
    write_svg("layer_center_core.svg", layers["center_core"])
    write_svg("layer_center_satellites.svg", layers["center_satellites"])

    preview_elements = make_full_preview(layers)
    write_svg("guide_preview.svg", preview_elements)
    write_piece_masks()
    write_piece_outlines()
    write_piece_views(preview_elements)


if __name__ == "__main__":
    main()
