#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@author: Jochen Hinz
Updated 28.03.2025 by Moahmed Ben Abdelouahab 
"""

import numpy as np
from functools import wraps
import matplotlib.pyplot as plt



# shortcut for vectorisation.
# >>> arr = np.ones((10,), dtype=int)
# >>> print(arr.shape)
#     (10,)
# >>> print(arr[:, np.newaxis].shape)
#     (10, 1)
# >>> print(arr[:, _].shape)
#     (10, 1)
_ = np.newaxis


def frozen(array: np.ndarray) -> np.ndarray:
  """
    Freeze a vector inplace and return it.

    Example
    -------

    >>> arr = np.zeros((10,), dtype=int)
    >>> print(arr[0])
        0
    >>> arr[0] = 1
    >>> print(arr[0])
        1
    >>> arr = np.zeros((10,), dtype=int)
    >>> arr = frozen(arr)
    >>> arr[0] = 1
        ERROR

    Both in and out of place will work.
    >>> arr = np.zeros((10,), dtype=int)
    >>> frozen(arr)
    >>> arr[0] = 1
        ERROR
  """
  array = np.asarray(array)
  array.flags.writeable = False
  return array


def freeze(fn):
  """
    Decorator that freezes the returned array inplace.

    Example
    -------

    def multiply(arr, val):
      return val * arr

    >>> arr = np.ones((5,), dtype=int)
    >>> new_arr = multiply(arr, 2)
    >>> print(new_arr)
        [2, 2, 2, 2, 2]
    >>> new_arr[0] = 10
    >>> print(new_arr)
        [10, 2, 2, 2, 2]

    @freeze
    def multiply(arr, val):
      return val * arr

    >>> arr = np.ones((5,), dtype=int)
    >>> new_arr = multiply(arr, 2)
    >>> print(new_arr)
        [2, 2, 2, 2, 2]
    >>> new_arr[0] = 10
        ERROR
  """
  @wraps(fn)
  def wrapper(*args, **kwargs):
    return frozen(fn(*args, **kwargs))
  return wrapper




def plot_convergence(mesh_sizes: np.ndarray, errors_list: np.ndarray, var_list: np.ndarray, label: str, title_suffix: str= "H1") -> None:
    r"""
      Function to plot convergence of errors with respect to mesh sizes, displaying the order of convergence on the plot.
      Args:
      - mesh_sizes (numpy.ndarray): Array of mesh sizes.
      - errors_list (numpy.ndarray): List of error values for each mesh size.
      - var_list (numpy.ndarray): List of variable labels or values for different datasets.
      - label (str): Label for the variable being analyzed.
      Returns:
      - None
    """
    
    plt.figure(figsize=(8,6))
    
    mesh_sizes = np.array(mesh_sizes)
    errors_list = np.array(errors_list)

    for errors, var in zip(errors_list, var_list):
        convergence_order = np.log(errors[:-1] / errors[1:]) / np.log(mesh_sizes[:-1] / mesh_sizes[1:])
        print(f"Convergence order {label}={var} (between consecutive refinements):", convergence_order)
        
        line, = plt.loglog(mesh_sizes, errors, marker='o', label=f'{label}={var:.4f}')
        
        # Add slope points
        for i, order in enumerate(convergence_order):
            plt.text(mesh_sizes[i+1], errors[i+1], f'{order:.2f}', fontsize=8, color=line.get_color(),
                     verticalalignment='center', horizontalalignment='center', rotation=0,
                     bbox=dict(facecolor="white", edgecolor='black', boxstyle='round,pad=0.3'))

    plt.xlabel('Mesh Size')
    plt.ylabel('Error')
    plt.title(f'{title_suffix} convergence')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.legend()
    plt.show()


