from .cluster_mean import ClusterMean
from .clustering import Clustering

class ClusteringKMeans(Clustering):
    """ K-means clustering. """

    def __init__(self, k, dist_f):
        """
            :param k: le nombre de clusters à construire.
            :param dist_f: la fonction de distance entre deux données.
        """

        super().__init__()
        self.k = k
        self.dist_f = dist_f

    def noyaux(self, clusters):
        """ Extrait les noyaux d'une liste de clusters. 

           :param list clusters: une liste de clusters dont les noyaux doivent\
           être retournés.
           :return: la liste des noyaux des clusters.
        """

        return [cluster.noyau for cluster in clusters]

    def initialise_clusters(self, donnees):
        """ Initialise les clusters. 

            :param list donnees: les données à regrouper dans des clusters.
        """
        if len(donnees) < self.k:
            raise Exception('Il faut au moins {} données'.format(self.k))

        # Crée les clusters autour des noyaux, qui sont les premières k données.
        noyaux = [(donnees[i], str(i + 1)) for i in range(self.k)]
        self.clusters = [ClusterMean([noyau[0]], noyau[1]) for noyau in noyaux]

        # Ajoute toutes les autres données au premier cluster.
        self.clusters[0].ajoute_donnees(donnees[self.k:])

    def fini(self, anciens_clusters):
        """ Teste si les clusters ont changé par rapport aux anciens clusters.

            C'est le cas si les noyaux ont changé depuis l'itération précédente.

            :param list anciens_clusters: la liste des anciens clusters.
        """
        return self.noyaux(self.clusters) == self.noyaux(anciens_clusters)

    def revise_clusters(self):
        """ Révise les clusters. """
        # Extrait toutes les données des anciens clusters, sauf les noyaux.
        donnees = []
        for cluster in self.clusters:
            donnees.extend([d for d in cluster.donnees if d != cluster.noyau])

        # Réinitialise les nouveaux clusters aux noyaux des anciens clusters.
        for cluster in self.clusters:
            cluster.vide(garde_noyau=True)

        # Assigne chaque donnée au cluster du noyau duquel il est le 
        # plus proche.
        for donnee in donnees:
            distances = [(self.dist_f(donnee, cluster.noyau), cluster) 
                         for cluster in self.clusters]
            cluster = min(distances, key=lambda x: x[0])[1]
            cluster.ajoute_donnee(donnee)

        # Recentre le noyau de chaque nouveau cluster.
        for cluster in self.clusters:
            cluster.centre(self.dist_f)

    def affiche_clusters(self):
        """ Affiche les clusters construits par l'algorithme."""

        print('\n'.join([str(cluster) for cluster in self.clusters]))