import asyncio
import pandas as pd
from pydantic import BaseModel
from openai.types.responses import ResponseContentPartDoneEvent, ResponseTextDeltaEvent

from agents import (
    Agent,
    Runner,
    TResponseInputItem,
    RawResponsesStreamEvent,
    function_tool,
)
from hei_calculator import HEICalculator
from dotenv import load_dotenv

load_dotenv()

nutrition_df = pd.read_csv("data/example_intake_data.csv")
nutrition_json = nutrition_df.to_json(orient="records")

# --- Pydantic Models ---


class NutrientInformation(BaseModel):
    food_name: str
    carbohydrate: float  # in grams (g) per 100g
    fat: float  # total fat consumed in grams (g) per 100g
    protein: float  # in grams (g) per 100g
    fiber: float  # in grams (g) per 100g
    alcohol: float  # in grams (g) per 100g
    sodium: (
        float  # typically in milligrams (mg) (conversion applied as needed) per 100g
    )
    sugar: float  # in grams (g) per 100g
    fatty_acids_monounsaturated: float  # in grams (g) per 100g
    fatty_acids_polyunsaturated: float  # in grams (g) per 100g
    fatty_acids_saturated: float  # in grams (g) per 100g


class FoodCategoryMapping(BaseModel):
    total_fruits: bool  # True if the food fits this category, False otherwise
    whole_fruits: bool  # True if the food fits this category, False otherwise
    total_vegetables: bool  # True if the food fits this category, False otherwise
    greens_and_beans: bool  # True if the food fits this category, False otherwise
    whole_grains: bool  # True if the food fits this category, False otherwise
    dairy: bool  # True if the food fits this category, False otherwise
    protein_foods: bool  # True if the food fits this category, False otherwise
    seafood_plant_proteins: bool  # True if the food fits this category, False otherwise
    refined_grains: bool  # True if the food fits this category, False otherwise
    added_sugars: bool  # True if the food fits this category, False otherwise


# --- Functions ---


@function_tool
def get_nutrient_information(food_name: str):
    """Returns the nutrient information for a given food name"""
    nutrient_info = pd.read_csv("data/nutrient_standard_info.csv")

    # -- Either match exactly or match by partial string
    # nutrient_info = nutrient_info[nutrient_info["food_name"] == food_name]
    nutrient_info = nutrient_info[
        nutrient_info["food_name"].str.lower().str.contains(food_name.lower())
    ]

    # -- Sort by food_name length
    nutrient_info = nutrient_info.sort_values(by="food_name", key=lambda x: x.str.len())
    if nutrient_info.empty:
        return {
            "food_name": food_name,
            "carbohydrate": None,
            "fat": None,
            "protein": None,
            "fiber": None,
            "alcohol": None,
            "sodium": None,
            "sugar": None,
            "fatty_acids_monounsaturated": None,
            "fatty_acids_polyunsaturated": None,
            "fatty_acids_saturated": None,
        }
    return nutrient_info.iloc[0].to_dict()


@function_tool
def calculate_HEI_score(day: str):
    """Calculates the HEI score for a given day (in format YYYY-MM-DD)"""
    user_df = nutrition_df.copy()
    user_df["eaten_date"] = pd.to_datetime(user_df["eaten_date"])
    user_df = user_df[user_df["eaten_date"] == (day)]
    total_energy = user_df["energy_kcal_eaten"].sum()
    calculator = HEICalculator(base_path="data/cross_verification_food_categories_HEI")
    all_scores = calculator.calculate_all_scores(user_df, total_energy)
    return all_scores


# --- Specialized Agents ---

nutrient_composition_agent = Agent(
    name="Nutrient Composition Agent",
    instructions="""You are a nutrient composition agent.
    Using get_nutrient_information function, you are responsible for providing nutrient composition of a given food.
    """,
    handoff_description="Specialist agent for food nutrient composition",
    output_type=NutrientInformation,
    tools=[get_nutrient_information],
)

food_category_agent = Agent(
    name="Food Category Agent",
    instructions="""You are a food category agent.
    You are responsible for assigning the food category of a given food.
    """,
    handoff_description="Specialist agent for food category",
    output_type=FoodCategoryMapping,
)

diet_quality_agent = Agent(
    name="Diet Quality Agent",
    instructions="""You are a diet quality agent.
    You are responsible for calculating the diet quality of a user's daywise food intake.
    """,
    handoff_description="Specialist agent for diet quality",
    tools=[calculate_HEI_score],
)


# --- Main Nutrition Tracking Agent ---

nutrition_agent = Agent(
    name="Nutrition Tracking Assistant",
    instructions=f"""You are a nutrition tracking assistant.
    You are responsible answering questions about the user's nutritional intake.
    For questions related to nutrient composition, handoff to the nutrient composition agent.
    For questions related to food category, handoff to the food category agent.
    If the nutrient composition agent doesnt find the food in the data, return "No data found for the given food name"
    For questions related to diet quality, handoff to the diet quality agent.
    Here is the data: {nutrition_json}""",
    handoffs=[food_category_agent, nutrient_composition_agent, diet_quality_agent],
)


async def main():

    msg = input(
        "👋 Hi! We have your nutritional intake data. What do you want to know? "
    )
    if msg.lower() == "quit":
        return
    inputs: list[TResponseInputItem] = [{"content": msg, "role": "user"}]

    while True:
        result = Runner.run_streamed(
            nutrition_agent,
            input=inputs,
        )
        async for event in result.stream_events():
            if not isinstance(event, RawResponsesStreamEvent):
                continue
            data = event.data
            if isinstance(data, ResponseTextDeltaEvent):
                print(data.delta, end="", flush=True)
            elif isinstance(data, ResponseContentPartDoneEvent):
                print("\n")

        inputs = result.to_input_list()
        print("\n")
        user_msg = input("💬 Enter a message: ")
        if user_msg.lower() == "quit":
            print("👋 See you space cowboy!")
            break
        inputs.append({"content": user_msg, "role": "user"})
        # nutrition_agent = result.current_agent


if __name__ == "__main__":
    asyncio.run(main())
