/*
 * Muse EEG Headband — On-Device BLE Streaming + Blink Counter
 * =============================================================
 * Connects directly to a Muse 2 EEG headband via Bluetooth LE on the
 * XIAO ESP32-S3.  No PC or phone needed — everything runs on the board.
 *
 * What it does:
 *   - Scans for a "Muse*" BLE device and connects automatically.
 *   - Subscribes to the AF7 and AF8 EEG channels (frontal electrodes).
 *   - Detects eye blinks by looking for sudden large amplitude spikes
 *     after subtracting a slow-moving baseline.
 *   - Shows three UI modes on the OLED, cycled with the D2 button:
 *       1. Blink Counter  – large digit count of blinks detected.
 *       2. Calibrating    – quiet-noise measurement for blink threshold.
 *       3. Waveform       – live scrolling EEG trace for AF7 and AF8.
 *
 * Hardware:
 *   OLED SDA -> D4 (Grove I2C)
 *   OLED SCL -> D5 (Grove I2C)
 *   Button   -> D2 (active low)
 *
 * Dependencies (platformio.ini):
 *   h2zero/NimBLE-Arduino@^1.4.3
 *   adafruit/Adafruit SSD1306@^2.5.11
 *   adafruit/Adafruit GFX Library@^1.11.10
 */

#include <Arduino.h>
#include <Wire.h>
#include <cstring>
#include <cmath>
#include <NimBLEDevice.h>
#include <Adafruit_GFX.h>
#include <Adafruit_SSD1306.h>

namespace {

// ---- Hardware pins ----------------------------------------------------------
constexpr int    I2C_SDA           = D4;
constexpr int    I2C_SCL           = D5;
constexpr int    BUTTON_PIN        = D2;
constexpr bool   BUTTON_ACTIVE_LOW = true;
constexpr uint32_t BUTTON_DEBOUNCE_MS = 25;

// ---- OLED -------------------------------------------------------------------
constexpr int     SCREEN_W   = 128;
constexpr int     SCREEN_H   = 64;
constexpr int     OLED_RESET = -1;
constexpr uint8_t OLED_ADDR  = 0x3C;

// ---- Muse 2 BLE UUIDs -------------------------------------------------------
// These are published in the Muse SDK and various open-source projects.
constexpr const char* MUSE_SERVICE_UUID = "0000fe8d-0000-1000-8000-00805f9b34fb";
constexpr const char* MUSE_CTRL_UUID    = "273e0001-4c4d-454d-96be-f03bac821358";
constexpr const char* MUSE_TP9_UUID     = "273e0003-4c4d-454d-96be-f03bac821358";
constexpr const char* MUSE_AF7_UUID     = "273e0004-4c4d-454d-96be-f03bac821358";
constexpr const char* MUSE_AF8_UUID     = "273e0005-4c4d-454d-96be-f03bac821358";
constexpr const char* MUSE_TP10_UUID    = "273e0006-4c4d-454d-96be-f03bac821358";
constexpr const char* MUSE_RAUX_UUID    = "273e0007-4c4d-454d-96be-f03bac821358";

// ---- Muse packet constants --------------------------------------------------
constexpr float  UV_SCALE       = 0.48828125f;  // ADC count -> microvolts
constexpr size_t EEG_PACKET_LEN = 20;           // bytes per BLE notify
constexpr uint16_t EEG_QUEUE_LEN = 128;         // ISR -> main-loop ring buffer

// ---- BLE connection parameters ----------------------------------------------
// Tighter interval = more data, less latency, but more RF energy needed.
constexpr uint16_t MUSE_CONN_ITVL_MIN  = 24;   // 30 ms
constexpr uint16_t MUSE_CONN_ITVL_MAX  = 40;   // 50 ms
constexpr uint16_t MUSE_CONN_LATENCY   = 0;
constexpr uint16_t MUSE_CONN_TIMEOUT   = 3000;  // 30 s supervision timeout
constexpr uint16_t MUSE_MIN_ACCEPT_TIMEOUT = 1000;  // reject peer updates below 10 s
constexpr uint16_t MUSE_CONN_SCAN_ITVL = 24;
constexpr uint16_t MUSE_CONN_SCAN_WINDOW = 24;

// Connection profile used during initial setup (slightly looser to ease link establishment)
constexpr uint16_t MUSE_CONNECT_ITVL_MIN = 12;  // 15 ms
constexpr uint16_t MUSE_CONNECT_ITVL_MAX = 24;  // 30 ms
constexpr uint16_t MUSE_CONNECT_TIMEOUT  = 1000; // 10 s

// ---- Blink detection parameters ---------------------------------------------
constexpr float    HP_BASELINE_ALPHA      = 0.005f;  // slow baseline tracking
constexpr float    BLINK_BASE_THRESHOLD_UV = 60.0f;  // minimum spike to consider
constexpr float    BLINK_NOISE_MULT       = 2.8f;    // threshold = noise * this
constexpr uint32_t BLINK_MIN_INTERVAL_MS  = 300;     // fastest natural blink rate
constexpr uint32_t BLINK_REARM_MS         = 120;     // signal must settle before next blink
constexpr float    BLINK_REARM_FRAC       = 0.50f;
constexpr float    BLINK_NO_CONTACT_ABS_UV = 420.0f; // above this = poor contact
constexpr uint32_t BLINK_NO_CONTACT_HOLD_MS = 900;
constexpr uint32_t BLINK_HOLD_MS          = 300;     // on-screen "BLINK!" hold time

// ---- Band power computation -------------------------------------------------
constexpr uint16_t BAND_BUF_SIZE   = 128;
constexpr float    EEG_FS_HZ       = 256.0f;
constexpr uint32_t BAND_UPDATE_MS  = 350;
constexpr uint32_t CALIBRATION_MIN_MS = 2500;

// ---- Channel IDs ------------------------------------------------------------
constexpr uint8_t CH_AF7  = 1;
constexpr uint8_t CH_AF8  = 2;
constexpr uint8_t CH_TP9  = 3;
constexpr uint8_t CH_TP10 = 4;
constexpr uint8_t CH_RAUX = 5;
constexpr uint8_t CH_COUNT = 5;

// ---- Waveform display layout ------------------------------------------------
constexpr int GRAPH1_TOP    = 12;
constexpr int GRAPH1_BOTTOM = 36;
constexpr int GRAPH2_TOP    = 38;
constexpr int GRAPH2_BOTTOM = 62;
constexpr int GRAPH_HALF_H  = 10;
constexpr int AF7_CENTER_Y  = 24;
constexpr int AF8_CENTER_Y  = 50;

// ---- Global objects ---------------------------------------------------------
Adafruit_SSD1306 display(SCREEN_W, SCREEN_H, &Wire, OLED_RESET);

NimBLEAdvertisedDevice*    gTargetDevice   = nullptr;
NimBLEClient*              gClient         = nullptr;
NimBLERemoteCharacteristic* gCtrlChar      = nullptr;
NimBLERemoteCharacteristic* gAF7Char       = nullptr;
NimBLERemoteCharacteristic* gAF8Char       = nullptr;
NimBLERemoteCharacteristic* gTP9Char       = nullptr;
NimBLERemoteCharacteristic* gTP10Char      = nullptr;
NimBLERemoteCharacteristic* gRAuxChar      = nullptr;

bool     gConnected          = false;
bool     gNeedsReconnect     = false;
bool     gStreamingStarted   = false;
uint32_t gNextConnectAttemptMs = 0;

// ISR-safe ring buffer for EEG packets received in BLE notify callbacks
struct EegPacket { uint8_t channelId; uint8_t payload[EEG_PACKET_LEN]; };
EegPacket        gEegQueue[EEG_QUEUE_LEN];
volatile uint16_t gEegHead    = 0;
volatile uint16_t gEegTail    = 0;
volatile uint32_t gEegDropped = 0;
portMUX_TYPE     gEegMux      = portMUX_INITIALIZER_UNLOCKED;

// Waveform ring buffers (one column per screen pixel)
int16_t gWaveAF7[SCREEN_W] = {0};
int16_t gWaveAF8[SCREEN_W] = {0};
uint8_t gWaveIdxAF7 = 0;
uint8_t gWaveIdxAF8 = 0;

float gBaselineAF7 = 0.0f,  gBaselineAF8 = 0.0f;
float gScaleAF7    = 140.0f, gScaleAF8    = 140.0f;
float gLastAF7     = 0.0f,   gLastAF8     = 0.0f;

volatile uint32_t gSamplesAF7 = 0, gSamplesAF8 = 0;
volatile uint32_t gNotifyPktsAF7 = 0, gNotifyPktsAF8 = 0, gNotifyPktsOther = 0;
volatile int      gGapLastDisconnectReason = 0;
volatile uint32_t gGapDisconnectCount = 0;

uint32_t gConnStartMs = 0;
uint32_t gLastConnParamReqMs = 0;

// Blink detector state
uint32_t gBlinkCount         = 0;
uint32_t gLastBlinkMs        = 0;
uint32_t gBlinkActiveUntilMs = 0;
float    gBlinkThresholdUv   = BLINK_BASE_THRESHOLD_UV;
float    gNoiseAbsAF7        = 22.0f;
float    gNoiseAbsAF8        = 22.0f;
float    gHpAF7              = 0.0f;
float    gHpAF8              = 0.0f;
bool     gBlinkArmed         = true;
uint32_t gBlinkBelowSinceMs  = 0;
uint32_t gHighNoiseSinceMs   = 0;
bool     gPoorContact        = false;
uint32_t gLastBlinkDebugMs   = 0;

// UI state
enum class UiMode : uint8_t { BlinkCounter = 0, Calibrating = 1, Waveform = 2 };
UiMode   gUiMode           = UiMode::BlinkCounter;
bool     gButtonStableState = HIGH;
bool     gButtonLastRead    = HIGH;
uint32_t gButtonLastEdgeMs  = 0;

// Calibration accumulators
uint32_t gCalibrationStartMs   = 0;
uint32_t gCalibrationSamples   = 0;
float    gCalibrationAbsSum    = 0.0f;
float    gCalibrationAbsSqSum  = 0.0f;

// Band power (computed every 350 ms from a 128-sample window of AF7)
float    gBandBuf[BAND_BUF_SIZE] = {0.0f};
uint16_t gBandWrite  = 0;
uint16_t gBandCount  = 0;
uint32_t gLastBandUpdateMs = 0;
float    gBandDelta = 0.0f, gBandTheta = 0.0f, gBandAlpha = 0.0f;
float    gBandBeta  = 0.0f, gBandGamma = 0.0f;

// ---- Utility ----------------------------------------------------------------

float clampf(float v, float lo, float hi) {
    if (v < lo) return lo;
    if (v > hi) return hi;
    return v;
}

const char* uiModeName(UiMode mode) {
    switch (mode) {
        case UiMode::BlinkCounter: return "blink-counter";
        case UiMode::Calibrating:  return "calibrating";
        case UiMode::Waveform:     return "waveform";
        default:                   return "?";
    }
}

const char* hciReasonToString(uint8_t reason) {
    switch (reason) {
        case 0x08: return "Connection Timeout";
        case 0x13: return "Remote User Terminated";
        case 0x16: return "Local Host Terminated";
        case 0x22: return "LL Response Timeout";
        case 0x3B: return "Unacceptable Conn Params";
        case 0x3E: return "Connection Failed to Establish";
        default:   return "Unknown";
    }
}

int channelIndex(uint8_t channelId) {
    switch (channelId) {
        case CH_AF7:  return 0; case CH_AF8:  return 1;
        case CH_TP9:  return 2; case CH_TP10: return 3;
        case CH_RAUX: return 4; default:      return -1;
    }
}

// ---- Blink detector ---------------------------------------------------------

void resetBlinkDetector() {
    gNoiseAbsAF7      = 22.0f;
    gNoiseAbsAF8      = 22.0f;
    gHpAF7            = 0.0f;
    gHpAF8            = 0.0f;
    gBlinkArmed       = true;
    gBlinkBelowSinceMs = 0;
    gHighNoiseSinceMs  = 0;
    gPoorContact       = false;
    gLastBlinkDebugMs  = 0;
}

// Called for every decoded EEG sample on AF7 or AF8.
// hpUv is the high-pass filtered value (DC baseline already removed).
void updateBlinkFromSample(uint8_t channelId, float hpUv) {
    float* hp    = (channelId == CH_AF7) ? &gHpAF7    : (channelId == CH_AF8) ? &gHpAF8    : nullptr;
    float* noise = (channelId == CH_AF7) ? &gNoiseAbsAF7 : (channelId == CH_AF8) ? &gNoiseAbsAF8 : nullptr;
    if (!hp || !noise) return;

    const uint32_t now  = millis();
    *hp                 = hpUv;
    const float absHp   = fabsf(hpUv);

    // Noise floor tracks the running mean absolute amplitude, ignoring large
    // spikes (which are likely blinks themselves).
    const float alpha = (gUiMode == UiMode::Calibrating) ? 0.03f : 0.006f;
    if (absHp < 350.0f) {
        *noise = (1.0f - alpha) * (*noise) + alpha * absHp;
    }

    // During calibration just accumulate statistics, no blink detection.
    if (gUiMode == UiMode::Calibrating) {
        gCalibrationSamples++;
        gCalibrationAbsSum    += absHp;
        gCalibrationAbsSqSum  += absHp * absHp;
        return;
    }

    // If both channels show very high noise for a sustained period,
    // report poor contact so the UI can warn the user.
    const float noiseMean = 0.5f * (gNoiseAbsAF7 + gNoiseAbsAF8);
    if (noiseMean > BLINK_NO_CONTACT_ABS_UV) {
        if (gHighNoiseSinceMs == 0) gHighNoiseSinceMs = now;
        if ((now - gHighNoiseSinceMs) >= BLINK_NO_CONTACT_HOLD_MS) gPoorContact = true;
    } else {
        gHighNoiseSinceMs = 0;
        gPoorContact      = false;
    }

    // Adaptive threshold: the higher the ambient noise, the more a spike
    // must exceed it to be counted as a blink.
    const float threshold         = fmaxf(gBlinkThresholdUv, noiseMean * BLINK_NOISE_MULT);
    const float secondaryThreshold = threshold * 0.60f;
    const float rearmThreshold    = fmaxf(14.0f, threshold * BLINK_REARM_FRAC);
    const bool  bothBelow         = (fabsf(gHpAF7) < rearmThreshold && fabsf(gHpAF8) < rearmThreshold);

    // Re-arm: wait for the signal to settle below rearmThreshold before
    // accepting the next blink, to avoid double-counting one blink.
    if (!gBlinkArmed) {
        if (bothBelow) {
            if (gBlinkBelowSinceMs == 0) gBlinkBelowSinceMs = now;
            else if ((now - gBlinkBelowSinceMs) >= BLINK_REARM_MS) gBlinkArmed = true;
        } else {
            gBlinkBelowSinceMs = 0;
        }
    }

    // Trigger: one channel must exceed the full threshold; the other must
    // exceed the secondary threshold (strong + supportive evidence).
    const float absA7 = fabsf(gHpAF7), absA8 = fabsf(gHpAF8);
    const bool  trigger = ((absA7 >= threshold && absA8 >= secondaryThreshold) ||
                           (absA8 >= threshold && absA7 >= secondaryThreshold));

    if (!gPoorContact && gBlinkArmed && (now - gLastBlinkMs) >= BLINK_MIN_INTERVAL_MS && trigger) {
        gBlinkCount++;
        gLastBlinkMs        = now;
        gBlinkActiveUntilMs = now + BLINK_HOLD_MS;
        gBlinkArmed         = false;
        gBlinkBelowSinceMs  = 0;
        Serial.printf("[blink] count=%lu thr=%.1f noise=%.1f a7=%.1f a8=%.1f\n",
                      static_cast<unsigned long>(gBlinkCount), threshold, noiseMean, absA7, absA8);
    }

    // Periodic debug line for serial monitor
    if (now - gLastBlinkDebugMs >= 1200) {
        gLastBlinkDebugMs = now;
        Serial.printf("[blink_dbg] mode=%s armed=%d poor=%d thr=%.1f noise=%.1f a7=%.1f a8=%.1f\n",
                      uiModeName(gUiMode), gBlinkArmed ? 1 : 0, gPoorContact ? 1 : 0,
                      threshold, noiseMean, absA7, absA8);
    }
}

// ---- Band power (DFT over 128-sample AF7 window) ----------------------------

void pushBandSample(float hpUv) {
    gBandBuf[gBandWrite] = hpUv;
    gBandWrite = static_cast<uint16_t>((gBandWrite + 1) % BAND_BUF_SIZE);
    if (gBandCount < BAND_BUF_SIZE) gBandCount++;
}

// Compute power in each EEG band using a windowed DFT.
// Runs infrequently (every 350 ms) so the O(N*K) cost is acceptable.
void updateBandPowers() {
    if (gBandCount < BAND_BUF_SIZE) return;
    const uint32_t now = millis();
    if (now - gLastBandUpdateMs < BAND_UPDATE_MS) return;
    gLastBandUpdateMs = now;

    float delta = 0, theta = 0, alpha = 0, beta = 0, gamma = 0;

    for (int k = 1; k <= 22; ++k) {
        float re = 0, im = 0;
        for (int n = 0; n < BAND_BUF_SIZE; ++n) {
            const uint16_t idx = static_cast<uint16_t>((gBandWrite + n) % BAND_BUF_SIZE);
            // Hann window reduces spectral leakage
            const float w   = 0.5f - 0.5f * cosf((2.0f * PI * n) / (BAND_BUF_SIZE - 1));
            const float x   = gBandBuf[idx] * w;
            const float ang = (2.0f * PI * k * n) / BAND_BUF_SIZE;
            re += x * cosf(ang);
            im -= x * sinf(ang);
        }
        const float p = re * re + im * im;
        const float f = (k * EEG_FS_HZ) / BAND_BUF_SIZE;

        if      (f >= 1.0f && f <  4.0f) delta += p;
        else if (f >= 4.0f && f <  8.0f) theta += p;
        else if (f >= 8.0f && f < 13.0f) alpha += p;
        else if (f >= 13.0f && f < 30.0f) beta  += p;
        else if (f >= 30.0f && f <= 45.0f) gamma += p;
    }

    const float total = delta + theta + alpha + beta + gamma;
    if (total > 0.0f) {
        gBandDelta = delta / total; gBandTheta = theta / total;
        gBandAlpha = alpha / total; gBandBeta  = beta  / total;
        gBandGamma = gamma / total;
    }
}

// ---- Calibration ------------------------------------------------------------

void startCalibration() {
    gUiMode                = UiMode::Calibrating;
    gCalibrationStartMs    = millis();
    gCalibrationSamples    = 0;
    gCalibrationAbsSum     = 0.0f;
    gCalibrationAbsSqSum   = 0.0f;
    resetBlinkDetector();
    Serial.println("[cal] started: keep still, eyes open, press D2 again to finish");
}

void finishCalibration() {
    const uint32_t dt = millis() - gCalibrationStartMs;
    float calibratedNoise = 22.0f;

    if (gCalibrationSamples > 0 && dt >= CALIBRATION_MIN_MS) {
        const float meanAbs = gCalibrationAbsSum / static_cast<float>(gCalibrationSamples);
        const float meanSq  = gCalibrationAbsSqSum / static_cast<float>(gCalibrationSamples);
        const float sigma   = sqrtf(fmaxf(0.0f, meanSq - meanAbs * meanAbs));
        gBlinkThresholdUv   = clampf(fmaxf(BLINK_BASE_THRESHOLD_UV, meanAbs * 2.8f + sigma * 0.8f), 55.0f, 220.0f);
        calibratedNoise     = fmaxf(14.0f, meanAbs);
        Serial.printf("[cal] done: samples=%lu meanAbs=%.1f sigma=%.1f threshold=%.1f\n",
                      static_cast<unsigned long>(gCalibrationSamples), meanAbs, sigma, gBlinkThresholdUv);
    } else {
        Serial.println("[cal] skipped: too short or no EEG samples");
    }

    resetBlinkDetector();
    gNoiseAbsAF7 = calibratedNoise;
    gNoiseAbsAF8 = calibratedNoise;
    gUiMode      = UiMode::Waveform;
}

void goToBlinkCounterMode() {
    gUiMode = UiMode::BlinkCounter;
    resetBlinkDetector();
    Serial.println("[mode] blink-counter");
}

// ---- Button -----------------------------------------------------------------

void handleButtonPressed() {
    switch (gUiMode) {
        case UiMode::BlinkCounter: startCalibration(); break;
        case UiMode::Calibrating:  finishCalibration(); break;
        case UiMode::Waveform:     goToBlinkCounterMode(); break;
        default:                   startCalibration(); break;
    }
}

void updateButton() {
    const bool    raw = static_cast<bool>(digitalRead(BUTTON_PIN));
    const uint32_t now = millis();

    if (raw != gButtonLastRead) {
        gButtonLastRead    = raw;
        gButtonLastEdgeMs  = now;
    }
    if (now - gButtonLastEdgeMs < BUTTON_DEBOUNCE_MS) return;

    if (raw != gButtonStableState) {
        gButtonStableState = raw;
        const bool pressed = BUTTON_ACTIVE_LOW ? (raw == LOW) : (raw == HIGH);
        if (pressed) handleButtonPressed();
    }
}

// ---- EEG packet parsing -----------------------------------------------------

// Decode a 12-bit unsigned integer from an arbitrary bit offset in a byte array.
// Muse packs 12 samples into each 20-byte BLE notification using 12 bits each.
uint16_t read12Bit(const uint8_t* data, size_t bitOffset) {
    uint16_t value = 0;
    for (size_t i = 0; i < 12; ++i) {
        const size_t idx = bitOffset + i;
        const uint8_t bit = (data[idx / 8] >> (7 - (idx % 8))) & 0x01;
        value = static_cast<uint16_t>((value << 1) | bit);
    }
    return value;
}

void resetWaveState() {
    memset(gWaveAF7, 0, sizeof(gWaveAF7));
    memset(gWaveAF8, 0, sizeof(gWaveAF8));
    gWaveIdxAF7 = gWaveIdxAF8 = 0;
    gBaselineAF7 = gBaselineAF8 = 0.0f;
    gScaleAF7 = gScaleAF8 = 140.0f;
    gLastAF7 = gLastAF8 = 0.0f;
    gSamplesAF7 = gSamplesAF8 = 0;
    gNotifyPktsAF7 = gNotifyPktsAF8 = gNotifyPktsOther = 0;
    gBlinkCount = 0; gLastBlinkMs = 0; gBlinkActiveUntilMs = 0;
    gBlinkThresholdUv = BLINK_BASE_THRESHOLD_UV;
    resetBlinkDetector();
    gCalibrationStartMs = gCalibrationSamples = 0;
    gCalibrationAbsSum = gCalibrationAbsSqSum = 0.0f;
    memset(gBandBuf, 0, sizeof(gBandBuf));
    gBandWrite = gBandCount = 0; gLastBandUpdateMs = 0;
    gBandDelta = gBandTheta = gBandAlpha = gBandBeta = gBandGamma = 0.0f;
    portENTER_CRITICAL(&gEegMux);
    gEegHead = gEegTail = 0; gEegDropped = 0;
    portEXIT_CRITICAL(&gEegMux);
}

// Decode one 20-byte EEG packet: 2 bytes sequence + 12 samples × 12 bits each.
void processChannelPacket(uint8_t channelId, const uint8_t* data) {
    float*   baseline   = nullptr;
    float*   scale      = nullptr;
    float*   last       = nullptr;
    int16_t* wave       = nullptr;
    uint8_t* waveIdx    = nullptr;
    volatile uint32_t* sampleCount = nullptr;

    if (channelId == CH_AF7) {
        baseline = &gBaselineAF7; scale = &gScaleAF7; last = &gLastAF7;
        wave = gWaveAF7; waveIdx = &gWaveIdxAF7; sampleCount = &gSamplesAF7;
    } else if (channelId == CH_AF8) {
        baseline = &gBaselineAF8; scale = &gScaleAF8; last = &gLastAF8;
        wave = gWaveAF8; waveIdx = &gWaveIdxAF8; sampleCount = &gSamplesAF8;
    }

    for (int i = 0; i < 12; ++i) {
        const uint16_t raw = read12Bit(data, 16 + i * 12);
        const float uv     = (static_cast<float>(raw) - 2048.0f) * UV_SCALE;

        if (baseline && scale && last && wave && waveIdx && sampleCount) {
            // Exponential moving average for baseline removal (high-pass equivalent)
            *baseline = (1.0f - HP_BASELINE_ALPHA) * (*baseline) + HP_BASELINE_ALPHA * uv;
            const float hp = uv - *baseline;

            // Auto-scale the waveform display: track the typical signal amplitude.
            *scale = 0.999f * (*scale) + 0.001f * fabsf(hp) * 6.0f;
            *scale = clampf(*scale, 40.0f, 420.0f);

            *last              = hp;
            wave[*waveIdx]     = static_cast<int16_t>(hp);
            *waveIdx           = static_cast<uint8_t>((*waveIdx + 1) % SCREEN_W);
            (*sampleCount)++;

            if (channelId == CH_AF7) pushBandSample(hp);
            updateBlinkFromSample(channelId, hp);
        }
    }
}

// ---- ISR ring buffer --------------------------------------------------------

void enqueueEegPacket(uint8_t channelId, uint8_t* data, size_t len) {
    if (len < EEG_PACKET_LEN) return;
    portENTER_CRITICAL(&gEegMux);
    const uint16_t next = static_cast<uint16_t>((gEegHead + 1) % EEG_QUEUE_LEN);
    if (next == gEegTail) { gEegDropped++; portEXIT_CRITICAL(&gEegMux); return; }
    gEegQueue[gEegHead].channelId = channelId;
    memcpy(gEegQueue[gEegHead].payload, data, EEG_PACKET_LEN);
    gEegHead = next;
    portEXIT_CRITICAL(&gEegMux);
}

void drainEegQueue() {
    for (int i = 0; i < 40; ++i) {
        EegPacket pkt; bool has = false;
        portENTER_CRITICAL(&gEegMux);
        if (gEegTail != gEegHead) {
            pkt = gEegQueue[gEegTail];
            gEegTail = static_cast<uint16_t>((gEegTail + 1) % EEG_QUEUE_LEN);
            has = true;
        }
        portEXIT_CRITICAL(&gEegMux);
        if (!has) break;
        processChannelPacket(pkt.channelId, pkt.payload);
    }
}

// ---- BLE notify callback (runs in BLE task context) -------------------------

void eegNotifyCb(NimBLERemoteCharacteristic* characteristic, uint8_t* data, size_t len, bool) {
    if      (characteristic == gAF7Char)  { gNotifyPktsAF7++;  enqueueEegPacket(CH_AF7,  data, len); }
    else if (characteristic == gAF8Char)  { gNotifyPktsAF8++;  enqueueEegPacket(CH_AF8,  data, len); }
    else if (characteristic == gTP9Char)  { gNotifyPktsOther++; enqueueEegPacket(CH_TP9,  data, len); }
    else if (characteristic == gTP10Char) { gNotifyPktsOther++; enqueueEegPacket(CH_TP10, data, len); }
    else if (characteristic == gRAuxChar) { gNotifyPktsOther++; enqueueEegPacket(CH_RAUX, data, len); }
}

// ---- BLE scanning & connection ----------------------------------------------

class MuseAdvertisedDeviceCallbacks : public NimBLEAdvertisedDeviceCallbacks {
    void onResult(NimBLEAdvertisedDevice* advertisedDevice) override {
        if (advertisedDevice->getName().rfind("Muse", 0) == 0) {
            Serial.printf("Found Muse: %s\n", advertisedDevice->getName().c_str());
            if (gTargetDevice) delete gTargetDevice;
            gTargetDevice = new NimBLEAdvertisedDevice(*advertisedDevice);
            NimBLEDevice::getScan()->stop();
        }
    }
};
MuseAdvertisedDeviceCallbacks gScanCallbacks;

class MuseClientCallbacks : public NimBLEClientCallbacks {
    void onConnect(NimBLEClient* client) override {
        gConnected = true; gNeedsReconnect = false; gConnStartMs = millis();
        NimBLEConnInfo info = client->getConnInfo();
        Serial.printf("Muse connected: itvl=%.2fms\n", info.getConnInterval() * 1.25f);
    }
    void onDisconnect(NimBLEClient*) override {
        gConnected = false; gStreamingStarted = false; gNeedsReconnect = true;
        Serial.println("Muse disconnected");
    }
    bool onConnParamsUpdateRequest(NimBLEClient*, const ble_gap_upd_params* p) override {
        // Accept only reasonable connection parameter updates from the headband.
        return (p->itvl_min >= 6 && p->itvl_min <= p->itvl_max &&
                p->itvl_max <= 120 && p->latency <= 30 &&
                p->supervision_timeout >= MUSE_MIN_ACCEPT_TIMEOUT &&
                p->supervision_timeout <= 3200);
    }
};

int museGapEventHandler(ble_gap_event* event, void*) {
    if (!event) return 0;
    if (event->type == BLE_GAP_EVENT_DISCONNECT) {
        gGapLastDisconnectReason = event->disconnect.reason;
        gGapDisconnectCount++;
        Serial.printf("[GAP] disconnect reason=0x%02X (%s)\n",
                      static_cast<unsigned>(event->disconnect.reason & 0xFF),
                      hciReasonToString(static_cast<uint8_t>(event->disconnect.reason)));
    }
    return 0;
}

void requestPreferredConnParams(const char* why) {
    if (!gClient || !gClient->isConnected()) return;
    gLastConnParamReqMs = millis();
    Serial.printf("[conn] requesting preferred params (%s)\n", why);
    gClient->updateConnParams(MUSE_CONN_ITVL_MIN, MUSE_CONN_ITVL_MAX,
                              MUSE_CONN_LATENCY, MUSE_CONN_TIMEOUT);
}

// Send a text command to the Muse control characteristic.
// The Muse protocol wraps the command as: [length+1][cmd bytes]['\n']
bool writeMuseCmd(const char* cmd) {
    if (!gCtrlChar || !cmd) return false;
    const size_t n = strlen(cmd);
    if (n > 16) return false;
    uint8_t payload[20] = {0};
    payload[0] = static_cast<uint8_t>(n + 1);
    for (size_t i = 0; i < n; ++i) payload[i + 1] = static_cast<uint8_t>(cmd[i]);
    payload[n + 1] = '\n';
    if (gCtrlChar->canWriteNoResponse() && gCtrlChar->writeValue(payload, n + 2, false)) return true;
    if (gCtrlChar->canWrite()           && gCtrlChar->writeValue(payload, n + 2, true))  return true;
    return false;
}

// Send preset p21 as a raw byte sequence (avoids encoding issues)
bool sendMusePreset21() {
    static const uint8_t cmd[5] = {0x04, 0x70, 0x32, 0x31, 0x0A};
    if (!gCtrlChar) return false;
    if (gCtrlChar->canWriteNoResponse() && gCtrlChar->writeValue(cmd, sizeof(cmd), false)) return true;
    if (gCtrlChar->canWrite()           && gCtrlChar->writeValue(cmd, sizeof(cmd), true))  return true;
    return false;
}

void teardownConnectionState() {
    gStreamingStarted = false; gConnStartMs = 0;
    if (gClient)       { NimBLEDevice::deleteClient(gClient); gClient = nullptr; }
    if (gTargetDevice) { delete gTargetDevice; gTargetDevice = nullptr; }
    gCtrlChar = gTP9Char = gAF7Char = gAF8Char = gTP10Char = gRAuxChar = nullptr;
}

bool connectToMuse() {
    if (!gTargetDevice) return false;

    NimBLEScan* scan = NimBLEDevice::getScan();
    if (scan && scan->isScanning()) scan->stop();

    bool connected = false;
    for (int attempt = 1; attempt <= 2; ++attempt) {
        gClient = NimBLEDevice::createClient();
        gClient->setClientCallbacks(new MuseClientCallbacks(), true);
        gClient->setConnectionParams(MUSE_CONNECT_ITVL_MIN, MUSE_CONNECT_ITVL_MAX,
                                     MUSE_CONN_LATENCY, MUSE_CONNECT_TIMEOUT,
                                     MUSE_CONN_SCAN_ITVL, MUSE_CONN_SCAN_WINDOW);
        gClient->setConnectTimeout(20);

        if (gClient->connect(gTargetDevice)) { connected = true; break; }

        Serial.printf("Connect failed (attempt %d/2)\n", attempt);
        NimBLEDevice::deleteClient(gClient); gClient = nullptr;
        delay(200);
    }
    if (!connected || !gClient) return false;

    requestPreferredConnParams("post_connect");
    gClient->setDataLen(120);

    NimBLERemoteService* museService = gClient->getService(MUSE_SERVICE_UUID);
    if (!museService) { Serial.println("Muse service not found"); gClient->disconnect(); return false; }

    gCtrlChar = museService->getCharacteristic(MUSE_CTRL_UUID);
    gTP9Char  = museService->getCharacteristic(MUSE_TP9_UUID);
    gAF7Char  = museService->getCharacteristic(MUSE_AF7_UUID);
    gAF8Char  = museService->getCharacteristic(MUSE_AF8_UUID);
    gTP10Char = museService->getCharacteristic(MUSE_TP10_UUID);
    gRAuxChar = museService->getCharacteristic(MUSE_RAUX_UUID);

    if (!gCtrlChar || !gAF7Char || !gAF8Char) {
        Serial.println("Required characteristics not found"); gClient->disconnect(); return false;
    }
    if (!gAF7Char->canNotify() || !gAF8Char->canNotify()) {
        Serial.println("AF7/AF8 notify not supported"); gClient->disconnect(); return false;
    }

    if (!gAF7Char->subscribe(true, eegNotifyCb) || !gAF8Char->subscribe(true, eegNotifyCb)) {
        Serial.println("Failed to subscribe to EEG"); gClient->disconnect(); return false;
    }

    delay(120);

    for (int i = 1; i <= 3; ++i) { if (sendMusePreset21()) break; delay(60); }
    for (int i = 1; i <= 5; ++i) { if (writeMuseCmd("d"))  { break; } delay(80); }

    resetWaveState();
    gStreamingStarted = true; gConnected = true; gNeedsReconnect = false;
    Serial.println("Muse stream started");
    return true;
}

void startScan() {
    NimBLEScan* scan = NimBLEDevice::getScan();
    if (scan->isScanning()) return;
    scan->setAdvertisedDeviceCallbacks(&gScanCallbacks, false);
    scan->setInterval(45); scan->setWindow(15);
    scan->setActiveScan(true); scan->setDuplicateFilter(true);
    scan->start(0, nullptr, false);
}

// ---- OLED rendering ---------------------------------------------------------

int sampleToY(float sample, float scale, int centerY) {
    if (scale < 20.0f) scale = 20.0f;
    const float n = clampf(sample / scale, -1.0f, 1.0f);
    return centerY - static_cast<int>(n * GRAPH_HALF_H);
}

void drawWave(const int16_t* wave, uint8_t writeIdx, float scale, int centerY, int top, int bottom) {
    for (int x = 0; x < SCREEN_W - 1; ++x) {
        const uint8_t i1 = static_cast<uint8_t>((writeIdx + x) % SCREEN_W);
        const uint8_t i2 = static_cast<uint8_t>((writeIdx + x + 1) % SCREEN_W);
        int y1 = clampf(static_cast<float>(sampleToY(static_cast<float>(wave[i1]), scale, centerY)), top, bottom);
        int y2 = clampf(static_cast<float>(sampleToY(static_cast<float>(wave[i2]), scale, centerY)), top, bottom);
        display.drawLine(x, y1, x + 1, y2, SSD1306_WHITE);
    }
}

void drawUi() {
    static uint32_t lastDraw = 0;
    const uint32_t  now      = millis();
    if (now - lastDraw < 50) return;  // cap at 20 fps
    lastDraw = now;

    display.clearDisplay();
    display.setTextColor(SSD1306_WHITE);
    display.setTextSize(1);
    display.setCursor(0, 0);

    if (!gConnected) {
        display.print("Muse: scanning...");
        display.setCursor(0, 12); display.print("Blinks: "); display.print(static_cast<unsigned long>(gBlinkCount));
        display.setCursor(0, 24); display.print("Mode: "); display.print(uiModeName(gUiMode));
        display.setCursor(0, 36); display.print("Btn D2: calibrate");
    } else if (gUiMode == UiMode::Calibrating) {
        const uint32_t sec = (now - gCalibrationStartMs) / 1000UL;
        display.print("CALIBRATING");
        display.setCursor(0, 12); display.print("Keep eyes open");
        display.setCursor(0, 22); display.print("Don't blink / be still");
        display.setCursor(0, 34); display.print("t="); display.print(static_cast<unsigned long>(sec)); display.print("s");
        display.setCursor(0, 46); display.print("noise=");
        display.print(static_cast<int>((gNoiseAbsAF7 + gNoiseAbsAF8) * 0.5f + 0.5f)); display.print("uV");
        display.setCursor(0, 56); display.print("Press D2 -> waveform");
    } else if (gUiMode == UiMode::Waveform) {
        display.print("AF7/AF8");
        display.setCursor(52, 0);
        display.print("A"); display.print(static_cast<int>(gBandAlpha * 100 + 0.5f));
        display.print(" B"); display.print(static_cast<int>(gBandBeta  * 100 + 0.5f));
        display.print(" G"); display.print(static_cast<int>(gBandGamma * 100 + 0.5f));
        if (gPoorContact) { display.setCursor(104, 56); display.print("BAD"); }
        display.drawLine(0, GRAPH1_TOP - 1, SCREEN_W - 1, GRAPH1_TOP - 1, SSD1306_WHITE);
        display.drawLine(0, GRAPH2_TOP - 1, SCREEN_W - 1, GRAPH2_TOP - 1, SSD1306_WHITE);
        drawWave(gWaveAF7, gWaveIdxAF7, gScaleAF7, AF7_CENTER_Y, GRAPH1_TOP, GRAPH1_BOTTOM);
        drawWave(gWaveAF8, gWaveIdxAF8, gScaleAF8, AF8_CENTER_Y, GRAPH2_TOP, GRAPH2_BOTTOM);
        display.setCursor(0, 56); display.print("D2->Blink");
    } else {
        // Blink counter: big number takes centre stage
        display.print("BLINK COUNT");
        display.setTextSize(4);
        display.setCursor(12, 18);
        display.print(static_cast<unsigned long>(gBlinkCount));
        display.setTextSize(1);
        display.setCursor(0, 56);
        if (gPoorContact)                 display.print("CHECK HEADBAND");
        else if (now < gBlinkActiveUntilMs) display.print("BLINK!");
        else                               display.print("Press D2: Cal");
    }

    display.display();
}

}  // namespace

// ---- Arduino entry points ---------------------------------------------------

void setup() {
    Serial.begin(115200);
    delay(200);

    pinMode(BUTTON_PIN, BUTTON_ACTIVE_LOW ? INPUT_PULLUP : INPUT);
    gButtonStableState = static_cast<bool>(digitalRead(BUTTON_PIN));
    gButtonLastRead    = gButtonStableState;
    gButtonLastEdgeMs  = millis();

    Wire.begin(I2C_SDA, I2C_SCL);
    if (!display.begin(SSD1306_SWITCHCAPVCC, OLED_ADDR)) {
        Serial.println("OLED init failed at 0x3C");
    } else {
        display.clearDisplay();
        display.setTextSize(1); display.setTextColor(SSD1306_WHITE);
        display.setCursor(0, 0);  display.println("Muse EEG Viewer");
        display.setCursor(0, 10); display.println("Scanning for Muse...");
        display.display();
    }

    NimBLEDevice::init("");
    NimBLEDevice::setMTU(185);
    NimBLEDevice::setPower(ESP_PWR_LVL_P9);
    NimBLEDevice::setCustomGapHandler(museGapEventHandler);
    startScan();

    Serial.println("Scanning for Muse 2...");
    Serial.printf("Button D2 cycles: blink-counter -> calibrate -> waveform\n");
}

void loop() {
    updateButton();

    // Detect phantom "still connected" state when the BLE stack has dropped the link
    if (gClient && !gClient->isConnected() && gConnected) {
        gConnected = false; gNeedsReconnect = true;
    }

    if (gNeedsReconnect) {
        gNeedsReconnect = false;
        teardownConnectionState();
        gNextConnectAttemptMs = millis() + 1200;
        startScan();
    }

    drainEegQueue();
    updateBandPowers();

    if (!gConnected && !gNeedsReconnect && gTargetDevice) {
        if (millis() >= gNextConnectAttemptMs) {
            Serial.println("Attempting Muse connection...");
            if (!connectToMuse()) {
                Serial.println("Connect failed, rescanning...");
                teardownConnectionState();
                gNextConnectAttemptMs = millis() + 2500;
                startScan();
            }
        }
    }

    drawUi();
    delay(8);
}
