"""
La classe MLP è pronta e implementa una rete neurale a più livelli con propagazione in avanti e all'indietro.
Caratteristiche principali:
- Inizializzazione:
    - Specifica del numero di strati nascosti e dei neuroni per ogni strato.
    - I pesi e i bias vengono inizializzati casualmente.
- Propagazione in avanti:
    - Utilizza la funzione sigmoid come default, oppure la SoftMax se attivata tramite parametro.
- Backpropagation:
    - Aggiorna i pesi e i bias usando il gradiente della funzione di perdita (Mean Squared Error o cross-entropy per softmax).
- Training:
    - Addestra la rete per un numero specificato di epoche.
- Predizione:
    - Ritorna le classi predette (indice del neurone con attivazione massima).
"""

import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Binarizer
import matplotlib.pyplot as plt

# Funzione per caricare e preprocessare il dataset
def load_custom_mnist_data(image_path, image_size=(28, 28), num_classes=10, samples_per_class=20):
    source_image = Image.open(image_path).convert("L")
    source_array = np.array(source_image)
    X = []
    y = []
    for class_label in range(num_classes):
        for sample_index in range(samples_per_class):
            row_start = class_label * image_size[0]
            col_start = sample_index * image_size[1]
            digit_image = source_array[row_start:row_start + image_size[0], col_start:col_start + image_size[1]]
            X.append(digit_image.flatten())
            y.append(class_label)
    X = np.array(X)
    y = np.array(y)
    binarizer = Binarizer(threshold=127)
    X = binarizer.fit_transform(X)
    return train_test_split(X, y, test_size=0.2, random_state=42)

# Implementazione del Percettrone
class Perceptron:
    def __init__(self, input_size, learning_rate=0.01):
        self.weights = np.random.rand(input_size)
        self.bias = np.random.rand(1)
        self.learning_rate = learning_rate

    def activation_function(self, x):
        return 1 if x >= 0 else 0

    def predict(self, X):
        linear_output = np.dot(X, self.weights) + self.bias
        return np.array([self.activation_function(x) for x in linear_output])

    def train(self, X, y, epochs=20):
        for epoch in range(epochs):
            errors = 0
            for i in range(len(X)):
                y_pred = self.activation_function(np.dot(X[i], self.weights) + self.bias)
                error = y[i] - y_pred
                self.weights += self.learning_rate * error * X[i]
                self.bias += self.learning_rate * error
                if error != 0:
                    errors += 1
            print(f"Epoch {epoch + 1}/{epochs}, Errori: {errors}")
            if errors == 0:
                print("Addestramento completato senza errori.")
                break

# Percettrone multi-classe
class MultiClassPerceptron:
    def __init__(self, input_size, num_classes=10, learning_rate=0.01):
        self.perceptrons = [Perceptron(input_size, learning_rate) for _ in range(num_classes)]

    def train(self, X, y, epochs=20):
        for class_label in range(len(self.perceptrons)):
            print(f"Addestramento per la classe {class_label}")
            binary_labels = np.where(y == class_label, 1, 0)
            self.perceptrons[class_label].train(X, binary_labels, epochs)

    def predict(self, X):
        scores = np.array([perc.predict(X) for perc in self.perceptrons]).T
        return np.argmax(scores, axis=1)

# Implementazione della rete a più livelli (MLP)
class MLP:
    def __init__(self, input_size, hidden_layers, output_size, learning_rate=0.01, output_activation="sigmoid"):
        self.learning_rate = learning_rate
        self.output_activation = output_activation  # "sigmoid" o "softmax"
        self.layers = []
        layer_sizes = [input_size] + hidden_layers + [output_size]
        for i in range(len(layer_sizes) - 1):
            weight_matrix = np.random.randn(layer_sizes[i], layer_sizes[i + 1]) * 0.1
            bias_vector = np.zeros((1, layer_sizes[i + 1]))
            self.layers.append({'weights': weight_matrix, 'bias': bias_vector})

    def activation_function(self, x, derivative=False, use_softmax=False):
        if use_softmax:
            if derivative:
                # In backpropagation con cross-entropy, la derivata della softmax viene inglobata nell'errore
                return np.ones_like(x)
            else:
                # Calcolo stabile della SoftMax
                exp_shifted = np.exp(x - np.max(x, axis=1, keepdims=True))
                return exp_shifted / np.sum(exp_shifted, axis=1, keepdims=True)
        else:
            if derivative:
                return x * (1 - x)
            else:
                return 1 / (1 + np.exp(-x))

    def forward_propagation(self, X):
        activations = [X]
        for i, layer in enumerate(self.layers):
            net_input = np.dot(activations[-1], layer['weights']) + layer['bias']
            # Se siamo nell'ultimo layer e vogliamo usare SoftMax
            if i == len(self.layers) - 1 and self.output_activation == "softmax":
                activation = self.activation_function(net_input, use_softmax=True)
            else:
                activation = self.activation_function(net_input)
            activations.append(activation)
        return activations

    def backward_propagation(self, activations, y_true):
        # Per il layer di output: se si usa softmax con cross-entropy, il delta è semplicemente y_pred - y_true.
        if self.output_activation == "softmax":
            delta = activations[-1] - y_true
        else:
            delta = activations[-1] - y_true
        deltas = [delta]
        # Calcolo dei delta per gli hidden layers
        for i in range(len(self.layers) - 2, -1, -1):
            if i + 1 == len(self.layers) - 1 and self.output_activation == "softmax":
                deriv = 1
            else:
                deriv = self.activation_function(activations[i + 1], derivative=True)
            delta = deltas[-1].dot(self.layers[i + 1]['weights'].T) * deriv
            deltas.append(delta)
        deltas.reverse()
        # Aggiornamento dei pesi e bias
        for i in range(len(self.layers)):
            self.layers[i]['weights'] -= self.learning_rate * activations[i].T.dot(deltas[i])
            self.layers[i]['bias'] -= self.learning_rate * np.sum(deltas[i], axis=0, keepdims=True)

    def train(self, X, y, epochs=50):
        for epoch in range(epochs):
            activations = self.forward_propagation(X)
            self.backward_propagation(activations, y)
            if epoch % 10 == 0:
                loss = np.mean((activations[-1] - y) ** 2)
                print(f"Epoch {epoch}/{epochs}, Loss: {loss:.4f}")

    def predict(self, X):
        activations = self.forward_propagation(X)
        return np.argmax(activations[-1], axis=1)

# Funzione per visualizzare i risultati
def display_predictions(X_test, y_test, y_pred, accuracy, model_name, num_samples=10):
    num_cols = 5
    num_rows = (num_samples + num_cols - 1) // num_cols
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 5))
    axes = axes.flatten()
    for i, ax in enumerate(axes[:num_samples]):
        ax.imshow(X_test[i].reshape(28, 28), cmap='gray', interpolation='nearest')
        ax.set_title(f"Pred: {y_pred[i]}\nTrue: {y_test[i]}", fontsize=8)
        ax.axis('off')
    for ax in axes[num_samples:]:
        ax.axis('off')
    plt.suptitle(f"Modello: {model_name}, Accuratezza: {accuracy:.2f}%", fontsize=12)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show(block=True)

# Funzione principale
def main():
    image_path = "mnist_grid_560x560.png"
    print("Caricamento del dataset...")
    X_train, X_test, y_train, y_test = load_custom_mnist_data(image_path)
    
    print("Addestramento MultiClassPerceptron...")
    perceptron_model = MultiClassPerceptron(X_train.shape[1], num_classes=10, learning_rate=0.01)
    perceptron_model.train(X_train, y_train, epochs=20)
    y_pred_perceptron = perceptron_model.predict(X_test)
    accuracy_perceptron = np.mean(y_pred_perceptron == y_test) * 100
    print(f"Accuratezza MultiClassPerceptron: {accuracy_perceptron:.2f}%")
    display_predictions(X_test, y_test, y_pred_perceptron, accuracy_perceptron, "MultiClassPerceptron")

    print("Addestramento MLP...")
    y_train_one_hot = np.eye(10)[y_train]
    mlp_model = MLP(X_train.shape[1], hidden_layers=[64], output_size=10, learning_rate=0.001, output_activation="softmax")
    mlp_model.train(X_train, y_train_one_hot, epochs=200)
    y_pred_mlp = mlp_model.predict(X_test)
    accuracy_mlp = np.mean(y_pred_mlp == y_test) * 100
    print(f"Accuratezza MLP: {accuracy_mlp:.2f}%")
    display_predictions(X_test, y_test, y_pred_mlp, accuracy_mlp, "MLP")

if __name__ == "__main__":
    main()