import rp2
from machine import Pin
import time
from math import sqrt

# our steppers have 200 steps/turn
# By default, the pololu DRV8834 module is in 1-/4 step mode
# => one turn should be 800 tuens

@rp2.asm_pio( out_init=(rp2.PIO.OUT_LOW, rp2.PIO.OUT_LOW) )
def start():
    wrap_target()
    pull()         # pull a value from the RX-FIFO (fed by the DMA) to OSR
    out(pins, 32)
    irq(0)
    irq(1)
    wait(1, irq, 2)
    wait(1, irq, 3)
    set(isr,1)
    # push()
    wrap()

@rp2.asm_pio( sideset_init=(rp2.PIO.OUT_LOW) )
def step0():
    wrap_target()
    pull()                  # pull a value from the RX-FIFO (fed by the DMA) to OSR
    out(isr, 16)            # copy the OSR's 16 lsb to ISR (step duration)
    out(y, 16)              # copy the OSR's 16 msb to y (nb. of steps)
    wait(1, irq, 0)
    label("step")
    mov(x, isr).side(1)             # initialize the duration counter
    label("counting")
    jmp(x_dec, "counting").side(0)
    jmp(y_dec, "step")
    irq(2)
    wrap()
    
@rp2.asm_pio( sideset_init=(rp2.PIO.OUT_LOW) )
def step1():
    wrap_target()
    pull()                  # pull a value from the RX-FIFO (fed by the DMA) to OSR
    out(isr, 16)            # copy the OSR's 16 lsb to ISR (step duration)
    out(y, 16)              # copy the OSR's 16 msb to y (nb. of steps)
    wait(1, irq, 1)
    label("step")
    mov(x, isr).side(1)             # initialize the duration counter
    label("counting")
    jmp(x_dec, "counting").side(0)
    jmp(y_dec, "step")
    irq(3)
    wrap()

class Steppers():

    def __init__(self, dir0Pin=27, step0Pin=29, step1Pin=6, endXPin=2, endYPin=4, M0Pin=7):
        # PIO init
        self.freq = 50000
        self.sm0 = rp2.StateMachine(0, step0, sideset_base=Pin(step0Pin), freq=self.freq)
        self.sm1 = rp2.StateMachine(1, step1, sideset_base=Pin(step1Pin), freq=self.freq)
        self.sm0.active(1)
        self.sm1.active(1)
        self.sm2 = rp2.StateMachine(2, start, out_base=Pin(dir0Pin), freq=10000000)
        self.sm2.active(1) 
        # DMA init
        self.d0 = rp2.DMA()                     # creates the DMA channel
        c0 = self.d0.pack_ctrl( size = 2,       # Transfer 32-bit words
                          inc_write = False,    # don't increment the write address
                          treq_sel = 0 )        # transfer is initiated by the state machine
        self.d0.config( write = self.sm0,       # data destination is our state machine RX FIFO
                        ctrl = c0,
                        trigger = False )
        self.d1 = rp2.DMA()                     # creates the DMA channel
        c1 = self.d1.pack_ctrl( size = 2,       # Transfer 32-bit words
                          inc_write = False,    # don't increment the write address
                          treq_sel = 1 )        # transfer is initiated by the state machine
        self.d1.config( write = self.sm1,       # data destination is our state machine RX FIFO
                        ctrl = c1,
                        trigger = False )
        # end-switches pin init
        self.endX = Pin(endXPin, Pin.IN)
        self.endY = Pin(endYPin, Pin.IN)
        # Microstepping
        M0 = Pin(M0Pin, Pin.IN)
        M0.off()
        # current position
        self.x = None
        self.y = None
        # geometric constant: steps * teeth Nb. * pitch
        self.stepPerMm = 800 // ( 20 * 2.5)
        # default speeds
        self.slowSpeed = 10	# 1 cm/s
        self.nomSpeed = 50		# 5 cm/s
        self.slowPer = int(self.freq / (self.slowSpeed * self.stepPerMm))
        self.nomPer = int(self.freq / (self.nomSpeed * self.stepPerMm))
        

    def moveEncode(self, stepNb, stepDuration):
        stepDuration = stepDuration - 3
        if stepNb == 0:
            data = []
            dir = -1
        elif stepNb > 0:
            stepNb = stepNb - 1
            data = bytearray( [ stepNb%256, stepNb//256, stepDuration%256, stepDuration//256 ] )
            dir = 0
        else:
            stepNb = -stepNb - 1
            data = bytearray( [ stepNb%256, stepNb//256, stepDuration%256, stepDuration//256 ] )
            dir = 1
        return dir, data       

    def moveStep(self, stepNb0, stepDuration0, stepNb1, stepDuration1):
        dir0, data0 = self.moveEncode(stepNb0, stepDuration0)
        if dir0 != -1:
            self.d0.config( read=data0, count=1, trigger=True )
            dir = dir0
        else:
            dir = 0        
        dir1, data1 = self.moveEncode(stepNb1, stepDuration1)
        if dir1 != -1:
            self.d1.config( read=data1, count=1, trigger=True )
            dir = dir + 2*dir1
        self.sm2.put(dir)
    
    def endXPressed(self):
        return self.endX.value() == 0

    def endYPressed(self):
        return self.endY.value() == 0

    def homing(self):
        while not self.endYPressed():
            self.moveRelXY(0,-1, self.slowSpeed)
        # print("Y homed")
        while not self.endXPressed():
            self.moveRelXY(-1,0, self.slowSpeed)
        # print("X homed")
        self.moveRelXY(10,10, self.nomSpeed)
        self.x = 0
        self.y = 0

    def moveRelXY(self, dx, dy, speed):
        dA = dx + dy
        dB = dx - dy
        l = sqrt(dx*dx + dy*dy)
        T = l / speed
        nA = int(dA * self.stepPerMm)
        nB = int(dB * self.stepPerMm)
        if nA != 0:
            tA = abs(int(self.freq*T/nA))
        else:
            nA = 1
            tA = 0
        if nB !=0:
            tB = abs(int(self.freq*T/nB))
        else:
            nB = 1
            tB = 0
        self.moveStep(nA, tA, nB, tB)
        time.sleep(T)

    
    def moveAbsXY(self, x, y):
        #print(f"Rel. move {x-self.x}, {y-self.y}") 
        self.moveRelXY(x-self.x, y-self.y, self.nomSpeed)
        self.x = x
        self.y = y
        #print(self.x, self.y)


if  __name__=="__main__":
    table = Steppers()
    nSLEEP = Pin(1, Pin.OUT)
    nSLEEP.on()
    table.homing()
    table.x = 0
    table.y = 0
    nSLEEP.off()
    print("Homed")
    while True:
        newX = int(input())
        newY = int(input())
        nSLEEP.on()
        table.moveAbsXY(newX, newY)
        nSLEEP.off()
        print("OK")
