#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Filename: 20250508_PlotNeffTidy3D.py
Created: 2025-05-08
Author: Claudio JARAMILLO
Email: claudio.jaramilloconcha@epfl.ch
"""

import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List
import os
from datetime import datetime


def read_and_sort_data(file_path: str) -> pd.DataFrame:
    """Read CSV file and sort by core_width."""
    try:
        df = pd.read_csv(file_path, skiprows=5)
        return df.sort_values("core_width")
    except Exception as e:
        raise ValueError(f"Error processing {file_path}: {e}")


def calculate_axis_limits(dataframes: List[pd.DataFrame]) -> dict:
    """Calculate appropriate axis limits with padding."""
    all_core_widths = []
    all_values = []

    for df in dataframes:
        all_core_widths.extend(df["core_width"])
        all_values.extend(df["value"])

    if not all_core_widths:
        return {}

    x_padding = (max(all_core_widths) - min(all_core_widths)) * 0.1
    y_padding = (max(all_values) - min(all_values)) * 0.1

    return {
        "xlim": (min(all_core_widths) - x_padding, max(all_core_widths) + x_padding),
        "ylim": (min(all_values) - y_padding, max(all_values) + y_padding),
    }


def plot_data(
    dataframes: List[pd.DataFrame], filenames: List[str], save: bool = True
) -> plt.Figure:

    # Dimensions
    mm_to_inch = 1 / 25.4
    fig_width = 89 * mm_to_inch  # Nature single column width
    fig, ax = plt.subplots(figsize=(fig_width, fig_width * 0.75))

    # Colours
    nc_colors = [
        "#006DDB",  # Dark Blue
        "#DC3220",  # Red
        "#0072B2",  # Blue
        "#D56E00",  # Vermilion
        "#009E73",  # Bluish Green
        "#CC79A7",  # Reddish Purple
        "#E69F00",  # Orange
        "#56B4E9",  # Sky Blue
        "#009292",  # Teal
        "#F0E442",  # Yellow
    ]

    # Markers and lines
    markers = ["o", "s", "^", "D", "v", "p", "*", "h", "X", "<"]
    line_styles = ["-", "--", "-.", ":"]

    # Formatting style
    for i, (df, label) in enumerate(zip(dataframes, filenames)):
        color = nc_colors[i % len(nc_colors)]
        marker = markers[i % len(markers)]
        line_style = line_styles[(i // len(markers)) % len(line_styles)]

        ax.plot(
            df["core_width"],
            df["value"],
            marker=marker,
            markersize=4.5,
            markeredgewidth=0.5,
            markeredgecolor="white",
            linestyle=line_style,
            linewidth=1.1,
            color=color,
            label=label,
            alpha=0.9,
        )

    # Text elements
    ax.set_xlabel("Core width (μm)", fontsize=8, fontfamily="sans-serif", labelpad=2)
    ax.set_ylabel(
        "Effective index (n$_{eff}$)", fontsize=8, fontfamily="sans-serif", labelpad=2
    )

    # Tick params
    ax.tick_params(
        axis="both",
        which="major",
        direction="in",
        length=2.5,
        width=0.5,
        labelsize=7,
        pad=2,
    )

    # Legend
    legend = ax.legend(
        frameon=True,
        loc="best",
        fontsize=7,
        handlelength=1.8,
        borderpad=0.3,
        handletextpad=0.4,
        borderaxespad=0.5,
        framealpha=0.9,
    )
    legend.get_frame().set_linewidth(0.5)
    legend.get_frame().set_edgecolor("0.8")
    legend.get_frame().set_facecolor("0.97")

    # Grid
    ax.grid(True, which="major", linestyle=":", linewidth=0.4, color="0.8", alpha=0.4)

    # Adjust spines (border lines)
    for spine in ax.spines.values():
        spine.set_linewidth(0.5)
        spine.set_color("0.4")

    # Tight layout with padding
    plt.tight_layout(pad=0.8)

    # Save figure if requested
    if save:
        # Get current directory where the script is running
        script_dir = os.path.dirname(os.path.abspath(__file__))

        # Create timestamped filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = os.path.join(script_dir, f"mode_solver_plot_{timestamp}.pdf")

        # Save in multiple formats
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
        fig.savefig(save_path.replace(".pdf", ".png"), dpi=600, bbox_inches="tight")

        print(
            f"Saved plots to:\nPDF: {save_path}\nPNG: {save_path.replace('.pdf', '.png')}"
        )

    return fig


def plot_mode_solver_results(file_paths: List[str]) -> None:
    """Main function to process files and generate plot."""
    dataframes = []
    filenames = []

    for file_path in file_paths:
        try:
            df = read_and_sort_data(file_path)
            dataframes.append(df)
            filenames.append(Path(file_path).stem)
        except ValueError as e:
            print(e)
            continue

    if not dataframes:
        print("No valid data to plot")
        return

    plot_data(dataframes, filenames)

    # Set axis limits
    axis_limits = calculate_axis_limits(dataframes)
    if axis_limits:
        plt.xlim(axis_limits["xlim"])
        plt.ylim(axis_limits["ylim"])

    plt.show()


if __name__ == "__main__":
    files_to_plot = [
        r".\dataTE_pump_1.csv",
        r".\dataTM_signal_1.csv",
    ]
    plot_mode_solver_results(files_to_plot)
