import cv2
import numpy as np
from keras.models import load_model
import time
import screen_brightness_control


# Constants for monitor brightness control
SET_MONITOR_BRIGHTNESS = 0x1007CC




# Function to track red light and extract path
def track_red_light():
    cap = cv2.VideoCapture(0)
    red_path = []
    start_time = time.time()

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to HSV
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)

        # Define lower and upper range for red color
        lower_red = np.array([0, 100, 100])
        upper_red = np.array([10, 255, 255])

        # Threshold the HSV image to get only red colors
        mask = cv2.inRange(hsv, lower_red, upper_red)

        # Find contours in the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Append the centroid of the largest contour (red light) to the red path
        if contours:
            largest_contour = max(contours, key=cv2.contourArea)
            M = cv2.moments(largest_contour)
            if M["m00"] != 0:
                centroid_x = int(M["m10"] / M["m00"])
                centroid_y = int(M["m01"] / M["m00"])
                red_path.append((centroid_x, centroid_y))

        # Draw red path on the frame
        for i in range(1, len(red_path)):
            cv2.line(frame, red_path[i - 1], red_path[i], (0, 0, 255), thickness=5)

        # Display frame with red path
        cv2.imshow("Red Path", frame)

        # Check if 2 seconds have elapsed
        if time.time() - start_time > 2:
            break

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

    return red_path


# Function to classify path using the model
def classify_path(red_path):
    # Create black background
    black_background = np.zeros((480, 640, 3), dtype=np.uint8)

    # Draw white path on the black background
    for i in range(1, len(red_path)):
        cv2.line(black_background, red_path[i - 1], red_path[i], (255, 255, 255), thickness=5)

    # Resize image to match model input shape
    input_image = cv2.resize(black_background, (100, 100))

    # Expand dimensions to match model input shape
    input_image = np.expand_dims(input_image, axis=0)

    # Load the classification model
    model = load_model("classify.keras")

    # Predict the class of the image
    prediction = model.predict(input_image)

    # Decode the prediction
    classes = ['lumos', 'nox', 'None']
    class_index = np.argmax(prediction)
    result = classes[class_index]

    return result

# Function to adjust screen brightness based on classification result
def adjust_brightness(classification_result):
    if classification_result == 'lumos':
        screen_brightness_control.set_brightness(100)

    elif classification_result == 'nox':
        screen_brightness_control.set_brightness(20)

    else:
        return


# Main function
def main():
    # Track red light and extract path
    print("Tracking red light for 2 seconds...")
    red_path = track_red_light()

    # Classify image
    print("Classifying image...")
    classification_result = classify_path(red_path)
    print("Classification result:", classification_result)

    # Adjust screen brightness based on classification result
    print("Adjusting screen brightness...")
    adjust_brightness(classification_result)


if __name__ == "__main__":

    main()