import serial
import sys
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from collections import deque

PORT = sys.argv[1] if len(sys.argv) > 1 else "/dev/cu.usbmodem101"
BAUD = 115200
WINDOW = 500  # data points shown

ser = serial.Serial(PORT, BAUD, timeout=0.05)

ts = deque(maxlen=WINDOW)
voltages = deque(maxlen=WINDOW)
currents = deque(maxlen=WINDOW)
spikes = deque(maxlen=WINDOW)

fig, (ax_v, ax_i) = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
fig.suptitle("LIF Neuron — Live", fontsize=14)

line_v, = ax_v.plot([], [], "b-", linewidth=1.2, label="Membrane V")
line_thresh = None
line_i, = ax_i.plot([], [], "g-", linewidth=1.2, label="Input I")

ax_v.set_ylabel("Membrane Potential")
ax_v.set_ylim(-0.05, 1.15)
ax_v.axhline(y=1.0, color="r", linestyle="--", linewidth=0.8, label="Threshold")
ax_v.legend(loc="upper right", fontsize=8)

ax_i.set_ylabel("Input Current")
ax_i.set_ylim(-0.05, 1.05)
ax_i.set_xlabel("Sample")
ax_i.legend(loc="upper right", fontsize=8)

spike_markers = []

def update(frame):
    global spike_markers

    # read all available lines
    lines_read = 0
    while lines_read < 20:
        raw = ser.readline()
        if not raw:
            break
        try:
            line = raw.decode("utf-8", errors="ignore").strip()
            parts = line.split(",")
            if len(parts) != 4:
                continue
            t = int(parts[0])
            i_in = float(parts[1])
            v = float(parts[2])
            sp = int(parts[3])

            ts.append(len(ts))
            voltages.append(v)
            currents.append(i_in)
            spikes.append(sp)
            lines_read += 1
        except (ValueError, UnicodeDecodeError):
            continue

    if len(ts) < 2:
        return line_v, line_i

    x = list(ts)
    line_v.set_data(x, list(voltages))
    line_i.set_data(x, list(currents))

    # mark spikes
    for m in spike_markers:
        m.remove()
    spike_markers = []
    for idx, sp in enumerate(spikes):
        if sp == 1:
            m = ax_v.axvline(x=x[idx], color="r", alpha=0.3, linewidth=1)
            spike_markers.append(m)

    ax_v.set_xlim(x[0], x[-1])
    ax_i.set_xlim(x[0], x[-1])

    return line_v, line_i

ani = FuncAnimation(fig, update, interval=50, blit=False, cache_frame_data=False)
plt.tight_layout()
plt.show()

ser.close()
