import numpy as np
import datetime
from dotenv import load_dotenv
import pandas as pd
import streamlit as st
from openai import OpenAI
import os
import plotly.express as px
import json



# Function to load CSV file
def load_data(file):
    df = pd.read_csv(file)
    return df


######### Define functions for data analysis and visualisation

# Function to generate glucose level versus time graph
def plot_glucose_graph(df):
    import plotly.graph_objects as go

    # Thresholds
    low_glucose = 70
    high_glucose = 180

    fig = go.Figure()

    # Create a line plot
    fig = px.line(df, x='Time', y='Glucose Level', title='Glucose Levels Over Time')

    # Add threshold lines for high and low glucose levels
    fig.add_hline(y=high_glucose, line_dash="dash", line_color="red", annotation_text="High Glucose Threshold")
    fig.add_hline(y=low_glucose, line_dash="dash", line_color="red", annotation_text="Low Glucose Threshold")

    # fill arreas :

    df['label'] = np.select([df['Glucose Level'] > 180, (df['Glucose Level'] >= 70) & (df['Glucose Level'] <= 180), df['Glucose Level'] < 70],
                        ['high', 'normal', 'low'],
                        default='unknown')
    df['group'] = df['label'].ne(df['label'].shift()).cumsum()
    df = df.groupby('group')
    dfs = []
    for name, data in df:
        dfs.append(data)
    def fillcol(label):
        if label == "normal":
            return 'rgba(0, 250, 0, 0.2)'  # Green for normal
        elif label == "high":
            return 'rgba(250, 0, 0, 0.2)'  # Red for high
        elif label == "low":
            return 'rgba(255, 255, 0, 0.2)'  # Yellow for low
        else:
            return 'rgba(128, 128, 128, 0.2)'  # Gray for unknown or other labels
    for df in dfs :
        fig.add_traces(go.Scatter(x=df['Time'], y = df['Glucose Level'],
                                        line = dict(color='rgba(0,0,0,0)'),
                                        fill='tozeroy',
                                        fillcolor = fillcol(df['label'].iloc[0]),
                                        showlegend=False))

    legend_colors = {
        "Normal range": 'rgba(0, 250, 0, 0.2)',
        "High levels (hyperglycemia)": 'rgba(250, 0, 0, 0.2)',
        "Low levels (hypoglycemia)": 'rgba(255, 255, 0, 0.2)'
    }

    for label, color in legend_colors.items():
        fig.add_trace(go.Scatter(
            x=[None], y=[None],
            mode='markers',
            marker=dict(size=10, color=color),
            legendgroup=label,
            showlegend=True,
            name=label.capitalize()
        ))

    # Customize the x-axis to show time correctly
    fig.update_xaxes(
        tickformat='%H:%M\n%e %b',  # Hour:Minute on the primary level, Day Month on the secondary level
        tickmode='auto', 
        nticks=24  # Adjust this number based on the duration and interval you want
    )

    # Update the layout to place the legend below the graph
    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="top",
            y=-0.3,  # Position the legend below the graph
            xanchor="center",
            x=0.5
        )
    )

    # Display the figure in Streamlit
    st.plotly_chart(fig)


def style_table(df, column_colors):
    def apply_column_color(col):
        color = column_colors.get(col.name)
        if color:
            return [f'background-color: {color}' if i == 0 else '' for i in range(len(col))]
        return ['' for _ in col]

    return df.style.set_properties(**{'text-align': 'center'}).apply(apply_column_color, axis=0)

    
# Function to calculate average glucose in mg/dL
def calculate_average_glucose(df):
    average_glucose = df['Glucose Level'].mean()
    return round(average_glucose,2)


# Function to calculate glucose management indicator (GMI) in %
def calculate_gmi(df):
    gmi = 3.31 + 0.02392*calculate_average_glucose(df) #Attention : mean(average glucose) need to be in mg/dL
    return round(gmi, 2)

# compute the mean blood glucose levels during days and nights
def calculate_day_night_time_mbg(df) :
    df["Time"] = pd.to_datetime(df["Time"])

    # Define nighttime and daytime periods (assuming 00:00 to 06:00 as nighttime)
    nighttime = df["Time"].dt.hour <6
    df_nighttime = df[nighttime]
    df_daytime = df[~(nighttime)]
    
    nighttime_mbg = df_nighttime['Glucose Level'].mean()
    
    # Calculate Daytime MBG
    daytime_mbg = df_daytime['Glucose Level'].mean()
    
    # Calculate Daily MBG
    daily_mbg = df['Glucose Level'].mean()

    
    # MAGE calculation
    threshold = 1.0  # This threshold is often used to identify significant glucose excursions
    glucose_max = df['Glucose Level'].rolling(window=5, min_periods=1).max()
    glucose_min = df['Glucose Level'].rolling(window=5, min_periods=1).min()
    excursions = (glucose_max - glucose_min).dropna()
    mage = excursions[excursions > threshold].mean()
    

    metrics = {
        'Nighttime MBG': nighttime_mbg,
        'Daytime MBG': daytime_mbg,
        'MAGE': mage
    }

    return metrics


def identify_hyperglycemia_episodes(df):
    # Convert 'Time' column to datetime if not already
    df['Time'] = pd.to_datetime(df['Time'])
    
    # Initialize variables
    episodes = []
    start_time = None
    max_gluc = -float('inf')
    
    for i in range(len(df)):
        if df.iloc[i]['Glucose Level'] > 180:
            if start_time is None:
                # Start a new episode
                start_time = df.iloc[i]['Time']
            # Update the maximum glucose level
            max_gluc = max(max_gluc, df.iloc[i]['Glucose Level'])
        else:
            if start_time is not None:
                # End the current episode
                end_time = df.iloc[i]['Time']
                episodes.append({
                    'start': start_time.isoformat(),
                    'end': end_time.isoformat(),
                    "episode_duration": (end_time - start_time).isoformat(),
                    'max_gluc': float(max_gluc)  # <<<<<<<<<<<<<< HERE TOO!
                })
                # Reset for the next episode
                start_time = None
                max_gluc = -float('inf')
    
    # Final check in case the last episode goes to the end of the dataframe
    if start_time is not None:
        end_time = df.iloc[-1]['Time']
        episodes.append({
            'start': start_time.isoformat(),
            'end': end_time.isoformat(),
            "episode_duration": (end_time - start_time).isoformat(),
            'max_gluc': float(max_gluc)
        })
    
    return episodes


    

# Function to calculate glucose variability in %
def calculate_glucose_variability(df):
    # Calculate mean glucose level
    mean_glucose = calculate_average_glucose(df)
    # Calculate standard deviation of glucose level
    std_glucose = df['Glucose Level'].std()
    # Calculate coefficient of variation (CV) as a percentage
    glucose_variability = (std_glucose / mean_glucose) * 100
    return round(glucose_variability,2)


# Function to calculate glucose range
def calculate_glucose_range(df):
    # Initialize counters for each range
    very_high_count = 0
    high_count = 0
    low_count = 0
    very_low_count = 0
    target_range_count = 0
    
    # Loop through each row in the DataFrame
    for index, row in df.iterrows():
        glucose_level = row['Glucose Level']
        
        # Check the glucose level and increment the corresponding counter
        if glucose_level > 250:
            very_high_count += 1
        elif glucose_level > 180:
            high_count += 1
        elif glucose_level < 54:
            very_low_count += 1
        elif glucose_level < 70:
            low_count += 1
        else:
            target_range_count += 1
    
    # Calculate total number of measurements
    total_measurements = len(df)
    
    # Calculate percentage of time spent in each range
    very_high_percent = round((very_high_count / total_measurements) * 100, 2)
    high_percent = round ((high_count / total_measurements) * 100, 2)
    low_percent = round((low_count / total_measurements) * 100, 2)
    very_low_percent = round((very_low_count / total_measurements) * 100, 2)
    target_range_percent = round((target_range_count / total_measurements) * 100, 2)
    
    # Create a dictionary to store the results
    result = {
        'Very High': ('>250 mg/dL', very_high_count, very_high_percent),
        'High': ('181-250 mg/dL', high_count, high_percent),
        'Target Range': ('70-180 mg/dL', target_range_count, target_range_percent),
        'Low': ('54-69 mg/dL', low_count, low_percent),
        'Very Low': ('<54 mg/dL', very_low_count, very_low_percent)    
    }

    return result


# Function to create stacked bar plot for glucose range
def create_glucose_range_plot(df):
    # Extract range labels, counts, and percentages
    glucose_range_result = calculate_glucose_range(df)
    
    # Collect data for DataFrame
    data_list = []
    for range_label, (desc, count, percent) in glucose_range_result.items():
        text_label = range_label if percent > 5 else ""  # Conditional text label based on percentage
        data_list.append({
            'Category': 'All Ranges',  # Single category for all entries
            'Percentage': percent,
            'Range': range_label,
            'Description': desc,
            'Count': count,
            'Text': text_label  # New column for conditional text
        })
    
    # Create DataFrame for plot
    data = pd.DataFrame(data_list)

    # Define colors for each range - this can be adjusted as needed
    colors = {
        'Very High': 'red',
        'High': 'pink',
        'Target Range': 'green',
        'Low': 'yellow',
        'Very Low': 'orange'
    }
    
    # Create stacked bar plot using plotly express
    fig = px.bar(
        data, 
        x='Category', 
        y='Percentage', 
        color='Range', 
        color_discrete_map=colors,
        text='Text',
        hover_data={
            'Range': True,  # Avoid duplication in hover information
            'Percentage': ':.2f%',  # Format percentage
            'Category': False,
            'Text': False
        },
        #title='Glucose Range Distribution',
        labels={'Percentage': 'Percentage of Time (%)'}
    )
    fig.update_traces(textposition='inside', insidetextanchor='middle')
    fig.update_layout(
        showlegend=True,
        xaxis_title='',
        yaxis_title='Percentage of Time (%)', 
        
    )
    st.plotly_chart(fig)

def identify_columns(data):
    time_col = next((col for col in data.columns if pd.to_datetime(data[col], errors='coerce').notna().all()), None)
    glucose_col = next((col for col in data.columns if data[col].dtype in ['float64', 'int64'] and data[col].between(15, 500).all()), None)
    if time_col and glucose_col:
        return time_col, glucose_col
    else:
        raise ValueError("Necessary columns not found")
    

def calculate_glucose_metrics(df, lower_bound=70, upper_bound=140):
    """
    Calculate various glucose metrics for gestational diabetes management.
    
    Parameters:
    - df: DataFrame with time and glucose levels.
    - lower_bound: Lower bound for Time in Range calculation (default 70 mg/dL).
    - upper_bound: Upper bound for Time in Range calculation (default 140 mg/dL).
    
    Returns:
    - A dictionary with glucose metrics.
    """
    data = df.copy() 
    
    # Store the time and glucose column names. 
    time_col, glucose_col = identify_columns(data)

    # Ensure the time column is a datetime type
    data[time_col] = pd.to_datetime(data[time_col])
    data.set_index(time_col, inplace=True)
    
    # Compute general stats
    gluc_stats_metrics = data[glucose_col].agg(['mean', 'max', 'min']).to_dict()
    
    # Calculate Time in Range, Above Range, and Below Range
    def calculate_time_percentages(df, lower_bound, upper_bound):
        tir = df.apply(lambda x: (x >= lower_bound) & (x <= upper_bound)).sum()
        tar = df.apply(lambda x: x > upper_bound).sum()
        tbr = df.apply(lambda x: x < lower_bound).sum()
        total = len(df)
        return {
            'TIR': (tir / total) * 100 if total != 0 else 0,
            'TAR': (tar/ total) * 100 if total != 0 else 0,
            'TBR': (tbr/ total) * 100 if total != 0 else 0
        }
    
    # Run time metrics calculations
    time_metrics = calculate_time_percentages(data[glucose_col], lower_bound, upper_bound)

    # Add time_metrics to gluc_stats_metrics
    gluc_stats_metrics.update(time_metrics)

    # add metrics computed above :
    gluc_stats_metrics["GMI"] = calculate_gmi(data)
    gluc_stats_metrics["GV"] = calculate_glucose_variability(data)

    last_metrics = calculate_day_night_time_mbg(data)

    gluc_stats_metrics ={**gluc_stats_metrics, **last_metrics}

    # to finish, shows the overall duration of the blood glucose sample given :
    df["Time"] = pd.to_datetime(df["Time"])
    gluc_stats_metrics["duration"] = df["Time"].max() - df["Time"].min()

    # HBGI
    def calculate_fg(glucose_value):
        return max(0, 1.509 * (np.log(glucose_value)**1.084) - 5.381)
    df['fg'] = df['Glucose Level'].apply(calculate_fg)
    hbgi = (1 / len(df)) * np.sum(10 * df['fg']**2)
    gluc_stats_metrics["HBGI"] = hbgi

    return gluc_stats_metrics

#--------------------------



######### Create the chatbot

def access_chatbot(metrics, personal_info_dict, episodes) : 
        
    # Load environment variables from .env file, if not already loaded
    load_dotenv()

    # Access the API key
    api_key = os.getenv("API_KEY")
    client = OpenAI(api_key=api_key)

    # Initialize the model in the Session State
    if "openai_model" not in st.session_state:
        st.session_state["openai_model"] = "gpt-3.5-turbo"

    # Initialize the 'messages' list in the Session state.
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Initialize the chatbot context (initial prompt) here :
    context = ("You are a health specialist assistant. The user is suffering from Gestational Diabetes Mellitus. You are here to help the user to understand her blood glucose metrics."
               "Mean Glucose: {mean} mmol/L, Max Glucose: {max} mmol/L, Min Glucose: {min} mmol/L, Time in Range (TIR): {TIR}%, Time Above Range (TAR): {TAR}%, Time Below Range (TBR): {TBR}%, Glucose Management Indicator : {GMI}%."
               "Glucose Variability : {GV}%, Nighttime Mean Blood Glucose : {Nighttime MBG} mmol/L, Daytime Mean Blood Glucose : {Daytime MBG} mmol/L, "
               "Mean Average Glucose Excursion : {MAGE} mmol/L, High Blood Glucose Index : {HBGI}. We computed these metrics from a blood glucose dataset containing {duration} of data."
               "The user is going to ask you questions about her blood glucose metrics. Be kind, and explain it as simple as possible, and give her insightful informations. "
               "Feel free to use emojis at the end of the sentence, to make cute answers. Plus, if your answer is long, feel free to make paragraphs."
               "The conversation might not be related to blood glucose metrics only."
               "Here are the personal information about the user (adapt your answers according to them) : ")
    # Add the glucose metrics
    for key, value in personal_info_dict.items():
        context += key + " : " + str(value) + ", "
    context = context.format(**metrics)
    
    # Add the hyperglycemia episodes

    episodes_json = json.dumps(episodes,indent=4)
    context += ("\n Here are the hyperglycemia episodes, using a threshold of 180 mmol/L : ")+ episodes_json

    # Save the context in the 'messages' list as a dictionnary : 
    # Role : defines who sent the message
    # Content : the message
    st.session_state.messages.append({"role": "system", "content": context})


    ##### Define the chatbot system

    # We show the conversation in the UI : 
    for message in st.session_state.messages:
        if message["role"] != "system": # show all messages except the initial prompt
            with st.chat_message(message["role"], avatar="👩🏼‍⚕️" if message["role"] == "assistant" else "🌸"):
                # Show the User/AI conversations
                st.markdown(message["content"])

    # Here we ask the user 
    if prompt := st.chat_input("How can I help you ?"): # retrieve the user prompt in the chat system
        st.session_state.messages.append({"role": "user", "content": prompt}) # we store the new user message in the Session State
        with st.chat_message("user", avatar="🌸"):
            st.markdown(prompt) # We show the user prompt in the chat system

        # We send the prompt and the conversation history to the API
        with st.chat_message("assistant", avatar="👩🏼‍⚕️"): # sets the chatbot display
            stream = client.chat.completions.create(
                model=st.session_state["openai_model"],
                messages=[
                    {"role": m["role"], "content": m["content"]} # tells the history of messages, including the new user message
                    for m in st.session_state.messages # message history
                ],
                stream=True, # allow the response to be streamed back, piece by piece.
            )
            response = st.write_stream(stream) # show the generated API message (including the stream)

         # We store the API's response in the message history
        st.session_state.messages.append({"role": "assistant", "content": response})




def generate_insight(metrics, personal_info_dict, episodes,level) : 
    # We load environment variables from .env file, in case it wasn't done ealier
    load_dotenv()

    # Access the API key
    api_key = os.getenv("API_KEY")
    client = OpenAI(api_key=api_key)

    #metrics = calculate_glucose_metrics(df)
    context = ("You are a health specialist assistant. The user is suffering from Gestational Diabetes Mellitus. You are here to help the user to understand her blood glucose metrics."
               "Mean Glucose: {mean} mmol/L, Max Glucose: {max} mmol/L, Min Glucose: {min} mmol/L, Time in Range (TIR): {TIR}%, Time Above Range (TAR): {TAR}%, Time Below Range (TBR): {TBR}%, Glucose Management Indicator : {GMI}%."
               "Glucose Variability : {GV}%, Nighttime Mean Blood Glucose : {Nighttime MBG} mmol/L, Daytime Mean Blood Glucose : {Daytime MBG} mmol/L, "
               "Mean Average Glucose Excursion : {MAGE} mmol/L, High Blood Glucose Index : {HBGI}. We computed these metrics from a blood glucose dataset containing {duration} of data."
               "Don't list me the metrics : the user already know it. Only highlight me the most problematics ones."
               "First : give her personalized feedback of her metrics : which one are dangerous/healthy and how it needs to be understood, and managed."
               "Secondly : give her advices on how to manage her blood glucose levels according to the metrics."
               "Feel free to use emojis, and make your explanation easily understandable."
               "Use warning emojis for dangerous metric values."
               "Here are the personal information about the user : ")

    # Extract the dictionnary values and add them into the context
    for key, value in personal_info_dict.items():
        context += key + " : " + str(value) + ", "

    # Add the glucose metrics to the context
    context = context.format(**metrics)

    # Add the hyperglycemia episodes
    episodes_json = json.dumps(episodes,indent=4)
    context += ("\n Here are the hyperglycemia episodes, using a threshold of 180 mmol/L : ")+ episodes_json


    level = level.lower()
    context += ("Can you adapt the insights genaration for " + str(level) + " level of interpretations."
                "Can you write this insight folowing this format : Start the insights with this sentence exactly :"
                "Based on your blood glucose metrics, there are several areas that need attention:"
                "After this starting sentence, first explain all the problematic metrics." 
                "Then give the feedabck about the general metrics and hyperglycemia episodes." 
                "To finish, give some management advice.")

    completion = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
    {"role": "system", "content": context},
    {"role": "user", "content": "."}],
  temperature=0.5) # The temperature defines the randomness or creativity of the generated response.
    
    # completion.choices[0].message.content = model's response
    return completion.choices[0].message.content