from machine import Pin, I2C
import math
import neopixel
import time


PINS = [1, 2, 3, 4, 9, 5, 6, 7, 8]
NUM = 21

# 40 is about 16% of a full 8-bit channel.
BRIGHTNESS = 40

# Options: "calm", "fire", "disco", or "panic".
START_MODE = "calm"
MODE_SEQUENCE = ("calm", "fire", "disco", "panic")
MODE_DURATION_MS = 30000

# Animation colors are written as RGB. The physical LEDs expect GRB.
COLOR_ORDER = "GRB"

# Keep this true if the first LED is physically hidden or too close to the PCB.
FIRST_LED_OFF = True

FRAME_DELAY = 0.02
TABLE_SIZE = 256
CALM_TABLE_SIZE = 512

HALL_ADDRESS = 0x35
HALL_SDA_PIN = 26
HALL_SCL_PIN = 27
HALL_READ_INTERVAL_MS = 100
HALL_MOVEMENT_THRESHOLD = 15

CALM_PALETTE = [
    (0x02, 0x06, 0x48),  # deep royal blue
    (0x00, 0x24, 0x18),  # dark green shadow
    (0x00, 0x12, 0x88),  # blue accent
    (0x18, 0x00, 0x68),  # dark violet
    (0x58, 0x00, 0xB8),  # clear purple
    (0xA8, 0x00, 0xE8),  # pronounced violet-magenta
    (0x00, 0x38, 0x28),  # deep teal-green
    (0xC8, 0x18, 0xFF),  # bright purple glow
    (0x68, 0x00, 0xC8),  # saturated purple return
    (0x18, 0x00, 0x68),  # dark violet return
    (0x02, 0x06, 0x48),  # seamless return to blue
]

FIRE_LOOP_PALETTE = [
    (0x18, 0x00, 0x00),  # dark ember
    (0x50, 0x00, 0x00),  # deep red
    (0x95, 0x0B, 0x00),  # hot red
    (0xE0, 0x30, 0x00),  # orange-red
    (0xFF, 0x68, 0x00),  # orange lick
    (0xA0, 0x10, 0x00),  # cooling red
    (0x35, 0x00, 0x00),  # fading ember
    (0x18, 0x00, 0x00),  # seamless loop back to dark
]

DISCO_PALETTE = [
    (0xFF, 0x00, 0xFF),  # magenta
    (0x00, 0xFF, 0xFF),  # cyan
    (0xFF, 0xFF, 0x00),  # yellow
    (0x00, 0xFF, 0x44),  # acid green
    (0xFF, 0x33, 0x00),  # hot orange
    (0x5B, 0x00, 0xFF),  # violet
    (0xFF, 0xFF, 0xFF),  # white flash
    (0xFF, 0x00, 0xFF),  # seamless loop back to magenta
]


strips = [
    neopixel.NeoPixel(Pin(pin, Pin.OUT), NUM)
    for pin in PINS
]


def ease_in_out_sine(t):
    return -(math.cos(math.pi * t) - 1) / 2


def make_wave_table(size):
    table = bytearray(size)

    for step in range(size):
        value = (math.sin(step * 2 * math.pi / size) + 1) * 0.5
        table[step] = int(value * 255)

    return table


def make_color_table(palette, size):
    table = bytearray(size * 3)
    palette_steps = len(palette)

    for step in range(size):
        position = step * palette_steps / size
        index_a = int(position)
        index_b = (index_a + 1) % palette_steps
        mix = ease_in_out_sine(position - index_a)

        color_a = palette[index_a]
        color_b = palette[index_b]

        offset = step * 3
        table[offset] = int(color_a[0] + (color_b[0] - color_a[0]) * mix)
        table[offset + 1] = int(color_a[1] + (color_b[1] - color_a[1]) * mix)
        table[offset + 2] = int(color_a[2] + (color_b[2] - color_a[2]) * mix)

    return table


WAVE = make_wave_table(TABLE_SIZE)
CALM_WAVE = make_wave_table(CALM_TABLE_SIZE)
CALM_COLORS = make_color_table(CALM_PALETTE, CALM_TABLE_SIZE)
FIRE_LOOP_COLORS = make_color_table(FIRE_LOOP_PALETTE, TABLE_SIZE)
DISCO_COLORS = make_color_table(DISCO_PALETTE, TABLE_SIZE)


def to_led_order(rgb):
    if COLOR_ORDER == "GRB":
        return (rgb[1], rgb[0], rgb[2])

    return rgb


def apply_level(color, level):
    scale = BRIGHTNESS * level
    rgb = (
        color[0] * scale // 65025,
        color[1] * scale // 65025,
        color[2] * scale // 65025,
    )

    return to_led_order(rgb)


def apply_brightness(color):
    rgb = (
        color[0] * BRIGHTNESS // 255,
        color[1] * BRIGHTNESS // 255,
        color[2] * BRIGHTNESS // 255,
    )

    return to_led_order(rgb)


def table_color(table, index):
    offset = index * 3
    return (
        table[offset],
        table[offset + 1],
        table[offset + 2],
    )


def clear_all():
    for strip in strips:
        for pixel in range(NUM):
            strip[pixel] = (0, 0, 0)
        strip.write()


def parse_hall_data(data):
    x = (data[0] << 4) + ((data[4] & 0xF0) >> 4)
    y = (data[1] << 4) + (data[4] & 0x0F)
    z = (data[2] << 4) + (data[5] & 0x0F)

    if x > 2047:
        x -= 4096
    if y > 2047:
        y -= 4096
    if z > 2047:
        z -= 4096

    return (x, y, z)


def setup_hall_sensor():
    try:
        sensor = I2C(
            1,
            sda=Pin(HALL_SDA_PIN),
            scl=Pin(HALL_SCL_PIN),
            freq=400000,
        )
        sensor.writeto(HALL_ADDRESS, bytes([0x10, 0x28, 0x15]))
        print("HALL_READY")
        return sensor
    except Exception as error:
        print("HALL_ERROR:", error)
        return None


def read_hall_vector(sensor):
    return parse_hall_data(sensor.readfrom(HALL_ADDRESS, 6))


def hall_moved(current, previous):
    if previous is None:
        return False

    delta = (
        abs(current[0] - previous[0]) +
        abs(current[1] - previous[1]) +
        abs(current[2] - previous[2])
    )

    return delta >= HALL_MOVEMENT_THRESHOLD


phase_a = 0
phase_b = 96
phase_c = 180
color_phase = 0
calm_wave_phase = 0
calm_color_phase = 0
fire_phase = 0
spark_phase = 140
fire_frame_count = 0
disco_phase = 0
disco_blob_phase = 72
panic_phase = 0
panic_gate = 0
calm_frame_count = 0
current_mode = START_MODE
mode_index = MODE_SEQUENCE.index(START_MODE)
next_mode_at = time.ticks_add(time.ticks_ms(), MODE_DURATION_MS)
hall_sensor = setup_hall_sensor()
last_hall_read = time.ticks_ms()
last_hall_vector = None


def first_active_pixel(strip):
    if FIRST_LED_OFF:
        strip[0] = (0, 0, 0)
        return 1

    return 0


def show_calm_frame():
    for strand_index, strip in enumerate(strips):
        first_pixel = first_active_pixel(strip)

        for bead_index in range(first_pixel, NUM):
            # One broad wave moves through the curtain. No noisy interference.
            wave = CALM_WAVE[
                (calm_wave_phase + bead_index * 20 + strand_index * 8) %
                CALM_TABLE_SIZE
            ]

            # Gentle brightness movement avoids low-level stepping/flicker.
            level = 34 + wave * 48 // 255

            # A separate, broad color wave: purple/blue/green drift smoothly.
            color_index = (
                calm_color_phase
                + strand_index * 40
                + bead_index * 8
            ) % CALM_TABLE_SIZE

            strip[bead_index] = apply_level(
                table_color(CALM_COLORS, color_index),
                level,
            )

        strip.write()


def show_fire_frame():
    for strand_index, strip in enumerate(strips):
        first_pixel = first_active_pixel(strip)

        for bead_index in range(first_pixel, NUM):
            # Fire uses the same trick as the smooth rainbow: constant brightness,
            # integer phase motion, and one seamless color loop.
            color_index = (
                fire_phase
                - bead_index * 9
                + strand_index * 13
            ) & 255

            strip[bead_index] = apply_brightness(
                table_color(FIRE_LOOP_COLORS, color_index)
            )

        strip.write()


def show_disco_frame():
    for strand_index, strip in enumerate(strips):
        first_pixel = first_active_pixel(strip)

        for bead_index in range(first_pixel, NUM):
            # Big conflicting color fields: disco should be loud, not tasteful.
            blob = WAVE[
                (disco_blob_phase + strand_index * 37 - bead_index * 19) & 255
            ]

            checker = 96 if ((strand_index + bead_index + disco_phase // 16) & 1) else 0
            kick = 72 if ((disco_phase + strand_index * 9) & 31) < 8 else 0

            color_index = (
                disco_phase
                + strand_index * 47
                + bead_index * 29
                + blob
                + checker
                + kick
            ) & 255

            strip[bead_index] = apply_brightness(
                table_color(DISCO_COLORS, color_index)
            )

        strip.write()


def show_panic_frame():
    ping_position = panic_gate % ((NUM - 1) * 2)

    if ping_position >= NUM:
        ping_position = (NUM - 1) * 2 - ping_position

    for strand_index, strip in enumerate(strips):
        first_pixel = first_active_pixel(strip)

        for bead_index in range(first_pixel, NUM):
            diagonal = (panic_phase + strand_index * 19 + bead_index * 11) & 255
            reverse = (panic_phase * 2 - strand_index * 37 + bead_index * 23) & 255
            ping = (ping_position + strand_index * 2) % NUM
            ping_neighbor = (ping + 1) % NUM
            strobe_on = (panic_gate & 3) == 0

            # Bright ping-pong hits, still capped by BRIGHTNESS.
            if strobe_on and (bead_index == ping or bead_index == ping_neighbor):
                color = (0xFF, 0xFF, 0xFF)
            elif ((panic_gate + strand_index) & 7) == 0 and bead_index == ((panic_gate * 3 + strand_index * 5) % NUM):
                color = (0xFF, 0xFF, 0xFF)
            # Hard alarm language: red and white, without full black dropouts.
            elif diagonal < 46:
                color = (0xFF, 0xFF, 0xFF)
            elif reverse < 90:
                color = (0xFF, 0x00, 0x18)
            elif ((strand_index + bead_index + panic_gate) & 3) == 0:
                color = (0xFF, 0x00, 0x00)
            else:
                color = (0x28, 0x00, 0x00)

            strip[bead_index] = apply_brightness(color)

        strip.write()


clear_all()
print("MODE_READY:" + current_mode)

while True:
    now = time.ticks_ms()

    if hall_sensor is not None and time.ticks_diff(now, last_hall_read) >= HALL_READ_INTERVAL_MS:
        try:
            hall_vector = read_hall_vector(hall_sensor)

            if hall_moved(hall_vector, last_hall_vector):
                if current_mode != "calm":
                    print("HALL_MOVEMENT -> calm")

                current_mode = "calm"
                mode_index = 0
                next_mode_at = time.ticks_add(now, MODE_DURATION_MS)

            last_hall_vector = hall_vector
        except Exception as error:
            print("HALL_READ_ERROR:", error)

        last_hall_read = now

    if time.ticks_diff(now, next_mode_at) >= 0:
        mode_index = (mode_index + 1) % len(MODE_SEQUENCE)
        current_mode = MODE_SEQUENCE[mode_index]
        next_mode_at = time.ticks_add(now, MODE_DURATION_MS)
        print("MODE:", current_mode)

    if current_mode == "fire":
        show_fire_frame()
    elif current_mode == "disco":
        show_disco_frame()
    elif current_mode == "panic":
        show_panic_frame()
    else:
        show_calm_frame()

    # Calm moves as slow broad waves; other modes keep their own stronger motion.
    calm_frame_count = (calm_frame_count + 1) & 255
    calm_wave_phase = (calm_wave_phase + 1) % CALM_TABLE_SIZE
    if calm_frame_count & 1:
        calm_color_phase = (calm_color_phase + 1) % CALM_TABLE_SIZE
    fire_frame_count = (fire_frame_count + 1) & 255
    if fire_frame_count & 1:
        fire_phase = (fire_phase + 1) & 255
        spark_phase = (spark_phase + 1) & 255
    disco_phase = (disco_phase + 13) & 255
    disco_blob_phase = (disco_blob_phase - 9) & 255
    panic_phase = (panic_phase + 29) & 255
    panic_gate = (panic_gate + 1) & 255

    time.sleep(FRAME_DELAY)
