import os
import shutil
import random
import cv2
import glob
import yaml
import tempfile
from collections import defaultdict
from autodistill_grounded_sam import GroundedSAM
from autodistill.detection import CaptionOntology
from autodistill_yolov8 import YOLOv8

###############################
###  AUTODISTILL PROCESS!!! ###
###############################

if __name__ == '__main__':

    # Define Folders and Images
    main_dir = "auto_dataset"
    final_dir = "final_dataset"
    current_directory = os.getcwd()
    annotation_dir = os.path.join(main_dir, "annotations")
    images_dir = os.path.join(main_dir, "images")
    train_dir = os.path.join(final_dir, "train")
    valid_dir = os.path.join(final_dir, "valid")
    data_dict = defaultdict(list)
    images = {}

    try:
        os.makedirs(final_dir, exist_ok=True)
        os.makedirs(main_dir, exist_ok=True)
        os.makedirs(os.path.join(final_dir, 'images'))
        os.makedirs(os.path.join(final_dir, 'labels'))
        for dir_path in [train_dir, valid_dir]:
            os.makedirs(os.path.join(dir_path, 'images'), exist_ok=True)
            os.makedirs(os.path.join(dir_path, 'labels'), exist_ok=True)
    except:
        pass

    for class_dir in glob.glob(os.path.join(main_dir, '*')):
        class_name = os.path.basename(class_dir)
        print(class_dir)
        
        base_model = GroundedSAM(ontology=CaptionOntology({"tree": str(class_name)}))
        try:
            base_model.label(input_folder=class_dir, output_folder="auto_dataset/" + str(class_name))
        except:
            pass

        for img_path in glob.glob(os.path.join(class_dir, '*')):
            img = cv2.imread(img_path)
            if img is not None:
                image_name = os.path.basename(img_path)
                images[image_name] = img
                data_dict[class_name].append(img_path)

        print(data_dict.keys())

    # Continue with your YAML creation, if necessary
    yaml_dict = {
        'names': list(data_dict.keys()),
        'nc': len(data_dict),
        'train': current_directory + "\\" + final_dir + "\\train",
        'val': current_directory + "\\" + final_dir + "\\valid"
    }

    with open(final_dir + '\\data.yaml', 'w') as file:
        yaml.dump(yaml_dict, file, default_flow_style=False)

    ###############################
    ###   RE-NUMBER CLASSES!!!  ###
    ###############################

    # Loop over each subfolder in the main directory
    for count, folder in enumerate(os.listdir(main_dir)):
        print("Assigning class index to class name: " + folder)
        print("Class index: " + str(count))
        source_folder_path = os.path.join(main_dir, folder, 'train')
        if os.path.isdir(source_folder_path):
            # Copy the 'images' and 'labels' subfolders to the final_dir
            for subfolder in ['images', 'labels']:
                source_subfolder_path = os.path.join(source_folder_path, subfolder)
                if os.path.exists(source_subfolder_path):
                    # Loop over each file in the subfolder and copy it to the destination
                    for file_name in os.listdir(source_subfolder_path):
                        source_file_path = os.path.join(source_subfolder_path, file_name)
                        dest_file_path = os.path.join(final_dir, subfolder, file_name)
                        if os.path.isfile(source_file_path):  # Check if it is a file, not a directory
                            # If this is an annotation file, rewrite the first value in each line
                            if subfolder == 'labels':
                                with open(source_file_path, 'r') as annot_file:
                                    lines = annot_file.readlines()
                                lines = [str(count) + line[line.find(' '):] for line in lines]
                                with open(dest_file_path, 'w') as annot_file:
                                    annot_file.writelines(lines)
                            else:
                                shutil.copy2(source_file_path, dest_file_path)  # preserves file metadata

        source_folder_path = os.path.join(main_dir, folder, 'valid')
        if os.path.isdir(source_folder_path):
            # Copy the 'images' and 'labels' subfolders to the final_dir
            for subfolder in ['images', 'labels']:
                source_subfolder_path = os.path.join(source_folder_path, subfolder)
                if os.path.exists(source_subfolder_path):
                    # Loop over each file in the subfolder and copy it to the destination
                    for file_name in os.listdir(source_subfolder_path):
                        source_file_path = os.path.join(source_subfolder_path, file_name)
                        dest_file_path = os.path.join(final_dir, subfolder, file_name)
                        if os.path.isfile(source_file_path):  # Check if it is a file, not a directory
                            # If this is an annotation file, rewrite the first value in each line
                            if subfolder == 'labels':
                                with open(source_file_path, 'r') as annot_file:
                                    lines = annot_file.readlines()
                                lines = [str(count) + line[line.find(' '):] for line in lines]
                                with open(dest_file_path, 'w') as annot_file:
                                    annot_file.writelines(lines)
                            else:
                                shutil.copy2(source_file_path, dest_file_path)  # preserves file metadata

    ###############################
    ### SPLIT DATASET PROCESS!!!###
    ###############################

    # Get a list of all the images and annotations
    image_files = [f for f in os.listdir(os.path.join(final_dir, "images")) if os.path.isfile(os.path.join(final_dir, "images", f))]
    print("LENGTH OF IMAGES: " + str(len(image_files)))
    annot_files = [f for f in os.listdir(os.path.join(final_dir, "labels")) if os.path.isfile(os.path.join(final_dir, "labels", f))]
    print("LENGTH OF ANNOTATIONS: " + str(len(annot_files)))

    # Assume that each image has a corresponding annotation with the same name
    # (minus the extension), shuffle the list and split into train and validation sets
    random.shuffle(image_files)
    valid_count = int(len(image_files) * 0.1)
    valid_files = image_files[:valid_count]
    train_files = image_files[valid_count:]

    # Move the files to the appropriate folders
    for filename in valid_files:
        shutil.move(os.path.join(final_dir, "images", filename), os.path.join(valid_dir, 'images', filename))
        annot_filename = os.path.splitext(filename)[0] + ".txt"
        if annot_filename in annot_files:
            shutil.move(os.path.join(final_dir, "labels", annot_filename), os.path.join(valid_dir, 'labels', annot_filename))

    for filename in train_files:
        shutil.move(os.path.join(final_dir, "images", filename), os.path.join(train_dir, 'images', filename))
        annot_filename = os.path.splitext(filename)[0] + ".txt"
        if annot_filename in annot_files:
            shutil.move(os.path.join(final_dir, "labels", annot_filename), os.path.join(train_dir, 'labels', annot_filename))
    try:
        os.removedirs(final_dir + '/images')
        os.removedirs(final_dir + '/labels')
    except:
        pass

    target_model = YOLOv8("yolov8s.pt")
    target_model.train(final_dir+'\\data.yaml', epochs=30)

    # run inference on the new model
    pred = target_model.predict("auto_dataset/valid/your-image.jpg", confidence=0.5)
    print(pred)