import json
import os
import sys
import threading
import urllib.request
import xml.etree.ElementTree as ET
from collections import defaultdict
import math
import tempfile

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'backend'))

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

from flask import Flask, jsonify, render_template, request, send_file
from svg.path import parse_path, Path, Line, CubicBezier, QuadraticBezier, Arc, Move, Close

try:
    from serial.tools import list_ports as serial_ports
except ImportError as e:
    serial_ports = None
    print(f"[WARN] Serial tools not available: {e}")

try:
    from kanjiXYZ import CoreXYZMachine
    MACHINE_AVAILABLE = True
except ImportError as e:
    CoreXYZMachine = None
    MACHINE_AVAILABLE = False
    print(f"[WARN] Machine libraries not available: {e}")

machine = CoreXYZMachine(auto_connect=False) if CoreXYZMachine else None
plot_lock = threading.Lock()
plot_state = {"plotting": False, "message": "Idle"}

app = Flask(__name__)

KANJIDIC_PATH = os.path.join(BASE_DIR, "kanjidic2.xml")
KANJIVG_CACHE = os.path.join(BASE_DIR, "kanjivg_codes.txt")
KANJIVG_TREE_URL = (
    "https://api.github.com/repos/KanjiVG/kanjivg/git/trees/master?recursive=1"
)


def load_kanjivg_codes():
    """Return a set of Unicode codepoints (ints) that have a primary SVG in KanjiVG.
    Cached to disk after first successful fetch."""
    if os.path.exists(KANJIVG_CACHE):
        with open(KANJIVG_CACHE) as f:
            return {int(line.strip(), 16) for line in f if line.strip()}
    print("[INIT] Fetching KanjiVG file list from GitHub ...")
    try:
        with urllib.request.urlopen(KANJIVG_TREE_URL, timeout=15) as resp:
            data = json.load(resp)
    except Exception as e:
        print(f"[WARN] KanjiVG fetch failed ({e}); filter disabled.")
        return None
    codes = set()
    for entry in data.get("tree", []):
        path = entry.get("path", "")
        if not path.startswith("kanji/") or not path.endswith(".svg"):
            continue
        name = path[len("kanji/"):]
        if "-" in name:  # skip variant files like 065e5-Kaisho.svg
            continue
        try:
            codes.add(int(name[:5], 16))
        except ValueError:
            continue
    with open(KANJIVG_CACHE, "w") as f:
        for c in sorted(codes):
            f.write(f"{c:05x}\n")
    print(f"[INIT] Cached {len(codes)} KanjiVG codepoints to {KANJIVG_CACHE}.")
    return codes


KANJIVG_CODES = load_kanjivg_codes()


def in_kanjivg(char):
    if KANJIVG_CODES is None:
        return True  # fetch failed at startup; don't filter
    return ord(char) in KANJIVG_CODES


def parse_kanjidic():
    if not os.path.exists(KANJIDIC_PATH):
        print(f"[WARN] Missing KANJIDIC2 data at {KANJIDIC_PATH}; search index disabled.")
        return {}, {}
    english_to_kanji = defaultdict(list)
    kanji_to_meanings = {}
    tree = ET.parse(KANJIDIC_PATH)
    root = tree.getroot()
    for character in root.findall("character"):
        literal_el = character.find("literal")
        if literal_el is None or not literal_el.text:
            continue
        kanji = literal_el.text
        meanings = []
        for meaning in character.findall("reading_meaning/rmgroup/meaning"):
            if meaning.get("m_lang") is not None:
                continue
            if not meaning.text:
                continue
            word = meaning.text.strip()
            meanings.append(word)
            key = word.lower()
            if kanji not in english_to_kanji[key]:
                english_to_kanji[key].append(kanji)
        if meanings:
            kanji_to_meanings[kanji] = meanings
    return dict(english_to_kanji), kanji_to_meanings


def build_indexes():
    print("[INIT] Parsing KANJIDIC2 ...")
    e2k, k2m = parse_kanjidic()
    print(f"[INIT] Indexed {len(e2k)} English words and {len(k2m)} kanji.")
    return e2k, k2m


ENGLISH_TO_KANJI, KANJI_TO_MEANINGS = build_indexes()

svg_cache = {}
KANJIVG_CDN = "https://cdn.jsdelivr.net/gh/KanjiVG/kanjivg@master/kanji/"


def download_svg_to_temp(filename):
    if filename not in svg_cache:
        url = KANJIVG_CDN + filename
        with urllib.request.urlopen(url, timeout=10) as resp:
            svg_cache[filename] = resp.read().decode('utf-8')
    tmp = tempfile.NamedTemporaryFile(suffix='.svg', delete=False, mode='w')
    tmp.write(svg_cache[filename])
    tmp.close()
    return tmp.name






# --- English -> katakana transliteration (approximate) ----------------
# Ordered longest-first so digraphs match before single letters.
KATAKANA_RULES = [
    # 3-letter patterns
    ("tch", "ッチ"), ("sch", "シュ"),
    # C + ei/ey (English /iː/) — must precede 2-letter CV pairs
    ("bei", "ビー"), ("dei", "ディー"), ("fei", "フィー"),
    ("gei", "ギー"), ("hei", "ヒー"), ("kei", "キー"),
    ("lei", "リー"), ("mei", "ミー"), ("nei", "ニー"),
    ("pei", "ピー"), ("rei", "リー"), ("sei", "シー"),
    ("tei", "ティー"), ("wei", "ウィー"),
    # C + ee (English /iː/)
    ("bee", "ビー"), ("dee", "ディー"), ("fee", "フィー"),
    ("gee", "ジー"), ("kee", "キー"), ("lee", "リー"),
    ("mee", "ミー"), ("nee", "ニー"), ("pee", "ピー"),
    ("ree", "リー"), ("see", "シー"), ("tee", "ティー"),
    # 2-letter digraphs
    ("sh", "シ"), ("ch", "チ"), ("th", "ス"), ("ph", "フ"), ("ng", "ング"),
    # long/diphthong vowels
    ("ou", "ウ"), ("oo", "ウー"), ("ee", "イー"), ("ei", "イー"), ("ai", "アイ"),
    ("au", "オー"), ("aw", "オー"), ("ea", "イー"), ("ie", "アイ"),
    # CV pairs
    ("ka", "カ"), ("ki", "キ"), ("ku", "ク"), ("ke", "ケ"), ("ko", "コ"),
    ("ga", "ガ"), ("gi", "ギ"), ("gu", "グ"), ("ge", "ゲ"), ("go", "ゴ"),
    ("sa", "サ"), ("si", "シ"), ("su", "ス"), ("se", "セ"), ("so", "ソ"),
    ("za", "ザ"), ("zi", "ジ"), ("zu", "ズ"), ("ze", "ゼ"), ("zo", "ゾ"),
    ("ta", "タ"), ("ti", "チ"), ("tu", "ツ"), ("te", "テ"), ("to", "ト"),
    ("da", "ダ"), ("di", "ディ"), ("du", "ドゥ"), ("de", "デ"), ("do", "ド"),
    ("na", "ナ"), ("ni", "ニ"), ("nu", "ヌ"), ("ne", "ネ"), ("no", "ノ"),
    ("ha", "ハ"), ("hi", "ヒ"), ("hu", "フ"), ("he", "ヘ"), ("ho", "ホ"),
    ("ba", "バ"), ("bi", "ビ"), ("bu", "ブ"), ("be", "ベ"), ("bo", "ボ"),
    ("pa", "パ"), ("pi", "ピ"), ("pu", "プ"), ("pe", "ペ"), ("po", "ポ"),
    ("fa", "ファ"), ("fi", "フィ"), ("fu", "フ"), ("fe", "フェ"), ("fo", "フォ"),
    ("ma", "マ"), ("mi", "ミ"), ("mu", "ム"), ("me", "メ"), ("mo", "モ"),
    ("ya", "ヤ"), ("yu", "ユ"), ("yo", "ヨ"),
    ("ra", "ラ"), ("ri", "リ"), ("ru", "ル"), ("re", "レ"), ("ro", "ロ"),
    ("la", "ラ"), ("li", "リ"), ("lu", "ル"), ("le", "レ"), ("lo", "ロ"),
    ("wa", "ワ"), ("wi", "ウィ"), ("we", "ウェ"), ("wo", "ウォ"),
    ("va", "ヴァ"), ("vi", "ヴィ"), ("vu", "ヴ"), ("ve", "ヴェ"), ("vo", "ヴォ"),
    ("ja", "ジャ"), ("ji", "ジ"), ("ju", "ジュ"), ("je", "ジェ"), ("jo", "ジョ"),
    # standalone vowels
    ("a", "ア"), ("i", "イ"), ("u", "ウ"), ("e", "エ"), ("o", "オ"),
    # orphan consonants get a default trailing vowel
    ("k", "ク"), ("s", "ス"), ("t", "ト"), ("n", "ン"),
    ("h", "フ"), ("m", "ム"), ("y", "イ"), ("r", "ル"),
    ("g", "グ"), ("z", "ズ"), ("d", "ド"), ("b", "ブ"),
    ("p", "プ"), ("f", "フ"), ("v", "ヴ"), ("j", "ジ"),
    ("l", "ル"), ("w", "ウ"), ("q", "ク"), ("x", "クス"), ("c", "ク"),
]


def to_katakana(name):
    """Approximate English -> katakana. Not linguistically perfect;
    intended as a starting point the user can edit."""
    name = name.lower().strip()
    # strip common silent trailing letters
    if name.endswith("gh"):
        name = name[:-2]
    elif (name.endswith("h") and len(name) >= 2 and name[-2] in "aeiou"):
        name = name[:-1]

    out = []
    i = 0
    while i < len(name):
        ch = name[i]
        # doubled consonant -> small tsu
        if (i + 1 < len(name) and ch == name[i + 1]
                and ch in "bcdfgklmpqrstvwxz"):
            out.append("ッ")
            i += 1
            continue
        matched = False
        for pattern, kana in KATAKANA_RULES:
            if name.startswith(pattern, i):
                out.append(kana)
                i += len(pattern)
                matched = True
                break
        if not matched:
            i += 1  # unknown char, skip
    return "".join(out)


def to_filename(char):
    return f"{ord(char):05x}.svg"


def character_payload(char):
    return {
        "kanji": char,
        "filename": to_filename(char),
        "codepoint": f"U+{ord(char):04X}",
        "meanings": KANJI_TO_MEANINGS.get(char, []),
    }


def machine_status_payload():
    if machine and hasattr(machine, "status"):
        payload = machine.status()
    else:
        payload = {
            "connected": False,
            "origin_set": False,
            "emergency_stop": False,
            "pos": {"x": 0.0, "y": 0.0, "z": 0.0},
            "ports": {"a": None, "b": None, "c": None},
            "message": "Not connected.",
        }

    connected = bool(payload["connected"])
    if plot_state["plotting"]:
        message = plot_state["message"]
    else:
        message = payload["message"]

    payload["plotting"] = plot_state["plotting"]
    payload["homing"] = False
    payload["message"] = message
    return payload


def start_plot(svg_path):
    if not machine or not machine.connected:
        return False, "Not connected"
    with plot_lock:
        if plot_state["plotting"]:
            return False, "Already plotting"
        plot_state["plotting"] = True
        plot_state["message"] = "Plotting..."

    def worker():
        message = "Done!"
        try:
            machine.draw_svg(svg_path)
            if machine.emergency_stop:
                message = "Stopped."
        except Exception as e:
            message = f"Plot failed: {e}"
        finally:
            try:
                os.unlink(svg_path)
            except OSError:
                pass
            with plot_lock:
                plot_state["plotting"] = False
                plot_state["message"] = message

    threading.Thread(target=worker, daemon=True).start()
    return True, None


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/api/search')
def api_search():
    query = request.args.get('q', '').strip().lower()
    if not query:
        return jsonify({"query": query, "results": []})
    matches = ENGLISH_TO_KANJI.get(query, [])
    matches = [k for k in matches if in_kanjivg(k)]  # drop non-KanjiVG kanji
    results = [character_payload(k) for k in matches]
    print(f"[SEARCH] '{query}' -> {len(results)} match(es) in KanjiVG")
    return jsonify({"query": query, "results": results})


@app.route('/api/convert')
def api_convert():
    raw = request.args.get('char', '').strip()
    if not raw:
        return jsonify({"error": "No character provided"}), 400
    char = raw[0]
    if not in_kanjivg(char):
        return jsonify({"error": f"'{char}' (U+{ord(char):04X}) is not in KanjiVG"}), 404
    payload = character_payload(char)
    print(f"[CONVERT] {char} -> {payload['filename']}")
    return jsonify(payload)


@app.route('/api/name')
def api_name():
    name = request.args.get('name', '').strip()
    if not name:
        return jsonify({"name": "", "katakana": "", "characters": []})
    katakana = to_katakana(name)
    # Drop any characters KanjiVG doesn't have (shouldn't happen for standard
    # katakana, but stay honest to the KanjiVG contract).
    katakana = "".join(c for c in katakana if in_kanjivg(c))
    payloads = [character_payload(c) for c in katakana]
    print(f"[NAME] '{name}' -> '{katakana}' ({len(payloads)} chars)")
    return jsonify({
        "name": name,
        "katakana": katakana,
        "characters": payloads,
    })


@app.route('/api/ports')
def api_ports():
    if not MACHINE_AVAILABLE or not serial_ports:
        return jsonify({'ports': [], 'error': 'Machine libraries not installed'})
    ports = [{'port': p.device, 'desc': p.description}
             for p in serial_ports.comports()]
    return jsonify({'ports': ports})

@app.route('/api/connect', methods=['POST'])
def api_connect():
    if not MACHINE_AVAILABLE:
        return jsonify({'error': 'Machine libraries not installed'}), 500
    if plot_state["plotting"]:
        return jsonify({'error': 'Cannot reconnect while plotting'}), 400
    data = request.get_json() or {}
    port_a = data.get('a_port')
    port_b = data.get('b_port')
    port_c = data.get('c_port')
    if not port_a or not port_b or not port_c:
        return jsonify({'error': 'Select ports for motors A, B, and C.'}), 400
    try:
        ports = machine.connect(port_a=port_a, port_b=port_b, port_c=port_c)
        with plot_lock:
            plot_state["message"] = "Connected."
        return jsonify({
            'connected': True,
            'ports': {'a': ports[0], 'b': ports[1], 'c': ports[2]},
            'message': 'Connected.',
        })
    except Exception as e:
        return jsonify({'connected': False, 'error': str(e)}), 500


@app.route('/api/disconnect', methods=['POST'])
def api_disconnect():
    if plot_state["plotting"]:
        return jsonify({'error': 'Cannot disconnect while plotting.'}), 400
    if machine:
        machine.disconnect()
    with plot_lock:
        plot_state["message"] = "Disconnected."
    return jsonify({'connected': False})


@app.route('/api/machine/status')
def api_machine_status():
    return jsonify(machine_status_payload())


@app.route('/api/calibration', methods=['GET', 'POST'])
def api_calibration():
    return jsonify({
        'supported': False,
        'message': 'Calibration is fixed inside backend/kanjiXYZ.py in the current setup.',
    })


@app.route('/api/origin', methods=['POST'])
def api_origin():
    if not machine or not machine.connected:
        return jsonify({'error': 'Not connected'}), 400
    machine.set_home()
    return jsonify({'origin_set': True})


@app.route('/api/jog', methods=['POST'])
def api_jog():
    if not machine or not machine.connected:
        return jsonify({'error': 'Not connected'}), 400
    if machine.emergency_stop:
        return jsonify({'error': 'Emergency stop active'}), 400
    data = request.get_json()
    axis = data.get('axis')
    distance_mm = float(data.get('distance', 1))
    dx = dy = dz = 0
    if axis == 'x':
        dx = 1
    elif axis == 'y':
        dy = 1
    elif axis == 'z':
        dz = 1
    else:
        return jsonify({'error': 'Invalid axis'}), 400
    machine.move_rel(dx, dy, dz, distance_mm)
    return jsonify({'pos': {'x': machine.pos_x, 'y': machine.pos_y, 'z': machine.pos_z}})


@app.route('/api/home/xy', methods=['POST'])
def api_home_xy():
    return jsonify({'error': 'Home XY is not supported by the current CoreXYZMachine.'}), 400


@app.route('/api/home/z', methods=['POST'])
def api_home_z():
    return jsonify({'error': 'Home Z is not supported by the current CoreXYZMachine.'}), 400


@app.route('/api/plot', methods=['POST'])
def api_plot():
    if not machine or not machine.connected:
        return jsonify({'error': 'Not connected'}), 400
    if not machine.origin_set:
        return jsonify({'error': 'Set home/origin before plotting.'}), 400
    data = request.get_json() or {}
    filenames = data.get('filenames', [])
    if not filenames:
        return jsonify({'error': 'No characters to plot'}), 400

    if len(filenames) == 1:
        svg_path = download_svg_to_temp(filenames[0])
    else:
        svg_content, _, _ = build_combined_svg(filenames)
        tmp = tempfile.NamedTemporaryFile(suffix='.svg', delete=False, mode='w')
        tmp.write(svg_content)
        tmp.close()
        svg_path = tmp.name

    started, error = start_plot(svg_path)
    if not started:
        try:
            os.unlink(svg_path)
        except OSError:
            pass
        return jsonify({'error': error}), 400
    return jsonify({'started': True})


@app.route('/api/plot/stop', methods=['POST'])
def api_plot_stop():
    return jsonify({'error': 'Stop plot is not supported; use Emergency Stop instead.'}), 400


@app.route('/api/emergency', methods=['POST'])
def api_emergency():
    if machine:
        machine.disable_all()
    with plot_lock:
        plot_state["message"] = "Emergency stop active."
    return jsonify({'stopped': True})


CHAR_SIZE = 109
CHAR_GAP = 15


def translate_d_string(d_string, offset_x, offset_y=0):
    parsed = parse_path(d_string)
    offset = complex(offset_x, offset_y)
    new_segs = []
    for seg in parsed:
        if isinstance(seg, Move):
            new_segs.append(Move(to=seg.start + offset))
        elif isinstance(seg, Line):
            new_segs.append(Line(start=seg.start + offset, end=seg.end + offset))
        elif isinstance(seg, CubicBezier):
            new_segs.append(CubicBezier(
                start=seg.start + offset, control1=seg.control1 + offset,
                control2=seg.control2 + offset, end=seg.end + offset))
        elif isinstance(seg, QuadraticBezier):
            new_segs.append(QuadraticBezier(
                start=seg.start + offset, control=seg.control + offset,
                end=seg.end + offset))
        elif isinstance(seg, Arc):
            new_segs.append(Arc(
                start=seg.start + offset, radius=seg.radius,
                rotation=seg.rotation, arc=seg.arc, sweep=seg.sweep,
                end=seg.end + offset))
        elif isinstance(seg, Close):
            new_segs.append(Close(start=seg.start + offset, end=seg.end + offset))
    return Path(*new_segs).d()


def extract_paths_from_svg(svg_text):
    root = ET.fromstring(svg_text)
    paths = []
    for elem in root.iter():
        if elem.tag.endswith('path'):
            d = elem.attrib.get('d')
            if d:
                paths.append(d)
    return paths


def compute_layout(n):
    """Compute grid positions following Japanese layout conventions.

    - 1 char:  single centered
    - 2 chars: vertical stack
    - 3 chars: vertical stack
    - 4+ chars: square-ish grid, incomplete rows centered
    Returns (positions, total_w, total_h) where positions is [(x, y), ...].
    """
    if n <= 0:
        return [], 0, 0

    step = CHAR_SIZE + CHAR_GAP

    # 1-3: vertical column (top to bottom)
    if n <= 3:
        positions = [(0, i * step) for i in range(n)]
        return positions, CHAR_SIZE, n * CHAR_SIZE + (n - 1) * CHAR_GAP

    # 4+: grid layout
    cols = math.ceil(math.sqrt(n))
    rows = math.ceil(n / cols)
    total_w = cols * CHAR_SIZE + (cols - 1) * CHAR_GAP
    total_h = rows * CHAR_SIZE + (rows - 1) * CHAR_GAP

    positions = []
    idx = 0
    for r in range(rows):
        in_row = min(cols, n - idx)
        row_w = in_row * CHAR_SIZE + (in_row - 1) * CHAR_GAP
        x_start = (total_w - row_w) / 2.0
        for c in range(in_row):
            positions.append((x_start + c * step, r * step))
            idx += 1

    return positions, total_w, total_h


def build_combined_svg(filenames):
    n = len(filenames)
    positions, total_w, total_h = compute_layout(n)

    all_paths = []
    for i, filename in enumerate(filenames):
        if filename not in svg_cache:
            url = KANJIVG_CDN + filename
            with urllib.request.urlopen(url, timeout=10) as resp:
                svg_cache[filename] = resp.read().decode('utf-8')
        d_strings = extract_paths_from_svg(svg_cache[filename])
        offset_x, offset_y = positions[i]
        for d in d_strings:
            translated = translate_d_string(d, offset_x, offset_y)
            all_paths.append(translated)

    svg_lines = [
        f'<svg xmlns="http://www.w3.org/2000/svg" '
        f'viewBox="0 0 {total_w} {total_h}" '
        f'width="{total_w}" height="{total_h}">',
    ]
    for d in all_paths:
        svg_lines.append(
            f'  <path d="{d}" fill="none" stroke="#000" stroke-width="3" '
            f'stroke-linecap="round" stroke-linejoin="round"/>')
    svg_lines.append('</svg>')
    return '\n'.join(svg_lines), total_w, total_h


@app.route('/api/simulate', methods=['POST'])
def api_simulate():
    data = request.get_json()
    filenames = data.get('filenames', [])
    if not filenames:
        return jsonify({'error': 'No filenames'}), 400
    try:
        if len(filenames) == 1:
            svg_path = download_svg_to_temp(filenames[0])
        else:
            svg_content, _, _ = build_combined_svg(filenames)
            tmp = tempfile.NamedTemporaryFile(suffix='.svg', delete=False, mode='w')
            tmp.write(svg_content)
            tmp.close()
            svg_path = tmp.name

        try:
            if not machine:
                return jsonify({'error': 'Machine helper unavailable'}), 500
            paths = machine.extract_svg_paths(svg_path)
            strokes = []
            for d in paths:
                pts = machine.discretize_path(d)
                strokes.append([[p[0], p[1]] for p in pts])
        finally:
            os.unlink(svg_path)

        return jsonify({'strokes': strokes})
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/combine-svg', methods=['POST'])
def api_combine_svg():
    data = request.get_json()
    filenames = data.get('filenames', [])
    if not filenames:
        return jsonify({'error': 'No filenames provided'}), 400
    try:
        svg_content, total_w, total_h = build_combined_svg(filenames)
        return jsonify({'svg': svg_content, 'width': total_w, 'height': total_h})
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/combine-svg/download', methods=['POST'])
def api_combine_svg_download():
    data = request.get_json()
    filenames = data.get('filenames', [])
    if not filenames:
        return jsonify({'error': 'No filenames provided'}), 400
    try:
        svg_content, _, _ = build_combined_svg(filenames)
        tmp = tempfile.NamedTemporaryFile(suffix='.svg', delete=False, mode='w')
        tmp.write(svg_content)
        tmp.close()
        return send_file(tmp.name, mimetype='image/svg+xml',
                         as_attachment=True, download_name='combined.svg')
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/plot/combined', methods=['POST'])
def api_plot_combined():
    if not machine or not machine.connected:
        return jsonify({'error': 'Not connected'}), 400
    if not machine.origin_set:
        return jsonify({'error': 'Set home/origin before plotting.'}), 400
    data = request.get_json() or {}
    filenames = data.get('filenames', [])
    if not filenames:
        return jsonify({'error': 'No filenames'}), 400

    svg_content, _, _ = build_combined_svg(filenames)
    tmp = tempfile.NamedTemporaryFile(suffix='.svg', delete=False, mode='w')
    tmp.write(svg_content)
    tmp.close()
    started, error = start_plot(tmp.name)
    if not started:
        try:
            os.unlink(tmp.name)
        except OSError:
            pass
        return jsonify({'error': error}), 400
    return jsonify({'started': True})


if __name__ == '__main__':
    app.run(
        host=os.environ.get('HOST', '127.0.0.1'),
        port=int(os.environ.get('PORT', 5001)),
        debug=os.environ.get('FLASK_DEBUG', '0') == '1',
        use_reloader=False,
    )
