from .cluster_hierarchique import ClusterHierarchique
from .clustering import Clustering

class ClusteringHierarchique(Clustering):
    """ Clustering hiérarchique. """

    liens = {
        'single': min,
        'complete': max,
    }

    def __init__(self, type_lien, dist_f):
        """
            :param str lien: le type de distance entre deux clusters,\
            'single' ou 'complete'.
            :param dist_f: la fonction de distance entre deux données.
        """

        super().__init__()
        self.dist_f = dist_f
        # Permet d'utiliser min ou max de manière générique en fonction du
        # paramètre type_lien.
        self.lien = self.liens[type_lien]

    def fusionne_clusters(self, cluster1, cluster2):
        """ Fusionne deux clusters.

            Le nouveau cluster contiendra ``cluster1`` à gauche et ``cluster2``\
            à droite.
            
            :param cluster1: un noeud qui ira à droite du nouveau cluster.
            :param cluster2: un noeud qui ira à gauche du nouveau cluster.
            :return: le nouveau cluster.
        """
        
        donnees = cluster1.donnees + cluster2.donnees
        return ClusterHierarchique(donnees, cluster1, cluster2)
    
    def calcule_distance(self, cluster1, cluster2):
        """ Calcule la distance entre deux clusters. """
        distances = []
        for donnee1 in cluster1.donnees:
            for donnee2 in cluster1.donnees:
                distances.append(self.dist_f(donnee1, donnee2))

        return self.lien(distances)

    def initialise_clusters(self, donnees):
        """ Initialise les clusters. 

            :param list donnees: les données à regrouper dans des clusters.
        """
        # Construit les clusters terminaux : un par donnée.
        # Les clusters seront ensuite fusionnés pour créer la hiérarchie.
        self.clusters = [ClusterHierarchique([donnee]) for donnee in donnees]

    def fini(self, anciens_clusters):
        """ Teste si les clusters ont changé par rapport aux anciens clusters.

            C'est le cas si leur nombre a diminué.

            :param list anciens_clusters: la liste des anciens clusters.
        """
        return len(self.clusters) == len(anciens_clusters)

    def revise_clusters(self):
        """ Révise les clusters. """
        if len(self.clusters) == 1:
            return

        # Calcule la distance entre chaque paire de clusters.
        distances = []
        for cluster1 in self.clusters:
            for cluster2 in self.clusters:
                if cluster1 != cluster2:
                    distance = self.calcule_distance(cluster1, cluster2)
                    distances.append((distance, cluster1, cluster2))

        # Trouve les deux clusters les plus proches. 
        paire = min(distances, key=lambda x: x[0])
        cluster1 = paire[1]
        cluster2 = paire[2]

        # Fusionne ces deux clusters. 
        nouveau_cluster = self.fusionne_clusters(cluster1, cluster2)
        self.clusters.remove(cluster1)
        self.clusters.remove(cluster2)
        self.clusters.append(nouveau_cluster)

    def affiche_clusters(self):
        """ Affiche les clusters découverts par l'algorithme."""

        print('\n'.join([str(cluster) for cluster in self.clusters]))