import os
import time
import threading
import math
import argparse
import xml.etree.ElementTree as ET
from svg.path import parse_path
import motor

# with new API access
class CoreXYZMachine:
    def __init__(self):
        self.current_factor = 0.45
        # these are our limits, don't change! :(
        self.velocity = 2.0
        self.accel = 1.0
        
        self.pos_x, self.pos_y, self.pos_z = 0.0, 0.0, 0.0
        self.origin_set = False
        self.emergency_stop = False
        self.serial_lock = threading.Lock()
        self.connected = False

        try:
            self.m1 = motor.Motor("/dev/cu.usbmodem1101") # Motor A
            self.m2 = motor.Motor("/dev/cu.usbmodem1201") # Motor B
            self.m3 = motor.Motor("/dev/cu.usbmodem1301") # Motor C
            self.connected = True
            print("Connected...")
        except Exception as e:
            print(f"Failed: {e}")
            return

        with self.serial_lock:
            for m in [self.m1, self.m2, self.m3]:
                m.set_steps_per_unit(200)
                m.set_current(self.current_factor)
                m.set_position(0.0)
                m.set_target_position(0.0, self.velocity, self.accel)

    def set_origin(self):
        self.pos_x, self.pos_y, self.pos_z = 0.0, 0.0, 0.0
        self.origin_set = True
        self.emergency_stop = False

        if self.connected:
            with self.serial_lock:
                self.m1.set_position(0.0)
                self.m2.set_position(0.0)
                self.m3.set_position(0.0)
                
                self.m1.set_target_position(0.0, self.velocity, self.accel)
                self.m2.set_target_position(0.0, self.velocity, self.accel)
                self.m3.set_target_position(0.0, self.velocity, self.accel)

        print("Origin set to X:0, Y:0, Z:0 at Back-Left-Top (red sticker)")

    def move_rel(self, dx, dy, dz, dist):
        if self.emergency_stop: return
        new_x = self.pos_x + (dx * dist)
        new_y = self.pos_y + (dy * dist)
        new_z = self.pos_z + (dz * dist)
        self.move_absolute(new_x, new_y, new_z)

    def move_absolute(self, x, y, z):
        if self.emergency_stop: return
        
        # Hardware clamping
        if self.origin_set:
            x = max(0.0, min(180.0, x))
            y = max(0.0, min(180.0, y))
            z = max(0.0, min(65.0, z))

        self.pos_x, self.pos_y, self.pos_z = x, y, z
        self.execute_kinematics(self.pos_x, self.pos_y, self.pos_z)

    def wait_for_move(self):
        if not self.connected: return
        while True:
            if self.emergency_stop: break
            with self.serial_lock:
                s1 = self.m1.get_states()
                s2 = self.m2.get_states()
                s3 = self.m3.get_states()
            
            moving = s1["is_moving"] or s2["is_moving"] or s3["is_moving"]
            if not moving:
                break
            time.sleep(0.01)

    def execute_kinematics(self, x_mm, y_mm, z_mm):
        if not self.connected: return
        capstan_dia = 19.0
        z_pitch = 2.0
        z_start = 4.0  
        
        xy_mm_to_rev = 1.0 / (capstan_dia * math.pi)
        z_mm_to_rev = 1.0 / (z_pitch * z_start) 

        x_rev = x_mm * xy_mm_to_rev
        y_rev = -y_mm * xy_mm_to_rev
        z_rev = -z_mm * z_mm_to_rev 

        target_a = y_rev + x_rev
        target_b = y_rev - x_rev
        target_c = z_rev

        with self.serial_lock:
            self.m1.set_target_position(target_a, self.velocity, self.accel)
            self.m2.set_target_position(target_b, self.velocity, self.accel)
            self.m3.set_target_position(target_c, self.velocity, self.accel)

    def disable_all(self):
        self.emergency_stop = True
        if self.connected:
            with self.serial_lock:
                self.m1.set_current(0)
                self.m2.set_current(0)
                self.m3.set_current(0)
        print("Motors Disabled...")

    # extract SVG logic
    def extract_svg_paths(self, svg_filepath):
        tree = ET.parse(svg_filepath)
        root = tree.getroot()
        path_strings = []
        for elem in root.iter():
            if elem.tag.endswith('path'):
                d_attr = elem.attrib.get('d')
                if d_attr:
                    path_strings.append(d_attr)
        return path_strings

    # change <d> paths into toolpaths
    def discretize_path(self, d_string, num_points=15):
        parsed_path = parse_path(d_string)
        points = []
        draw_segments = [seg for seg in parsed_path if type(seg).__name__ not in ['Move', 'Close']]
        
        if not draw_segments: return points
        
        total_length = sum(seg.length() for seg in draw_segments)
        
        for i in range(num_points):
            t = i / float(num_points - 1)
            if total_length < 1e-5:
                pt = draw_segments[0].point(0)
                points.append((pt.real, pt.imag))
                continue
                
            target_len = t * total_length
            current_len = 0.0
            
            for seg in draw_segments:
                seg_len = seg.length()
                if current_len + seg_len >= target_len - 1e-6:
                    seg_t = (target_len - current_len) / seg_len if seg_len > 0 else 0
                    pt = seg.point(seg_t)
                    points.append((pt.real, pt.imag))
                    break 
                current_len += seg_len
        return points

    def generate_transit_line(self, x1, y1, x2, y2, steps=8):
        points = []
        for i in range(steps):
            t = i / float(steps - 1) if steps > 1 else 1.0
            points.append((x1 + (x2 - x1) * t, y1 + (y2 - y1) * t))
        return points

    #draw routine
    def draw_kanji_routine(self, kanji_input):
        script_dir = os.path.dirname(os.path.abspath(__file__))
        
        # 5-char padded unicode character
        kanji_input = kanji_input.strip()
        if len(kanji_input) == 1:
            hex_code = f"{ord(kanji_input):05x}"
        else:
            hex_code = kanji_input.lower()
            
        svg_filename = f"{hex_code}.svg"
        svg_file_path = os.path.join(script_dir, "kanji", svg_filename)
        
        if not os.path.exists(svg_file_path):
            print(f"Error: {svg_file_path} not found.")
            return

        print(f"Chotto matte kudasai... Generating...")
        paths = self.extract_svg_paths(svg_file_path)
        all_strokes = [self.discretize_path(d) for d in paths]
        min_x = min_y = float('inf')
        max_x = max_y = float('-inf')
        
        for stroke in all_strokes:
            for pt in stroke:
                min_x, max_x = min(min_x, pt[0]), max(max_x, pt[0])
                min_y, max_y = min(min_y, pt[1]), max(max_y, pt[1])
                
        svg_w = max_x - min_x
        svg_h = max_y - min_y

        pad = 10.0
        canvas_w = 180.0 - (pad * 2)
        canvas_h = 180.0 - (pad * 2)
        
        scale = min(canvas_w / svg_w, canvas_h / svg_h) if svg_w > 0 and svg_h > 0 else 1.0
        
        #accessible raw movements
        def to_machine(rx, ry):
            mx = (rx - min_x) * scale + pad + (canvas_w - svg_w * scale) / 2
            my = (ry - min_y) * scale + pad + (canvas_h - svg_h * scale) / 2
            return mx, my

        z_up = 10.0
        z_down = 60.0
        
        print("Moving Z height...")
        self.move_absolute(self.pos_x, self.pos_y, z_up)
        self.wait_for_move()
        
        current_mx, current_my = self.pos_x, self.pos_y

        for index, stroke_points in enumerate(all_strokes):
            if not stroke_points or self.emergency_stop: break
            print(f"Drawing Stroke: {index + 1}/{len(all_strokes)}...")
            
            start_rx, start_ry = stroke_points[0]
            target_start_mx, target_start_my = to_machine(start_rx, start_ry)
            
            if self.emergency_stop: return

            #transit lines in between strokes
            transit_points = self.generate_transit_line(current_mx, current_my, target_start_mx, target_start_my)
            for pt in transit_points:
                if self.emergency_stop: return
                self.move_absolute(pt[0], pt[1], z_up)
                self.wait_for_move()

            self.move_absolute(target_start_mx, target_start_my, z_down)
            self.wait_for_move()
            time.sleep(0.1) 
            
            for pt in stroke_points[1:]:
                if self.emergency_stop: return
                mx, my = to_machine(pt[0], pt[1])
                self.move_absolute(mx, my, z_down)
                current_mx, current_my = mx, my
                self.wait_for_move()
                
            # Lift
            self.move_absolute(current_mx, current_my, z_up)
            self.wait_for_move()
            time.sleep(0.1) 

        print("Complete...")
        if not self.emergency_stop:
            home_points = self.generate_transit_line(current_mx, current_my, 0.0, 0.0)
            for pt in home_points:
                self.move_absolute(pt[0], pt[1], z_up)
                self.wait_for_move()
                
            self.move_absolute(0.0, 0.0, z_up-10)
            self.wait_for_move()
            print("Job Finished.")


#GUI wrapper if you don't have a flask server running
import tkinter as tk

class CoreXYZApp:
    def __init__(self, root, machine):
        self.root = root
        self.machine = machine
        self.root.title("KANJIXYZ")
        self.root.geometry("400x700") 
        
        self.setup_ui()
        self.update_gui_loop()

    def setup_ui(self):
        inst_frame = tk.Frame(self.root, pady=10)
        inst_frame.pack(fill="x", padx=10)
        
        instructions = (
            "1. Move Z axis to middle first *IMPORTANT*\n"
            "2. Move X/Y axis to the back left corner (see sticker)\n"
            "3. Move the Z axis as high as possible (hit limit switch)\n"
            "4. Move X/Y gently into the back left corner again\n"
            "5. Click 'SET ORIGIN'"
        )
        tk.Label(inst_frame, text=instructions, font=("Arial", 10), fg="darkblue", justify="left").pack(pady=5)

        self.origin_btn = tk.Button(inst_frame, text="SET ORIGIN (0, 0, 0)\nAT BACK-LEFT", 
                                    font=("Arial", 11, "bold"), bg="lightblue",
                                    command=self.set_origin_ui)
        self.origin_btn.pack(fill="x", pady=5)

        kanji_frame = tk.Frame(inst_frame, pady=5)
        kanji_frame.pack(fill="x")
        tk.Label(kanji_frame, text="Kanji or Hex:", font=("Arial", 11)).pack(side=tk.LEFT)
        self.kanji_ent = tk.Entry(kanji_frame, width=12, font=("Arial", 12))
        self.kanji_ent.insert(0, "")
        self.kanji_ent.pack(side=tk.LEFT, padx=5)

        self.draw_btn = tk.Button(inst_frame, text="DRAW SVG", 
                                  font=("Arial", 11, "bold"), bg="#e0e0e0", state=tk.DISABLED,
                                  command=self.start_draw_kanji)
        self.draw_btn.pack(fill="x", pady=5)

        input_frame = tk.Frame(self.root, pady=5)
        input_frame.pack()
        tk.Label(input_frame, text="Step Distance (mm):", font=("Arial", 12)).pack(side=tk.LEFT)
        self.step_ent = tk.Entry(input_frame, width=8, font=("Arial", 12))
        self.step_ent.insert(0, "10.0")
        self.step_ent.pack(side=tk.LEFT, padx=5)

        self.status_label = tk.Label(self.root, text="POSITION UNCALIBRATED", 
                                    font=("Arial", 16, "bold"), fg="orange", pady=10)
        self.status_label.pack()

        pad_frame = tk.Frame(self.root)
        pad_frame.pack(pady=5)
        tk.Button(pad_frame, text="Y- (Back)", width=10, height=2, 
                command=lambda: self.trigger_move_rel(0, -1, 0)).grid(row=0, column=1, pady=5)
        tk.Button(pad_frame, text="X- (Left)", width=10, height=2, 
                command=lambda: self.trigger_move_rel(-1, 0, 0)).grid(row=1, column=0, padx=5)
        tk.Button(pad_frame, text="X+ (Right)", width=10, height=2, 
                command=lambda: self.trigger_move_rel(1, 0, 0)).grid(row=1, column=2, padx=5)
        tk.Button(pad_frame, text="Y+ (Front)", width=10, height=2, 
                command=lambda: self.trigger_move_rel(0, 1, 0)).grid(row=2, column=1, pady=5)

        z_frame = tk.Frame(self.root, pady=10)
        z_frame.pack()
        tk.Button(z_frame, text="Z- (Up)", width=12, height=2, fg="darkgreen",
                command=lambda: self.trigger_move_rel(0, 0, -1)).grid(row=0, column=0, padx=10)
        tk.Button(z_frame, text="Z+ (Down)", width=12, height=2, fg="darkred",
                command=lambda: self.trigger_move_rel(0, 0, 1)).grid(row=0, column=1, padx=10)
                
        tk.Button(self.root, text="DISABLE MOTORS / E-STOP", fg="white", bg="black", font=("Arial", 10, "bold"),
                command=self.machine.disable_all).pack(pady=15)

    def set_origin_ui(self):
        self.machine.set_origin()
        self.origin_btn.config(text="RECALIBRATE ORIGIN", bg="lightgreen")
        self.draw_btn.config(state=tk.NORMAL, bg="#d8b4e2") 

    def trigger_move_rel(self, dx, dy, dz):
        try:
            dist = float(self.step_ent.get())
            self.machine.move_rel(dx, dy, dz, dist)
        except ValueError:
            print("Error: Invalid step distance entered.")

    def start_draw_kanji(self):
        if not self.machine.origin_set:
            print("Set origin before drawing kudasai!")
            return
        kanji_input = self.kanji_ent.get()
        threading.Thread(target=self.machine.draw_kanji_routine, args=(kanji_input,), daemon=True).start()

    def update_gui_loop(self):
        if self.machine.emergency_stop:
            self.status_label.config(text="MOTORS DISABLED", fg="red")
        elif self.machine.origin_set:
            self.status_label.config(text=f"X:{self.machine.pos_x:.2f}  Y:{self.machine.pos_y:.2f}  Z:{self.machine.pos_z:.2f}", fg="blue")
        else:
            self.status_label.config(text=f"JOGGING (Uncalibrated)\nX:{self.machine.pos_x:.2f}  Y:{self.machine.pos_y:.2f}  Z:{self.machine.pos_z:.2f}", fg="orange")
            
        self.root.after(100, self.update_gui_loop)

#CLI and API access 
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="KanjiXYZ ")
    parser.add_argument('--headless', action='store_true', help="API Mode)")
    parser.add_argument('--draw', type=str, help="Kanji to draw")
    parser.add_argument('--origin', action='store_true', help="Set origin automatically on startup")
    parser.add_argument('--move', nargs=3, type=float, metavar=('X', 'Y', 'Z'), help="Move to absolute X Y Z")
    
    args = parser.parse_args()
    
    if args.headless or args.draw or args.move:
        machine = CoreXYZMachine()
        
        if not machine.connected:
            print("No hardware...")
            exit(1)
            
        if args.origin:
            machine.set_origin()
            
        if args.move:
            machine.move_absolute(args.move[0], args.move[1], args.move[2])
            machine.wait_for_move()
            
        if args.draw:
            if not machine.origin_set:
                print("No origin set... Drawing anyways...")
                machine.set_origin() 
            machine.draw_kanji_routine(args.draw)
            
    #run GUI
    else:
        root = tk.Tk() 
        machine = CoreXYZMachine()
        app = CoreXYZApp(root, machine)
        root.mainloop()