import numpy as np
import pandas as pd


class HEICalculator:
    def __init__(self, base_path, nutrition_data):
        """
        Parameters:
            base_path (str): Directory path where food category CSVs are stored.
            nutrition_data (DataFrame): Daily nutritional intake data.
        """
        self.base_path = base_path
        self.nutrition_data = nutrition_data
        # Load food ID lists for each food category
        self.food_ids = {
            "total_fruits": self._load_food_ids("fruit"),
            "whole_fruits": self._load_food_ids("wholefruit"),
            "total_vegetables": self._load_food_ids("vegetable"),
            "greens_and_beans": self._load_food_ids("green_and_beans"),
            "whole_grains": self._load_food_ids("wholegrain"),
            "dairy": self._load_food_ids("dairy"),
            "protein_foods": self._load_food_ids("protein_food"),
            "seafood_plant_proteins": self._load_food_ids(
                "seafood_and_plant_protein_food"
            ),
            "refined_grains": self._load_food_ids("refined_grain"),
            "added_sugars": self._load_food_ids("added_sugars"),
        }
        # Define nutrient units and conversion factors
        self.nutrient_units = {
            "protein": "g",
            "alchohol": "g",
            "water": "g",
            "carbohydrates": "g",
            "fiber": "g",
            "sugar": "g",
            "fat": "g",
            "fatty_acids_saturated": "g",
            "fatty_acids_monounsaturated": "g",
            "fatty_acids_polyunsaturated": "g",
            "cholesterol": "mg",
            "vitamin_a": "IU",
            "vitamin_c": "mg",
            "beta_carotene": "mcg",
            "vitamin_e": "TAE",
            "vitamin_d": "mcg",
            "vitamin_k": "mcg",
            "thiamin": "mg",
            "riboflavin": "mg",
            "niacin": "mg",
            "vitamin_b6": "mg",
            "folate": "mcg",
            "vitamin_b12": "mcg",
            "calcium": "mg",
            "phosphorus": "mg",
            "magnesium": "mg",
            "iron": "mg",
            "zinc": "mg",
            "copper": "mg",
            "selenium": "mcg",
            "potassium": "mg",
            "sodium": "mg",
            "caffeine": "mg",
            "theobromine": "mg",
            "energy_kcal": "kcal",
            "pantothenic_acid": "mg",
            "vitamin_b1": "mg",
            "vitamin_b2": "mg",
        }
        self.conversion_factors = {
            "mg": 1000,
            "g": 1,
            "mcg": 1000000,
        }
        # Dictionary to hold any manual modifications to food category quantities
        self.modified_quantities = {}

    def _load_food_ids(self, food_cat):
        """Loads the food IDs from a CSV file in the base path."""
        food_id_df = pd.read_csv(f"{self.base_path}/{food_cat}.csv")
        return food_id_df["id"].values

    def actual_consumed_amount(self, df, columnName):
        """Calculates the actual consumed amount for a given nutrient column."""
        factor = self.conversion_factors[self.nutrient_units[columnName]]
        return df.apply(
            lambda x: x["eaten_quantity_in_gram"] * x[columnName] / (100 * factor),
            axis=1,
        ).sum()

    def actual_fatty_consumed(self, df, columnName):
        """
        Calculates the actual fatty acid consumed.
        Only considers rows where any fatty acids are present.
        """

        def fatty_prop_calc(row):
            if (
                row["fatty_acids_monounsaturated"]
                + row["fatty_acids_polyunsaturated"]
                + row["fatty_acids_saturated"]
                == 0
            ):
                return False
            return True

        mdf = df[
            [
                "eaten_quantity_in_gram",
                "fatty_acids_monounsaturated",
                "fatty_acids_polyunsaturated",
                "fatty_acids_saturated",
                "fat_eaten",
            ]
        ].fillna(0)
        df = df.copy()
        df["fatty_acids_present"] = mdf.apply(fatty_prop_calc, axis=1)
        df = df[df["fatty_acids_present"] == True]
        consumed = df.apply(lambda x: x["fat_eaten"] * x[columnName] / 100, axis=1)
        return consumed.sum()

    def calculate_food_quantities(self, df=None):
        """
        Quantifies the consumed amounts for each food category.
        If df is not provided, uses the instance nutrition_data.
        Incorporates any modifications made via modify_food_quantity().
        """
        if df is None:
            df = self.nutrition_data

        quantities = {}
        quantities["total_fruits"] = df[
            df["food_id"].isin(self.food_ids["total_fruits"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["whole_fruits"] = df[
            df["food_id"].isin(self.food_ids["whole_fruits"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["total_vegetables"] = df[
            df["food_id"].isin(self.food_ids["total_vegetables"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["greens_and_beans"] = df[
            df["food_id"].isin(self.food_ids["greens_and_beans"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["whole_grains"] = df[
            df["food_id"].isin(self.food_ids["whole_grains"])
        ]["eaten_quantity_in_gram"].sum()
        # For dairy, summing calcium instead of gram quantity
        quantities["dairy"] = df[df["food_id"].isin(self.food_ids["dairy"])][
            "calcium"
        ].sum()
        quantities["protein_foods"] = df[
            df["food_id"].isin(self.food_ids["protein_foods"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["seafood_plant_proteins"] = df[
            df["food_id"].isin(self.food_ids["seafood_plant_proteins"])
        ]["eaten_quantity_in_gram"].sum()
        # Fatty acids are calculated as a tuple: (PUFA+MUFA, Saturated)
        pufa_mufa = self.actual_fatty_consumed(
            df, "fatty_acids_monounsaturated"
        ) + self.actual_fatty_consumed(df, "fatty_acids_polyunsaturated")
        sfa = self.actual_fatty_consumed(df, "fatty_acids_saturated")
        quantities["fatty_acids"] = (pufa_mufa, sfa)
        quantities["refined_grains"] = df[
            df["food_id"].isin(self.food_ids["refined_grains"])
        ]["eaten_quantity_in_gram"].sum()
        quantities["sodium"] = self.actual_consumed_amount(df, "sodium")
        quantities["added_sugars"] = 4 * self.actual_consumed_amount(df, "sugar")
        quantities["saturated_fats"] = 9 * self.actual_fatty_consumed(
            df, "fatty_acids_saturated"
        )

        # Apply any modifications
        for category, mod_value in self.modified_quantities.items():
            quantities[category] = mod_value

        return quantities

    def get_food_quantities(self):
        """Returns the current food quantities dictionary."""
        return self.calculate_food_quantities()

    def modify_food_quantity(self, category, new_quantity):
        """Modifies the stored quantity for a given food category."""
        self.modified_quantities[category] = new_quantity

    def calculate_quantity(self, df, category):
        """
        Calculates the quantity for a given food category from a provided DataFrame.
        This version assumes that the DataFrame is already filtered or correctly represents the foods
        for the given category, so it does not filter by food IDs.
        """
        if category in [
            "total_fruits",
            "whole_fruits",
            "total_vegetables",
            "greens_and_beans",
            "whole_grains",
            "protein_foods",
            "seafood_plant_proteins",
            "refined_grains",
        ]:
            return df["eaten_quantity_in_gram"].sum()
        elif category == "dairy":
            return df["calcium"].sum()
        elif category == "sodium":
            return self.actual_consumed_amount(df, "sodium")
        elif category == "added_sugars":
            return 4 * self.actual_consumed_amount(df, "sugar")
        elif category == "saturated_fats":
            return 9 * self.actual_fatty_consumed(df, "fatty_acids_saturated")
        elif category == "fatty_acids":
            pufa_mufa = self.actual_fatty_consumed(
                df, "fatty_acids_monounsaturated"
            ) + self.actual_fatty_consumed(df, "fatty_acids_polyunsaturated")
            sfa = self.actual_fatty_consumed(df, "fatty_acids_saturated")
            return (pufa_mufa, sfa)
        else:
            raise ValueError(f"Unknown food category: {category}")

    def score_food_category(self, food_quantity, total_energy, foodclass):
        """
        Calculates the HEI points and scaled ratio for a given food category.
        Returns a tuple (points, scaled_amount_ratio).
        """
        points = 0
        # Adequacy scoring
        if foodclass == "total_fruits":
            min_q, max_q, max_score = 0, 141, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "whole_fruits":
            min_q, max_q, max_score = 0, 60, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "total_vegetables":
            min_q, max_q, max_score = 0, 160, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "greens_and_beans":
            min_q, max_q, max_score = 0, 29, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "whole_grains":
            min_q, max_q, max_score = 0, 31, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "dairy":
            min_q, max_q, max_score = 0, 412, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "protein_foods":
            min_q, max_q, max_score = 0, 15.6, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "seafood_plant_proteins":
            min_q, max_q, max_score = 0, 3.3, 5
            quantity_levels = np.linspace(min_q, max_q, 5)
            score_levels = np.linspace(0, max_score, 6)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "fatty_acids":
            # Here food_quantity is expected to be a tuple: (PUFA+MUFA, SFA)
            min_q, max_q, max_score = 1.2, 2.5, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            pufa_mufa, sfa = food_quantity
            # Ratio of unsaturated to saturated fatty acids
            amount_ratio = pufa_mufa / sfa if sfa != 0 else 0
        # Moderation scoring
        elif foodclass == "refined_grains":
            min_q, max_q, max_score = 76, 32, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "sodium":
            min_q, max_q, max_score = 2, 1.1, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 1000 * food_quantity / total_energy
        elif foodclass == "added_sugars":
            min_q, max_q, max_score = 26, 6.5, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 100 * food_quantity / total_energy
        elif foodclass == "saturated_fats":
            min_q, max_q, max_score = 16, 8, 10
            quantity_levels = np.linspace(min_q, max_q, 10)
            score_levels = np.linspace(0, max_score, 11)
            amount_ratio = 100 * food_quantity / total_energy
        else:
            raise ValueError(f"Unknown food class: {foodclass}")

        amount_ratio = round(amount_ratio, 2)

        # Determine if the scoring is for adequacy or moderation
        if foodclass in [
            "total_fruits",
            "whole_fruits",
            "total_vegetables",
            "greens_and_beans",
            "whole_grains",
            "dairy",
            "protein_foods",
            "seafood_plant_proteins",
            "fatty_acids",
        ]:
            # Adequacy: higher intakes yield higher scores
            for i in range(len(quantity_levels)):
                if i == 0:
                    if amount_ratio <= quantity_levels[0]:
                        points = score_levels[0]
                        break
                elif i == len(quantity_levels) - 1:
                    if amount_ratio >= quantity_levels[-1]:
                        points = score_levels[-1]
                        break
                else:
                    if quantity_levels[i] <= amount_ratio < quantity_levels[i + 1]:
                        points = score_levels[i + 1]
                        break
        else:
            # Moderation: lower intakes yield higher scores
            for i in range(len(quantity_levels)):
                if i == 0:
                    if amount_ratio <= quantity_levels[-1]:
                        points = score_levels[-1]
                        break
                elif i == len(quantity_levels) - 1:
                    if amount_ratio >= quantity_levels[0]:
                        points = score_levels[0]
                        break
                else:
                    if quantity_levels[i + 1] <= amount_ratio < quantity_levels[i]:
                        points = score_levels[i + 1]
                        break

        return points, amount_ratio

    def calculate_scores(self):
        """
        Calculates HEI scores for all food categories based on the current nutrition data.
        It computes the total energy from the nutrition data and then scores each category.
        Returns a dictionary mapping each category to its (points, scaled ratio) and an overall total score.
        """
        total_energy = self.nutrition_data["energy_kcal_eaten"].sum()
        quantities = self.calculate_food_quantities()
        scores = {}
        overall_total = 0

        for category in [
            "total_fruits",
            "whole_fruits",
            "total_vegetables",
            "greens_and_beans",
            "whole_grains",
            "dairy",
            "protein_foods",
            "seafood_plant_proteins",
            "fatty_acids",
            "refined_grains",
            "sodium",
            "added_sugars",
            "saturated_fats",
        ]:
            qty = quantities.get(category)
            # Score the category using its respective scoring method
            pts, ratio = self.score_food_category(qty, total_energy, category)
            scores[category] = {"points": pts, "scaled_ratio": ratio}
            overall_total += pts

        scores["total_HEI_score"] = overall_total
        return scores
