"""
LIF Neuron Visualizer
Usage: python3 visualizer.py /dev/cu.usbmodem1101
"""

import sys
import serial
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.path as mpath
import matplotlib.patches as mpatches
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Circle, Ellipse, PathPatch
from collections import deque
import numpy as np

PORT = sys.argv[1] if len(sys.argv) > 1 else "/dev/cu.usbmodem1101"
BAUD = 115200
WINDOW = 400

ser = serial.Serial(PORT, BAUD, timeout=0.05)

# Biological scale
V_REST_MV = -70.0
V_THRESH_MV = -55.0
V_SPIKE_MV = 40.0
I_MAX_PA = 500.0

def to_mv(v_norm):
    return V_REST_MV + (v_norm / 0.7) * (V_THRESH_MV - V_REST_MV)

def to_pa(i_norm):
    return i_norm * I_MAX_PA

# Data buffers
count = 0
xs = deque(maxlen=WINDOW)
voltages_mv = deque(maxlen=WINDOW)
currents_pa = deque(maxlen=WINDOW)
spike_xs = deque(maxlen=50)
spike_count = 0

# Colors
plt.style.use("dark_background")
BG = "#060612"
CYAN = "#00e5ff"
GREEN = "#00e676"
RED = "#ff1744"
ORANGE = "#ff9100"
GREY = "#333333"
WHITE = "#e0e0e0"
PURPLE = "#b388ff"

# --- Layout: neuron on top (big), plots below ---
fig = plt.figure(figsize=(14, 9), facecolor=BG)
gs = gridspec.GridSpec(3, 2, height_ratios=[3, 1.8, 1.8], hspace=0.3, wspace=0.25)

ax_neuron = fig.add_subplot(gs[0, :], facecolor=BG)
ax_v = fig.add_subplot(gs[1, :], facecolor=BG)
ax_i = fig.add_subplot(gs[2, :], facecolor=BG, sharex=ax_v)


# ============================================================
# NEURON DIAGRAM (organic, curved, big)
# ============================================================
ax_neuron.set_xlim(-8, 10)
ax_neuron.set_ylim(-4.5, 4.5)
ax_neuron.set_aspect("equal")
ax_neuron.axis("off")


def organic_curve(pts, num=40):
    """Smooth a set of points into a natural curve using cubic interpolation."""
    pts = np.array(pts)
    t = np.linspace(0, 1, len(pts))
    t_smooth = np.linspace(0, 1, num)
    x_smooth = np.interp(t_smooth, t, pts[:, 0])
    y_smooth = np.interp(t_smooth, t, pts[:, 1])
    # add slight organic wobble
    x_smooth += np.sin(t_smooth * np.pi * 3) * 0.03
    y_smooth += np.cos(t_smooth * np.pi * 2.5) * 0.03
    return x_smooth, y_smooth


def draw_branch(ax, pts, base_width=2.0, color=GREEN, base_alpha=0.5):
    """Draw a tapered organic branch."""
    x, y = organic_curve(pts, 30)
    n = len(x)
    lines = []
    for i in range(n - 1):
        frac = i / n
        w = base_width * (1 - 0.7 * frac)
        a = base_alpha * (1 - 0.4 * frac)
        l, = ax.plot([x[i], x[i+1]], [y[i], y[i+1]],
                     color=color, linewidth=w, alpha=a, solid_capstyle="round")
        lines.append(l)
    return lines


# -- Dendrite trees (left side, organic branching) --
dendrite_lines = []

# upper dendrite trunk
dendrite_lines += draw_branch(ax_neuron,
    [(-1.8, 0.4), (-2.8, 1.0), (-3.5, 1.8), (-4.2, 2.8)], base_width=3.0)
# upper sub-branches
dendrite_lines += draw_branch(ax_neuron,
    [(-3.2, 1.5), (-3.8, 2.5), (-4.5, 3.2)], base_width=1.5)
dendrite_lines += draw_branch(ax_neuron,
    [(-3.5, 1.8), (-4.4, 2.0), (-5.2, 2.5)], base_width=1.2)
dendrite_lines += draw_branch(ax_neuron,
    [(-4.2, 2.8), (-4.8, 3.5)], base_width=1.0)
dendrite_lines += draw_branch(ax_neuron,
    [(-4.2, 2.8), (-5.0, 3.0)], base_width=0.8)
dendrite_lines += draw_branch(ax_neuron,
    [(-3.8, 2.5), (-4.0, 3.3)], base_width=0.7)
dendrite_lines += draw_branch(ax_neuron,
    [(-5.2, 2.5), (-5.8, 3.0)], base_width=0.6)
dendrite_lines += draw_branch(ax_neuron,
    [(-5.2, 2.5), (-5.6, 2.2)], base_width=0.5)

# middle dendrite
dendrite_lines += draw_branch(ax_neuron,
    [(-1.8, 0.0), (-3.0, 0.1), (-4.0, -0.2), (-5.0, 0.3)], base_width=2.5)
dendrite_lines += draw_branch(ax_neuron,
    [(-4.0, -0.2), (-4.8, -0.8), (-5.5, -0.5)], base_width=1.2)
dendrite_lines += draw_branch(ax_neuron,
    [(-5.0, 0.3), (-5.8, 0.8)], base_width=0.8)
dendrite_lines += draw_branch(ax_neuron,
    [(-5.0, 0.3), (-5.7, 0.0)], base_width=0.6)

# lower dendrite trunk
dendrite_lines += draw_branch(ax_neuron,
    [(-1.8, -0.4), (-2.8, -1.0), (-3.5, -1.8), (-4.2, -2.8)], base_width=3.0)
dendrite_lines += draw_branch(ax_neuron,
    [(-3.2, -1.5), (-3.8, -2.5), (-4.5, -3.2)], base_width=1.5)
dendrite_lines += draw_branch(ax_neuron,
    [(-3.5, -1.8), (-4.4, -2.0), (-5.2, -2.5)], base_width=1.2)
dendrite_lines += draw_branch(ax_neuron,
    [(-4.2, -2.8), (-4.8, -3.5)], base_width=1.0)
dendrite_lines += draw_branch(ax_neuron,
    [(-4.2, -2.8), (-5.0, -3.0)], base_width=0.8)
dendrite_lines += draw_branch(ax_neuron,
    [(-5.2, -2.5), (-5.8, -2.8)], base_width=0.5)

# -- Synapse boutons on dendrite tips --
synapse_dots = []
tip_positions = [
    (-4.8, 3.5), (-5.0, 3.0), (-4.0, 3.3), (-5.8, 3.0), (-5.6, 2.2),
    (-5.8, 0.8), (-5.7, 0.0), (-5.5, -0.5),
    (-4.8, -3.5), (-5.0, -3.0), (-4.5, -3.2), (-5.8, -2.8),
]
for tx, ty in tip_positions:
    d = Circle((tx, ty), 0.12, facecolor=GREEN, edgecolor=GREEN, alpha=0.3, linewidth=0.5)
    ax_neuron.add_patch(d)
    synapse_dots.append(d)

# -- Soma (irregular blob, not a perfect circle) --
# create organic soma shape
theta = np.linspace(0, 2 * np.pi, 80)
r_soma = 1.4 + 0.15 * np.sin(3 * theta) + 0.1 * np.cos(5 * theta) + 0.08 * np.sin(7 * theta)
soma_x = r_soma * np.cos(theta)
soma_y = r_soma * np.sin(theta) * 0.95  # slightly squished
soma_verts = list(zip(soma_x, soma_y))
soma_verts.append(soma_verts[0])
soma_codes = [mpath.Path.MOVETO] + [mpath.Path.LINETO] * (len(soma_verts) - 2) + [mpath.Path.CLOSEPOLY]
soma_path = mpath.Path(soma_verts, soma_codes)
soma = PathPatch(soma_path, facecolor=(0, 0.1, 0.2, 0.4),
                 edgecolor=CYAN, linewidth=2.5)
ax_neuron.add_patch(soma)

# nucleus (off-center, organic)
theta_n = np.linspace(0, 2 * np.pi, 50)
r_nuc = 0.5 + 0.05 * np.sin(4 * theta_n)
nuc_x = 0.1 + r_nuc * np.cos(theta_n)
nuc_y = 0.15 + r_nuc * np.sin(theta_n) * 0.85
nuc_verts = list(zip(nuc_x, nuc_y))
nuc_verts.append(nuc_verts[0])
nuc_codes = [mpath.Path.MOVETO] + [mpath.Path.LINETO] * (len(nuc_verts) - 2) + [mpath.Path.CLOSEPOLY]
nuc_path = mpath.Path(nuc_verts, nuc_codes)
nucleus = PathPatch(nuc_path, facecolor=(0.05, 0.12, 0.25, 0.6),
                    edgecolor=PURPLE, linewidth=1.2)
ax_neuron.add_patch(nucleus)

# nucleolus (tiny dot inside nucleus)
nucleolus = Circle((0.15, 0.2), 0.13, facecolor=PURPLE, alpha=0.3, edgecolor="none")
ax_neuron.add_patch(nucleolus)

# -- Axon hillock (tapered transition from soma to axon) --
hillock_x = [1.3, 1.6, 2.0, 2.3]
hillock_y_top = [0.5, 0.35, 0.25, 0.2]
hillock_y_bot = [-0.5, -0.35, -0.25, -0.2]
ax_neuron.fill_between(hillock_x, hillock_y_bot, hillock_y_top,
                       color=(0, 0.1, 0.2, 0.3), edgecolor=CYAN, linewidth=1.2)

# -- Axon (long, with organic waviness) --
axon_t = np.linspace(0, 1, 100)
axon_x = 2.3 + axon_t * 6.5
axon_y = np.sin(axon_t * np.pi * 4) * 0.08  # slight wave
ax_neuron.plot(axon_x, axon_y, color=ORANGE, linewidth=2.0, alpha=0.5, solid_capstyle="round")

# -- Myelin sheaths (rounded organic capsules, not boxes) --
myelin_patches = []
myelin_positions = [3.0, 4.2, 5.4, 6.6]
for mx in myelin_positions:
    # draw as a rounded ellipse
    m = Ellipse((mx, 0), 0.8, 0.55, facecolor=(0.15, 0.1, 0, 0.25),
                edgecolor=ORANGE, linewidth=1.2, alpha=0.6)
    ax_neuron.add_patch(m)
    myelin_patches.append(m)

# nodes of ranvier labels (gaps between myelin)
for i in range(len(myelin_positions) - 1):
    gap_x = (myelin_positions[i] + myelin_positions[i+1]) / 2
    ax_neuron.plot(gap_x, 0, "o", color=CYAN, markersize=3, alpha=0.3)

# -- Axon terminals (synaptic boutons) --
terminal_positions = [(8.2, 0.6), (8.5, 0.0), (8.2, -0.6)]
terminals = []
for tx, ty in terminal_positions:
    # branch from axon tip
    ax_neuron.plot([7.8, tx], [0, ty], color=ORANGE, linewidth=1.2, alpha=0.4)
    t = Circle((tx, ty), 0.2, facecolor=BG, edgecolor=RED, linewidth=2, alpha=0.5)
    ax_neuron.add_patch(t)
    terminals.append(t)

# vesicles inside terminals
for tx, ty in terminal_positions:
    for vx, vy in [(tx-0.05, ty+0.05), (tx+0.06, ty-0.04), (tx-0.02, ty-0.06)]:
        v = Circle((vx, vy), 0.04, facecolor=RED, alpha=0.2, edgecolor="none")
        ax_neuron.add_patch(v)

# -- Spike ring (hidden, expands on spike) --
spike_ring = Circle((0, 0), 1.8, fill=False, edgecolor=RED, linewidth=0, alpha=0)
ax_neuron.add_patch(spike_ring)

# -- Labels --
ax_neuron.text(-5.5, 4.0, "dendrites", ha="center", color=GREEN, fontsize=8, alpha=0.5, style="italic")
ax_neuron.text(0, -2.2, "soma", ha="center", color=CYAN, fontsize=8, alpha=0.4, style="italic")
ax_neuron.text(0.15, 0.55, "nucleus", ha="center", color=PURPLE, fontsize=6, alpha=0.4)
ax_neuron.text(4.8, -0.9, "myelin sheath", ha="center", color=ORANGE, fontsize=7, alpha=0.4, style="italic")
ax_neuron.text(1.7, -0.9, "axon hillock", ha="center", color=CYAN, fontsize=6, alpha=0.35, style="italic")
ax_neuron.text(8.5, -1.3, "synaptic\nterminals", ha="center", color=RED, fontsize=7, alpha=0.4, style="italic")

# voltage readout inside soma
v_text = ax_neuron.text(0, -0.3, "-70 mV", ha="center", va="center",
                        color=CYAN, fontsize=16, fontweight="bold", family="monospace")

# input current label near dendrites
i_text = ax_neuron.text(-5.5, -4.0, "0 pA", ha="center", color=GREEN,
                        fontsize=11, fontweight="bold", family="monospace")


# ============================================================
# PLOTS
# ============================================================
line_v, = ax_v.plot([], [], color=CYAN, linewidth=1.5, alpha=0.9)
ax_v.axhline(y=V_THRESH_MV, color=RED, linestyle="--", linewidth=1, alpha=0.4)
ax_v.axhline(y=V_REST_MV, color=GREY, linestyle=":", linewidth=0.8, alpha=0.3)
ax_v.set_ylabel("Membrane (mV)", color=WHITE, fontsize=9)
ax_v.set_ylim(-78, 50)
ax_v.tick_params(colors=GREY, labelsize=7)
for s in ["top", "right"]: ax_v.spines[s].set_visible(False)
for s in ["bottom", "left"]: ax_v.spines[s].set_color(GREY)
ax_v.text(0.01, 0.9, "V_mem", transform=ax_v.transAxes, color=CYAN, fontsize=8, fontweight="bold", alpha=0.4)
ax_v.text(0.99, 0.82, f"{V_THRESH_MV:.0f} mV", transform=ax_v.transAxes, color=RED, fontsize=7, alpha=0.35, ha="right")

line_i, = ax_i.plot([], [], color=GREEN, linewidth=1.5, alpha=0.9)
ax_i.set_ylabel("I_syn (pA)", color=WHITE, fontsize=9)
ax_i.set_ylim(-10, I_MAX_PA + 30)
ax_i.set_xlabel("samples (10 ms each)", color=GREY, fontsize=8)
ax_i.tick_params(colors=GREY, labelsize=7)
for s in ["top", "right"]: ax_i.spines[s].set_visible(False)
for s in ["bottom", "left"]: ax_i.spines[s].set_color(GREY)
ax_i.text(0.01, 0.85, "I_syn", transform=ax_i.transAxes, color=GREEN, fontsize=8, fontweight="bold", alpha=0.4)

# Title
fig.text(0.5, 0.975, "Leaky Integrate-and-Fire Neuron", ha="center",
         color=WHITE, fontsize=15, fontweight="bold")
fig.text(0.5, 0.955, "XIAO ESP32-S3  \u2502  EC11 Rotary Encoder  \u2502  \u03C4 = 2.2s  \u2502  V_thresh = -55 mV",
         ha="center", color=GREY, fontsize=8)

fill_collection = None
spike_fade = 0


def update(frame):
    global count, spike_count, fill_collection, spike_fade

    new_spikes = 0
    for _ in range(30):
        raw = ser.readline()
        if not raw:
            break
        try:
            parts = raw.decode("utf-8", errors="ignore").strip().split(",")
            if len(parts) != 4:
                continue
            i_in = float(parts[1])
            v = float(parts[2])
            sp = int(parts[3])

            xs.append(count)
            if sp == 1:
                voltages_mv.append(V_SPIKE_MV)
                spike_xs.append(count)
                spike_count += 1
                new_spikes += 1
            else:
                voltages_mv.append(to_mv(v))
            currents_pa.append(to_pa(i_in))
            count += 1
        except (ValueError, IndexError):
            continue

    if len(xs) < 2:
        return []

    x = list(xs)
    v_list = list(voltages_mv)
    i_list = list(currents_pa)

    line_v.set_data(x, v_list)
    line_i.set_data(x, i_list)

    if fill_collection is not None:
        fill_collection.remove()
    fill_clipped = [max(v, V_REST_MV) for v in v_list]
    fill_collection = ax_v.fill_between(x, V_REST_MV, fill_clipped, color=CYAN, alpha=0.05)

    ax_v.set_xlim(x[0], x[-1])
    ax_i.set_xlim(x[0], x[-1])

    cur_v = v_list[-1]
    cur_i = i_list[-1]
    cur_v_norm = max(0, min((cur_v - V_REST_MV) / (V_THRESH_MV - V_REST_MV), 1.0))
    i_frac = min(cur_i / I_MAX_PA, 1.0)

    # --- Soma glow ---
    r = cur_v_norm * 0.8
    g = 0.08 + 0.08 * (1 - cur_v_norm)
    b = 0.25 + 0.55 * (1 - cur_v_norm)
    soma.set_facecolor((r, g, b, 0.4 + 0.3 * cur_v_norm))
    soma.set_edgecolor(RED if cur_v_norm > 0.85 else CYAN)
    soma.set_linewidth(2.5 + 4 * cur_v_norm)

    # --- Dendrite glow ---
    for d in synapse_dots:
        d.set_alpha(0.15 + 0.85 * i_frac)
        d.set_radius(0.08 + 0.08 * i_frac)
    for l in dendrite_lines:
        l.set_alpha(min(1.0, l.get_alpha() * 0.5 + 0.5 * (0.3 + 0.6 * i_frac)))

    # --- Voltage text ---
    v_text.set_text(f"{cur_v:+.0f} mV")
    v_text.set_color(RED if cur_v_norm > 0.85 else CYAN)
    v_text.set_fontsize(16 + 4 * cur_v_norm)

    # --- Input text ---
    i_text.set_text(f"{cur_i:.0f} pA")
    i_text.set_alpha(0.4 + 0.6 * i_frac)

    # --- Spike animation ---
    if new_spikes > 0:
        spike_fade = 10

    if spike_fade > 0:
        a = spike_fade / 10.0
        # soma flash
        soma.set_facecolor((1, 0.15, 0.05, 0.5 + 0.4 * a))
        soma.set_edgecolor((1, 0.3, 0.1))
        soma.set_linewidth(3 + 5 * a)
        # spike ring expands
        spike_ring.set_linewidth(3 * a)
        spike_ring.set_alpha(a * 0.6)
        spike_ring.set_radius(1.8 + (10 - spike_fade) * 0.3)
        # terminals flash
        for t in terminals:
            t.set_facecolor((1, 0.1, 0.1, a * 0.8))
            t.set_edgecolor((1, 0.3, 0.1))
        # myelin flash
        for m in myelin_patches:
            m.set_facecolor((0.4, 0.2, 0, 0.3 * a))
            m.set_edgecolor((1, 0.6, 0, a))
        spike_fade -= 1
    else:
        spike_ring.set_linewidth(0)
        spike_ring.set_alpha(0)
        for t in terminals:
            t.set_facecolor(BG)
            t.set_edgecolor(RED)
            t.set_alpha(0.5)
        for m in myelin_patches:
            m.set_facecolor((0.15, 0.1, 0, 0.25))
            m.set_edgecolor(ORANGE)
            m.set_alpha(0.6)

    # --- Firing rate ---
    recent = sum(1 for sx in spike_xs if sx > count - 200)
    rate_hz = recent / (200 * 0.01) if count > 200 else 0

    # title area firing info
    fig.texts[-1].set_text(
        f"XIAO ESP32-S3  \u2502  EC11 Rotary Encoder  \u2502  \u03C4 = 2.2s  \u2502  "
        f"Spikes: {spike_count}  \u2502  Rate: {rate_hz:.1f} Hz"
        f"{'  \u26a1' if new_spikes > 0 else ''}"
    )

    return []


ani = FuncAnimation(fig, update, interval=40, blit=False, cache_frame_data=False)
plt.subplots_adjust(top=0.94, bottom=0.05, left=0.06, right=0.97, hspace=0.3)
plt.show()
ser.close()
