import serial
import datetime
import numpy as np
import math
import time


class NodeController:
    lookup_cmd = {-1: 0b01,
                  0: 0b00,
                  1: 0b10}

    def __init__(self, port, steps_per_unit, steps_max):
        n = len(steps_per_unit)
        if n > 4:
            raise ValueError("Max 4 nodes are supported for now.")
        self.n = n
        self.vec_pos = np.zeros((n,), np.float64)
        self.vec_steps = np.zeros((n,), np.int32)
        self.vec_enabled = np.zeros((n,), np.bool)
        self.vec_endstop = np.zeros((n,), np.bool)
        self.steps_per_unit = np.array(steps_per_unit, np.float64)
        self.steps_max = np.array(steps_max, np.int32)
        self.port = port
        self.serial = None
        self.recv = 0

    def step(self, step_vals):
        # make a unit step in each axis, choice in [+1,0,-1]
        msg = 0
        for i, v in enumerate(step_vals):
            msg |= self.lookup_cmd[v] << (i*2)
        self.vec_steps += step_vals
        self.serial.write([msg])
        self.serial.flush()
        # naive way of doing it: take only 1 byte
        self.recv = self.serial.read()[0]
        self.refresh_status()

    def step_repeat(self, step_vals, n_repeat, delay=0.001):
        for i in range(n_repeat):
            self.step(step_vals)
            time.sleep(delay)

    def refresh_status(self):
        for i in range(self.n):
            status = (self.recv >> ((i + 4 - self.n)*2)) & 0b11
            self.vec_enabled[i] = status & 0b10 != 0
            self.vec_endstop[i] = status & 0b01 == 0

    # def get_enabled(self):
    #     for i in range(self.n):
    #         self.recv <<
    #     return

    # def set_enabled(self):
    #     self.get_enabled()

    def homing(self):
        vec_step = np.zeros((self.n,), np.int32)
        for i in range(self.n):
            vec_step[:] = 0
            vec_step[i] = 1
            while not self.vec_endstop[i]:
                self.step(vec_step)
                time.sleep(0.0001)
        self.override_steps(self.steps_max)

    def override_steps(self, steps):
        self.vec_steps[:] = steps
        self.vec_pos[:] = self.vec_steps / self.steps_per_unit

    def override_pos(self, pos):
        self.vec_pos[:] = pos
        self.vec_steps = self.vec_pos * self.steps_per_unit

    def goto(self, pos, feedrate):
        # pos in [units], feedrate in [units/s]
        delta = np.array(pos) - self.vec_pos
        dist = np.linalg.norm(delta)
        duration = dist/feedrate
        t = datetime.datetime.now()
        dt = 0

        pos_init = self.vec_pos.copy()
        while dt < duration:
            dt = (datetime.datetime.now() - t).total_seconds()
            pos_new = pos_init + delta * dt/duration
            self.update_pos(pos_new)

    def update_pos(self, pos):
        self.vec_pos[:] = pos
        pos_step_new = (self.vec_pos*self.steps_per_unit).astype(np.int32)
        diff = pos_step_new-self.vec_steps
        diff_sign = np.sign(diff)
        if not np.all(diff == 0):
            self.step(diff_sign)

    def connect(self):
        if self.serial is None:
            self.serial = serial.Serial(self.port, timeout=0.1)
            # send a null command
            self.step([0]*self.n)
        else:
            raise ValueError("Already connected!")

    def close(self):
        if self.serial is not None:
            self.serial.close()

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
