import argparse
import json
import time
from pathlib import Path
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig


class ProperSlowSequencer:
    def __init__(self, port="COM15"):   # this is found via the find_motors_bus_port script
        self.port = port
        self.motor_bus = None
        self.motors_config = {
            "shoulder_pan": [1, "sts3215"],
            "shoulder_lift": [2, "sts3215"],
            "elbow_flex": [3, "sts3215"],
            "wrist_flex": [4, "sts3215"],
            "wrist_roll": [5, "sts3215"],
            "gripper": [6, "sts3215"],
        }
        
    def connect(self):
        print(f"Connecting to robot on {self.port}...")
        config = FeetechMotorsBusConfig(port=self.port, motors=self.motors_config)
        self.motor_bus = FeetechMotorsBus(config)
        self.motor_bus.connect()
        
        # REMOVE ALL LIMITS first
        print("REMOVING ALL LIMITS...")
        for motor_name in self.motors_config.keys():
            try:
                self.motor_bus.write("Min_Angle_Limit", 0, motor_name)
                self.motor_bus.write("Max_Angle_Limit", 4095, motor_name)
                print(f"   {motor_name}: Limits removed")
            except Exception as e:
                print(f"   {motor_name}: {e}")
        
        print("Connected and ALL limits removed")
        
    def disconnect(self):
        if self.motor_bus:
            self.motor_bus.disconnect()
            print("Disconnected")
    
    def set_speed_mode(self, speed_mode="slow"):
        """Set movement speed using Acceleration parameter (the correct way!)"""
        speed_settings = {
            "very_slow": {"acceleration": 50, "max_accel": 50, "description": "Very Slow (50)"},
            "slow": {"acceleration": 100, "max_accel": 100, "description": "Slow (100)"},
            "medium": {"acceleration": 150, "max_accel": 150, "description": "Medium (150)"},
            "fast": {"acceleration": 254, "max_accel": 254, "description": "Fast (254 - Official LeRobot)"},
        }
        
        if speed_mode not in speed_settings:
            speed_mode = "slow"
        
        settings = speed_settings[speed_mode]
        
        print(f"\n SETTING SPEED: {settings['description']}")
        
        for motor_name in self.motors_config.keys():
            try:
                # Apply EXACT official LeRobot settings except for speed
                self.motor_bus.write("Mode", 0, motor_name)  # Position Control
                self.motor_bus.write("P_Coefficient", 16, motor_name)  # Smooth movement
                self.motor_bus.write("I_Coefficient", 0, motor_name)
                self.motor_bus.write("D_Coefficient", 32, motor_name)
                
                # Unlock EPROM
                self.motor_bus.write("Lock", 0, motor_name)
                
                # THE KEY: Acceleration controls speed!
                self.motor_bus.write("Maximum_Acceleration", settings["max_accel"], motor_name)
                self.motor_bus.write("Acceleration", settings["acceleration"], motor_name)
                
                print(f"   {motor_name}: Speed set to {settings['description']}")
            except Exception as e:
                print(f"   {motor_name}: {e}")
    
    def torque_off(self):
        """Disable torque for manual movement"""
        self.motor_bus.write("Torque_Enable", TorqueMode.DISABLED.value)
        print("Torque OFF - move robot freely")
    
    def torque_on(self):
        """Enable torque for controlled movement"""
        self.motor_bus.write("Torque_Enable", TorqueMode.ENABLED.value)
        print("Torque ON - robot under control")
    
    def get_positions(self):
        """Get current positions"""
        positions = {}
        for motor_name in self.motors_config.keys():
            pos = self.motor_bus.read("Present_Position", motor_name)
            if hasattr(pos, '__len__') and len(pos) == 1:
                pos = pos[0]
            positions[motor_name] = int(pos)
        return positions
    
    def move_to_position_correct(self, positions):
        """Move using CORRECT approach - ONLY Goal_Position, speed controlled by Acceleration"""
        for motor_name, position in positions.items():
            self.motor_bus.write("Goal_Position", int(position), motor_name)
    
    def move_to_position_fast(self, positions):
        """Move FAST using old Goal_Time approach for super fast movements"""
        time_ms = 200  # Very fast movement time
        for motor_name, position in positions.items():
            self.motor_bus.write("Goal_Time", time_ms, motor_name)
            self.motor_bus.write("Goal_Position", int(position), motor_name)
    
    def play_sequence(self, name, speed_mode="slow"):
        """Play back a recorded sequence with proper speed control"""
        file_path = Path("sequences") / f"{name}.json"
        
        if not file_path.exists():
            print(f"Sequence file not found: {file_path}")
            return
        
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        sequence = data["sequence"]
        
        print(f"\n{'='*60}")
        print(f"PROPER SLOW PLAYBACK: {name}")
        print(f"{'='*60}")
        print(f"Positions: {len(sequence)}")
        
        # Set the speed mode
        self.set_speed_mode(speed_mode)
        
        # Show what we're going to do
        for step in sequence:
            pos_str = " | ".join([f"{motor}:{position:4d}" for motor, position in step["positions"].items()])
            duration = step["duration"]
            time_desc = "START" if duration == 0 and step["position"] == 1 else f"SMOOTH (target: {duration}s)"
            print(f"   {step['position']}. {pos_str} ({time_desc})")
        
        input(f"\nPress ENTER to start PROPER SLOW playback (Speed: {speed_mode})...")
        
        self.torque_on()
        
        print(f"\n EXECUTING PROPER SLOW MOVEMENT...")
        print("=" * 50)
        
        try:
            for i, step in enumerate(sequence):
                pos = step["positions"]
                duration = step["duration"]
                position_num = step["position"]
                
                pos_str = " | ".join([f"{motor}:{position:4d}" for motor, position in pos.items()])
                
                if position_num == 1:
                    print(f"Position {position_num}: {pos_str} (moving to start)")
                    self.move_to_position_correct(pos)
                    time.sleep(1.5)  # Short pause for start position only
                else:
                    if duration == 0:
                        print(f"Position {position_num}: {pos_str} (SUPER FAST)")
                        self.move_to_position_fast(pos)  # Use old fast method
                        time.sleep(0.4)  # Minimal pause for fast movements
                    else:
                        print(f"Position {position_num}: {pos_str} (SMOOTH ~{duration}s)")
                        self.move_to_position_correct(pos)  # Use new smooth method
                        time.sleep(0.8)  # Minimal pause for smooth movements
            
            print(f"\n PROPER SLOW SEQUENCE COMPLETE!")
            
        except KeyboardInterrupt:
            print(f"\n  PLAYBACK STOPPED")


def main():
    parser = argparse.ArgumentParser(description="Proper Slow Position Sequence Player")
    parser.add_argument("--mode", choices=["play"], required=True, help="Only playback mode")
    parser.add_argument("--name", required=True, help="Sequence name")
    parser.add_argument("--speed", choices=["very_slow", "slow", "medium", "fast"], default="slow", help="Movement speed")
    parser.add_argument("--port", default="COM15", help="Serial port")
    
    args = parser.parse_args()
    
    player = ProperSlowSequencer(args.port)
    
    try:
        player.connect()
        
        if args.mode == "play":
            player.play_sequence(args.name, args.speed)
        
    except Exception as e:
        print(f" ERROR: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    finally:
        player.disconnect()
    
    return 0


if __name__ == "__main__":
    exit(main())
