import asyncio
import pandas as pd
from pydantic import BaseModel
from openai.types.responses import ResponseContentPartDoneEvent, ResponseTextDeltaEvent

from agents import (
    Agent,
    Runner,
    TResponseInputItem,
    RawResponsesStreamEvent,
    function_tool,
)
from dotenv import load_dotenv

load_dotenv()

nutrition_df = pd.read_csv("data/example_intake_data.csv")
nutrition_json = nutrition_df.to_json(orient="records")

# --- Pydantic Models ---


class NutrientInformation(BaseModel):
    food_name: str
    carbohydrate: float  # in grams (g) per 100g
    fat: float  # total fat consumed in grams (g) per 100g
    protein: float  # in grams (g) per 100g
    fiber: float  # in grams (g) per 100g
    alcohol: float  # in grams (g) per 100g
    sodium: (
        float  # typically in milligrams (mg) (conversion applied as needed) per 100g
    )
    sugar: float  # in grams (g) per 100g
    fatty_acids_monounsaturated: float  # in grams (g) per 100g
    fatty_acids_polyunsaturated: float  # in grams (g) per 100g
    fatty_acids_saturated: float  # in grams (g) per 100g


# --- Functions ---


@function_tool
def get_nutrient_information(food_name: str):
    """Returns the nutrient information for a given food name"""
    nutrient_info = pd.read_csv("data/nutrient_standard_info.csv")
    # -- Either match exactly or match by partial string
    # nutrient_info = nutrient_info[nutrient_info["food_name"] == food_name]
    nutrient_info = nutrient_info[
        nutrient_info["food_name"].str.lower().str.contains(food_name.lower())
    ]

    # -- Sort by food_name length
    nutrient_info = nutrient_info.sort_values(by="food_name", key=lambda x: x.str.len())
    if nutrient_info.empty:
        return {
            "food_name": food_name,
            "carbohydrate": None,
            "fat": None,
            "protein": None,
            "fiber": None,
            "alcohol": None,
            "sodium": None,
            "sugar": None,
            "fatty_acids_monounsaturated": None,
            "fatty_acids_polyunsaturated": None,
            "fatty_acids_saturated": None,
        }
    return nutrient_info.iloc[0].to_dict()


# --- Specialized Agents ---

food_info_agent = Agent(
    name="Food Info Agent",
    instructions="""You are a nutrient composition agent.
    Using get_nutrient_information function, you are responsible for providing nutrient composition of a given food.
    If the food name is not found, return "No nutrient information found for the given food name"
    """,
    handoff_description="Specialist agent for food nutrient composition",
    output_type=NutrientInformation,
    tools=[get_nutrient_information],
    model="gpt-4o-mini",
)


# --- Main Nutrition Tracking Agent ---

nutrition_agent = Agent(
    name="Nutrition Tracking Assistant",
    instructions=f"""You are a nutrition tracking assistant.
    You are responsible answering questions about the user's nutritional intake.
    For questions related to food information, handoff to the food info agent.
    Here is the data: {nutrition_json}""",
    handoffs=[food_info_agent],
    model="gpt-4o-mini",
)


async def main():

    msg = input(
        "👋 Hi! We have your nutritional intake data. What do you want to know? "
    )
    if msg.lower() == "quit":
        return
    inputs: list[TResponseInputItem] = [{"content": msg, "role": "user"}]

    while True:
        result = Runner.run_streamed(
            nutrition_agent,
            input=inputs,
        )
        async for event in result.stream_events():
            if not isinstance(event, RawResponsesStreamEvent):
                continue
            data = event.data
            if isinstance(data, ResponseTextDeltaEvent):
                print(data.delta, end="", flush=True)
            elif isinstance(data, ResponseContentPartDoneEvent):
                print("\n")

        inputs = result.to_input_list()
        print("\n")
        user_msg = input("💬 Enter a message: ")
        if user_msg.lower() == "quit":
            print("👋 See you space cowboy!")
            break
        inputs.append({"content": user_msg, "role": "user"})
        # nutrition_agent = result.current_agent


if __name__ == "__main__":
    asyncio.run(main())
