import asyncio
import pandas as pd

# from typing import Any
from datetime import datetime
from dataclasses import dataclass
from pydantic import BaseModel
from openai.types.responses import ResponseContentPartDoneEvent, ResponseTextDeltaEvent

from agents import (
    Agent,
    Runner,
    TResponseInputItem,
    RawResponsesStreamEvent,
    function_tool,
    RunContextWrapper,
    GuardrailFunctionOutput,
    InputGuardrailTripwireTriggered,
    input_guardrail,
    RunHooks,
)

from nutrition_agents.price_search_agent import price_search_agent

# from hei_calculator import HEICalculator
from hei_calculator_v2 import HEICalculator
from nutrition_agents.openfoodfacts_api import openfoodfacts_agent
from dotenv import load_dotenv

load_dotenv()

nutrition_df = pd.read_csv("data/example_intake_data.csv")
nutrition_json = nutrition_df.to_json(orient="records")

# --- Pydantic Models ---


class NutrientInformation(BaseModel):
    food_id: int
    food_name: str
    carbohydrate: float  # in grams (g) per 100g
    fat: float  # total fat consumed in grams (g) per 100g
    protein: float  # in grams (g) per 100g
    fiber: float  # in grams (g) per 100g
    alcohol: float  # in grams (g) per 100g
    sodium: (
        float  # typically in milligrams (mg) (conversion applied as needed) per 100g
    )
    calcium: float  # in grams (g) per 100g
    sugar: float  # in grams (g) per 100g
    fatty_acids_monounsaturated: float  # in grams (g) per 100g
    fatty_acids_polyunsaturated: float  # in grams (g) per 100g
    fatty_acids_saturated: float  # in grams (g) per 100g


class FoodCategoryMapping(BaseModel):
    total_fruits: bool  # True if the food fits this category, False otherwise
    whole_fruits: bool  # True if the food fits this category, False otherwise
    total_vegetables: bool  # True if the food fits this category, False otherwise
    greens_and_beans: bool  # True if the food fits this category, False otherwise
    whole_grains: bool  # True if the food fits this category, False otherwise
    dairy: bool  # True if the food fits this category, False otherwise
    protein_foods: bool  # True if the food fits this category, False otherwise
    seafood_plant_proteins: bool  # True if the food fits this category, False otherwise
    refined_grains: bool  # True if the food fits this category, False otherwise
    added_sugars: bool  # True if the food fits this category, False otherwise


# --- Dataclasses ---


@dataclass
class UserInfo:
    name: str
    total_tokens: int = 0
    session_start: datetime = None
    nutrition_df: pd.DataFrame = nutrition_df


# --- Functions ---


@function_tool
def get_nutrient_information(food_name: str):
    """Returns the nutrient information for a given food name"""
    standard_nutrient_info = pd.read_csv("data/nutrient_standard_info.csv")
    # -- Either match exactly or match by partial string
    # nutrient_info = nutrient_info[nutrient_info["food_name"] == food_name]
    nutrient_info = standard_nutrient_info[
        standard_nutrient_info["food_name"].str.lower().str.contains(food_name.lower())
    ]

    # -- Sort by food_name length
    nutrient_info = nutrient_info.sort_values(by="food_name", key=lambda x: x.str.len())
    if nutrient_info.empty:
        return {
            "food_id": None,
            "food_name": food_name,
            "carbohydrate": None,
            "fat": None,
            "protein": None,
            "fiber": None,
            "alcohol": None,
            "sodium": None,
            "calcium": None,
            "sugar": None,
            "fatty_acids_monounsaturated": None,
            "fatty_acids_polyunsaturated": None,
            "fatty_acids_saturated": None,
        }
    return nutrient_info.iloc[0].to_dict()


# @function_tool
# def calculate_HEI_score(
#     wrapper: RunContextWrapper[UserInfo],
#     day: str,
#     # food_category: FoodCategoryMapping = None,
# ):
#     """Calculates the HEI score for a given day (in format YYYY-MM-DD)"""
#     if wrapper and wrapper.context:
#         print(f"📊 Calculating HEI score for {wrapper.context.name} on {day}")

#     user_df = wrapper.context.nutrition_df.copy()
#     user_df["eaten_date"] = pd.to_datetime(user_df["eaten_date"])
#     user_df = user_df[user_df["eaten_date"] == (day)]
#     calculator = HEICalculator(
#         base_path="data/cross_verification_food_categories_HEI", nutrition_data=user_df
#     )
#     all_scores = calculator.calculate_scores()

#     return all_scores


@function_tool
def calculate_HEI_score(
    wrapper: RunContextWrapper[UserInfo],
    day: str,
    new_food_categories: FoodCategoryMapping = None,
):
    """Calculates the HEI score for a given day (in format YYYY-MM-DD)"""
    if wrapper and wrapper.context:
        print(f"📊 Calculating HEI score for {wrapper.context.name} on {day}")

    if new_food_categories is not None:
        print("food_category: ", new_food_categories)

    user_df = wrapper.context.nutrition_df.copy()
    user_df["eaten_date"] = pd.to_datetime(user_df["eaten_date"])
    user_df = user_df[user_df["eaten_date"] == (day)]
    calculator = HEICalculator(
        base_path="data/cross_verification_food_categories_HEI", nutrition_data=user_df
    )
    all_scores = calculator.calculate_scores()

    new_foods = user_df[user_df["food_group_cname"] == "custom"]

    # Handle custom food categories if provided
    if new_food_categories is not None and new_foods.shape[0] > 0:
        food_quantities = calculator.get_food_quantities()
        # Convert Pydantic model to dictionary
        food_category_dict = new_food_categories.model_dump()

        for category in food_category_dict:
            if food_category_dict[category]:
                new_foods_category_quantity = calculator.calculate_quantity(
                    new_foods, category
                )
                print(
                    "food_quantities for category", category, food_quantities[category]
                )
                food_quantities[category] = (
                    food_quantities[category] + new_foods_category_quantity
                )
                calculator.modify_food_quantity(category, food_quantities[category])
        all_scores = calculator.calculate_scores()

    return all_scores


@function_tool
def update_nutrition_df(
    wrapper: RunContextWrapper[UserInfo],
    day: str,
    food_name: str,
    eaten_quantity_in_gram: float,
    nutrient_info: NutrientInformation,
):
    """Updates the nutrition_df with the new food information"""
    # print("--------------------------------")
    # print("foot_category", foot_category)
    # print("--------------------------------")
    # print("nutrient_info", type(nutrient_info), nutrient_info)
    # print("--------------------------------")
    # print("Before update:", wrapper.context.nutrition_df)
    # print(wrapper.context.nutrition_df.iloc[-1])

    energy_from_food = (
        eaten_quantity_in_gram
        * (
            nutrient_info.carbohydrate * 4
            + nutrient_info.fat * 9
            + nutrient_info.protein * 4
        )
        / 100
    )

    new_row = {
        "food_id": nutrient_info.food_id,
        "food_group_cname": "custom",
        "food_name": food_name,
        "eaten_at": f"{day} 00:00:00",
        "eaten_date": day,
        "eaten_quantity_in_gram": eaten_quantity_in_gram,
        "energy_kcal_eaten": energy_from_food,
        "carb_eaten": nutrient_info.carbohydrate * (eaten_quantity_in_gram / 100),
        "fat_eaten": nutrient_info.fat * (eaten_quantity_in_gram / 100),
        "protein_eaten": nutrient_info.protein * (eaten_quantity_in_gram / 100),
        "fiber_eaten": nutrient_info.fiber * (eaten_quantity_in_gram / 100),
        "alcohol_eaten": nutrient_info.alcohol * (eaten_quantity_in_gram / 100),
        "salt_eaten": nutrient_info.sodium * (eaten_quantity_in_gram / 100) / 1000,
        "sodium": nutrient_info.sodium * (eaten_quantity_in_gram / 100) / 1000,
        "calcium": nutrient_info.calcium * (eaten_quantity_in_gram / 100) / 1000,
        "sugar": nutrient_info.sugar * (eaten_quantity_in_gram / 100),
        "fatty_acids_monounsaturated": nutrient_info.fatty_acids_monounsaturated
        * (eaten_quantity_in_gram / 100),
        "fatty_acids_polyunsaturated": nutrient_info.fatty_acids_polyunsaturated
        * (eaten_quantity_in_gram / 100),
        "fatty_acids_saturated": nutrient_info.fatty_acids_saturated
        * (eaten_quantity_in_gram / 100),
    }
    print("new_row:\n", new_row)

    wrapper.context.nutrition_df = pd.concat(
        [wrapper.context.nutrition_df, pd.DataFrame([new_row])], ignore_index=True
    )
    wrapper.context.food_amount = eaten_quantity_in_gram
    # print("After update:", wrapper.context.nutrition_df)
    wrapper.context.nutrition_df.to_csv("data/test_nutrition_df.csv", index=False)
    return "Updated nutrition_df"


# --- Specialized Agents ---

nutrient_composition_agent = Agent[UserInfo](
    name="Nutrient Composition Agent",
    instructions="""You are a nutrient composition agent.
    Using get_nutrient_information function, you are responsible for providing nutrient composition of a given food.
    """,
    handoff_description="Specialist agent for food nutrient composition",
    output_type=NutrientInformation,
    tools=[get_nutrient_information],
)

food_category_agent = Agent[UserInfo](
    name="Food Category Agent",
    instructions="""You are a food category agent.
    You are responsible for assigning the food category of a given food.
    Keep in mind that the food category would be used to calculate the HEI diet quality score.
    Hence, the food category should be as accurate and make sense w.r.t HEI relevant food categories.
    """,
    handoff_description="Specialist agent for food category",
    output_type=FoodCategoryMapping,
)

diet_quality_agent = Agent[UserInfo](
    name="Diet Quality Agent",
    instructions="""You are a diet quality agent.
    You are responsible for calculating the diet quality of a user's daywise food intake.
    If a new food is provided, then use the food category agent to determine the food categories and
    pass the food categories class object to the calculate_HEI_score function, else the food categories should be None.
    """,
    handoff_description="Specialist agent for diet quality",
    tools=[
        food_category_agent.as_tool(
            tool_name="determine_food_category",
            tool_description="Determine the food category of a given food",
        ),
        calculate_HEI_score,
    ],
)

update_nutrition_df_agent = Agent[UserInfo](
    name="Update Nutrition Dataframe Agent",
    instructions="""You are nutrition tracking agent.
    You are responsible for updating the nutrition_df with the new food information.
    You would be given a food name and a day.
    You need to find the nutritional information of the given food using the nutrient composition agent or openfoodfacts agent.
    Then you need to create a food category mapping using the food category agent.
    Add the new food to the user's nutrition_df using the update_nutrition_df function.
    To this tool call, pass the food category mapping and nutrient information as arguments as well in the same format.
    """,
    handoff_description="Specialist agent for updating nutrition_df",
    model="gpt-4o-mini",
    tools=[
        nutrient_composition_agent.as_tool(
            tool_name="determine_nutrient_composition",
            tool_description="Determine the nutrient composition of a given food",
        ),
        food_category_agent.as_tool(
            tool_name="determine_food_category",
            tool_description="Determine the food category of a given food",
        ),
        update_nutrition_df,
    ],
)


# --- Guardrails ---


class MedicalAdviceGuardrailOutput(BaseModel):
    is_medical_advice: bool
    reasoning: str


medical_advice_guardrail_agent = Agent(
    name="Medical Advice Guardrail",
    instructions=(
        "Check if the user's request includes medical advice or medication info. Diet quality is not medical advice."
        "that goes beyond general information. Return is_medical_advice as True if it does, "
        "along with a brief reasoning. Otherwise, return False."
    ),
    output_type=MedicalAdviceGuardrailOutput,
)


@input_guardrail
async def medical_advice_guardrail(
    ctx: RunContextWrapper[None], agent: Agent, input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
    result = await Runner.run(
        medical_advice_guardrail_agent, input, context=ctx.context
    )
    final_output = result.final_output_as(MedicalAdviceGuardrailOutput)
    return GuardrailFunctionOutput(
        output_info=final_output,
        tripwire_triggered=final_output.is_medical_advice,
    )


# --- Lifecycle Hooks ---


class DemoLifecycleHook(RunHooks):
    def __init__(self, user_info: UserInfo):
        self.user_info = user_info

    async def on_agent_start(
        self, wrapper: RunContextWrapper[UserInfo], agent: Agent
    ) -> None:
        wrapper.context.session_start = datetime.now()
        # wrapper.context.total_tokens += wrapper.usage.total_tokens
        print(
            f"🎯 [LIFECYCLE] AGENT {agent.name} HAS STARTED AT {wrapper.context.session_start}"
        )

    # async def on_agent_end(
    #     self, wrapper: RunContextWrapper[UserInfo], agent: Agent, output: Any
    # ) -> None:
    #     if agent.name == "Diet Quality Agent":
    #         # tokens = wrapper.usage.total_tokens
    #         # wrapper.context.total_tokens += tokens
    #         print(
    #             f"🎯 [LIFECYCLE] AGENT {agent.name} HAS ENDED WITH OUTPUT. HERE IS THE OUTPUT \n: {output}. "
    #             # f"Current usage: {wrapper.usage}. Tokens used so far: {wrapper.context.total_tokens}"
    #         )


# --- Main Nutrition Tracking Agent ---

nutrition_agent = Agent[UserInfo](
    name="Nutrition Tracking Assistant",
    instructions=f"""You are a nutrition tracking assistant.
    You are responsible answering questions about the user's nutritional intake.
    For questions related to nutrient composition, handoff to the nutrient composition agent.
    The agent will return the nutrient composition of the food, so you can use it to answer the user's question.
    For questions related to food category, handoff to the food category agent.
    If the nutrient composition agent doesnt find the food in the data, return "No data found for the given food name"
    For questions related to price, handoff to the food price research agent.
    But for finding the nutritional information of a barcoded food item, handoff to the openfoodfacts agent.
    For questions related to diet quality, handoff to the diet quality agent.
    For questions related to updating the nutrition_df, handoff to the update_nutrition_df agent.
    If the user asks what diet quality would be if a new food was eaten on a particular day,
    then handoff to the update_nutrition_df agent and then calculate the HEI diet quality using the diet quality agent.
    Here is the data: {nutrition_json}""",
    handoffs=[
        food_category_agent,
        nutrient_composition_agent,
        diet_quality_agent,
        price_search_agent,
        openfoodfacts_agent,
        update_nutrition_df_agent,
    ],
    input_guardrails=[medical_advice_guardrail],
    model="gpt-4o-mini",
)

# --- Main ---

demo_hooks = DemoLifecycleHook(UserInfo(name="John Doe"))


async def main():
    user_info = UserInfo(name="John Doe", nutrition_df=nutrition_df)

    msg = input(
        "👋 Hi! We have your nutritional intake data. What do you want to know? "
    )
    if msg.lower() == "quit":
        return
    inputs: list[TResponseInputItem] = [{"content": msg, "role": "user"}]

    while True:
        try:
            result = Runner.run_streamed(
                nutrition_agent,
                input=inputs,
                hooks=demo_hooks,
                context=user_info,
            )
            async for event in result.stream_events():
                if not isinstance(event, RawResponsesStreamEvent):
                    continue
                data = event.data
                if isinstance(data, ResponseTextDeltaEvent):
                    print(data.delta, end="", flush=True)
                elif isinstance(data, ResponseContentPartDoneEvent):
                    print("\n")
        except InputGuardrailTripwireTriggered:
            print(
                "\n⚠️ GUARDRAIL TRIGGERED ⚠️ : I'm not allowed to provide medical advice."
            )

        inputs = result.to_input_list()
        print("\n")
        user_msg = input("💬 Enter a message: ")
        if user_msg.lower() == "quit":
            print("👋 See you space cowboy!")
            break
        inputs.append({"content": user_msg, "role": "user"})
        # nutrition_agent = result.current_agent


if __name__ == "__main__":
    asyncio.run(main())
