# Librerie necessarie per:
import numpy as np                                      # la gestione degli array multidimensionali e operazioni numeriche
from PIL import Image                                   # la lettura e manipolazione delle immagini
from sklearn.model_selection import train_test_split    # dividere i dati in training e test set
from sklearn.preprocessing import Binarizer             # binarizzare i pixel delle immagini
import matplotlib.pyplot as plt                         # la visualizzazione dei dati e delle predizioni

# Funzione per caricare e preprocessare il dataset dall'immagine sorgente
def load_custom_mnist_data(image_path, image_size=(28, 28), num_classes=10, samples_per_class=20):
    """
    Carica i dati da un'immagine sorgente (es. mnist_grid_560x560.png) organizzata
    in righe e colonne, dove ogni riga rappresenta una classe (0-9) e le colonne
    contengono esempi di quella classe.

    1. Divide l'immagine in sezioni 28x28 (una per ogni esempio).
    2. Assegna etichette alle immagini in base alla loro riga.
    3. Binarizza i pixel (>127 = 1, altrimenti 0) per ridurre il rumore.
    4. Divide i dati in training e test set (80% training, 20% test).

    Output: X_train, X_test (feature binarizzate), y_train, y_test (etichette).
    """
    source_image = Image.open(image_path).convert("L")  # Scala di grigi
    source_array = np.array(source_image)
    X = []  # Lista per i dati immagine (appena appiattiti)
    y = []  # Lista per le etichette (classi 0-9)
    
    for class_label in range(num_classes):              # Itera sulle righe (classi)
        for sample_index in range(samples_per_class):   # Itera sugli esempi (colonne)
            row_start = class_label * image_size[0]
            col_start = sample_index * image_size[1]
            digit_image = source_array[row_start:row_start + image_size[0], col_start:col_start + image_size[1]]
            X.append(digit_image.flatten())     # Aggiungi l'immagine appiattita
            y.append(class_label)               # Aggiungi l'etichetta corrispondente
    
    X = np.array(X)
    y = np.array(y)
    
    # Binarizza i dati per ridurre il rumore (pixel > 127 -> 1, altrimenti 0)
    binarizer = Binarizer(threshold=127)
    X = binarizer.fit_transform(X)
    
    # Divide il dataset in 2: training (80%) e test set (20%)
    return train_test_split(X, y, test_size=0.2, random_state=42)

# Classe Percettrone
class Perceptron:
    """
    Implementa un singolo percettrone per la classificazione binaria.
    """
    def __init__(self, input_size, learning_rate=0.01):
        """
        Inizializza i pesi casuali per ogni feature dell'input e il bias.
        """
        self.weights = np.random.rand(input_size)
        self.bias = np.random.rand(1)
        self.learning_rate = learning_rate  # Tasso di apprendimento
    
    def activation_function(self, x):
        """
        Applica una funzione a soglia: restituisce 1 se x >= 0, altrimenti 0.
        """
        return 1 if x >= 0 else 0

    def predict(self, X):
        """
        Calcola le predizioni per un array di input:
        1. Calcola l'output lineare (prodotto scalare + bias).
        2. Applica la funzione di attivazione a soglia.
        """
        linear_output = np.dot(X, self.weights) + self.bias
        return np.array([self.activation_function(x) for x in linear_output])

    def train(self, X, y, epochs=20):
        """
        Addestra il percettrone iterando sugli esempi:
        - L'obiettivo è aggiornare i pesi e il bias per minimizzare l'errore di predizione.
        
        Algoritmo di aggiornamento:
        1. Per ogni esempio nel dataset, calcola l'errore:
           errore = y_true - y_pred
        2. Aggiorna i pesi per correggere l'errore:
           w_i = w_i + learning_rate * errore * x_i
        3. Aggiorna il bias per correggere l'errore:
           b = b + learning_rate * errore

        I pesi vengono incrementati proporzionalmente all'importanza della feature (x_i)
        e all'entità dell'errore. Se l'errore è zero, i pesi non vengono aggiornati.
        """
        for epoch in range(epochs):
            errors = 0  # Conta gli errori per epoch
            for i in range(len(X)):
                y_pred = self.activation_function(np.dot(X[i], self.weights) + self.bias)
                error = y[i] - y_pred  # Calcola l'errore
                
                # Aggiorna i pesi e il bias in base all'errore
                self.weights += self.learning_rate * error * X[i]
                self.bias += self.learning_rate * error
                
                if error != 0:
                    errors += 1

            print(f"Epoch {epoch + 1}/{epochs}, Errori: {errors}")
            if errors == 0:
                print("Addestramento completato senza errori.")
                break

# Classe MultiClassPerceptron
class MultiClassPerceptron:
    """
    Gestisce più percettroni (uno per classe) per classificazione multi-classe.
    """
    def __init__(self, input_size, num_classes=10, learning_rate=0.01):
        self.num_classes = num_classes  # Numero totale di classi
        self.perceptrons = [Perceptron(input_size, learning_rate) for _ in range(num_classes)]

    def train(self, X, y, epochs=20):
        """
        Addestra un percettrone per ogni classe:
        - Utilizza etichette binarie per distinguere la classe target dalle altre.
        """
        for class_label in range(self.num_classes):
            print(f"Addestramento per la classe {class_label}")
            binary_labels = np.where(y == class_label, 1, 0)  # Crea etichette binarie
            self.perceptrons[class_label].train(X, binary_labels, epochs)

    def predict(self, X):
        """
        Predice la classe per ciascun input:
        1. Ottiene i punteggi da tutti i percettroni.
        2. Restituisce la classe con il punteggio più alto.
        """
        scores = np.array([perc.predict(X) for perc in self.perceptrons]).T
        return np.argmax(scores, axis=1)

# Funzione per mostrare predizioni
def display_predictions(X_test, y_test, y_pred, accuracy, dataset_name, num_samples=10):
    """
    Mostra alcune immagini del dataset di test insieme alle predizioni del modello e ai valori reali.
    Inoltre, visualizza l'accuratezza complessiva del modello e il nome del dataset.

    Argomenti:
    ----------
    - X_test: array numpy, contiene le immagini del dataset di test.
    - y_test: array numpy, contiene le etichette reali delle immagini di test.
    - y_pred: array numpy, contiene le predizioni del modello per X_test.
    - accuracy: float, rappresenta l'accuratezza complessiva del modello (%).
    - dataset_name: str, il nome o una descrizione del dataset in uso.
    - num_samples: int, numero di esempi da mostrare (default: 10).

    Scopo:
    ------
    Visualizzare alcune immagini del dataset di test in una griglia, con le predizioni
    e i valori reali annotati, per confrontare la performance del modello.
    """

    # Numero di colonne nella griglia
    num_cols = 5  # Fissa il numero di colonne della griglia a 5

    # Calcola il numero di righe necessario per visualizzare 'num_samples' immagini
    num_rows = (num_samples + num_cols - 1) // num_cols

    # Creazione di una griglia di sottotrame per le immagini
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 5))

    # Appiattisce l'array delle assi per iterare facilmente su di esso
    axes = axes.flatten()

    # Itera sulle immagini e sui rispettivi assi per mostrarle
    for i, ax in enumerate(axes[:num_samples]):  # Considera solo il numero richiesto di campioni
        # Ricostruisce l'immagine 28x28 a partire dal vettore appiattito in X_test[i]
        ax.imshow(X_test[i].reshape(28, 28), cmap='gray', interpolation='nearest')

        # Imposta il titolo della sottotrama con la predizione e l'etichetta reale
        ax.set_title(f"Pred: {y_pred[i]}\nTrue: {y_test[i]}", fontsize=8)

        # Disattiva gli assi per migliorare la visualizzazione
        ax.axis('off')

    # Disattiva eventuali assi vuoti nella griglia (se num_samples < num_cols * num_rows)
    for ax in axes[num_samples:]:
        ax.axis('off')

    # Imposta un titolo generale per la figura con l'accuratezza e il nome del dataset
    plt.suptitle(f"Dataset: {dataset_name}, Accuratezza: {accuracy:.2f}%", fontsize=12)

    # Regola i margini per evitare sovrapposizioni e migliorare la leggibilità
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    # Mostra la figura e blocca l'esecuzione fino a quando la finestra non viene chiusa
    plt.show(block=True)


# Funzione principale
def main():
    """
    1. Carica il dataset da un'immagine (es. mnist_grid_560x560.png) precedentemente creata.
    2. Addestra un modello MultiClassPerceptron su 20 campioni per classe.
    3. Valuta l'accuratezza del modello sui dati di test.
    4. Mostra alcune predizioni visivamente.
    """
    image_path = "mnist_grid_560x560.png"
    print("Caricamento del dataset...")
    X_train, X_test, y_train, y_test = load_custom_mnist_data(image_path, image_size=(28, 28), num_classes=10, samples_per_class=20)
    
    print("Inizializzazione del modello...")
    input_size = X_train.shape[1]
    multi_class_perceptron = MultiClassPerceptron(input_size, num_classes=10, learning_rate=0.01)
    
    print("Avvio dell'addestramento...")
    multi_class_perceptron.train(X_train, y_train, epochs=20)
    
    print("Valutazione del modello...")
    y_pred = multi_class_perceptron.predict(X_test)
    accuracy = np.mean(y_pred == y_test)
    print(f"Accuratezza: {accuracy * 100:.2f}%")
    
    print("Visualizzazione di alcuni esempi...")
    display_predictions(X_test, y_test, y_pred, accuracy, "MNIST Dataset con 20 campioni per classe")

if __name__ == "__main__":
    main()
