import numpy as np
from itertools import product


def compute_polynomial_weights_local_Lagrange(p: int) -> np.ndarray:
  """
    Given the polynomial order p >= 1, compute the weights with respect
    to the canonical basis {1, y, y^2, ..., y^p, x, xy, ...., x^p} of P_p(Khat)
    for each basis function phi_i with phi_i(p_j) = \delta_ij.
    Here, the p_j are the points (x, y) \in X \times X, with X = {0, 1/p, 2/p, ..., 1}
    and x + y <= 1.

    Paramters
    ---------
    p : `int`
      The polynomial order.

    Returns
    -------

    weights : np.ndarray
      A matrix of shape N_p x N_p whose i-th column contains the polynomial weights of \phi_i
      in the canonical python ordering (a_00, a_01, a_02, ... a_0p, a_10, a_11, ... a_p0 )
    """

  assert (p := int(p)) >= 1
  Np = (p**2 + 3 * p + 2) // 2

  # create an array of multi-indices whose L-th row contains the multi index (i, j)
  # representing the polynomial powers x^i y^j in the canonical python ordering.

  # product(range(p+1), range(p+1)) creates the pairs:
  # for i in range(p+1):
  #   for j in range(p+1):
  #     pair = (i, j)
  multi_indices = np.stack([multi_index for multi_index in product(range(p+1), range(p+1)) if sum(multi_index) <= p]).astype(int)

  # create an array containing as rows the p_j = (x_j, y_j) of the triangle.
  P = multi_indices / p

  # create a matrix M containing as L-th column the L-th canonical polynomial of P_p(Khat) evaluated in the P_j

  M = np.empty((Np, Np), dtype=float)

  # iterate simultaneously over the L-th column index and the corresponding multi index
  for L, (i, j) in enumerate(multi_indices):
    # YOUR CODE HERE
    # M[:, L] = ???

  # create the right hand side matrix whose L-th column corresponds to the right hand side of the L-th nodal basis function.
  # YOUR CODE HERE
  # Rhs = ???

  # np.linalg.solve accepts several right hand sides as a matrix
  return np.linalg.solve(M, Rhs)


if __name__ == '__main__':
  for p in (1, 2, 3, 4):

    # round to 7 figures for better formatting.
    myweights = np.round(compute_polynomial_weights_local_Lagrange(p), 7)

    # create multi indices corresponding to order p in canonical python ordering
    multi_indices = tuple(multi_index for multi_index in product(range(p+1), range(p+1)) if sum(multi_index) <= p)

    # print to stdout
    print('With respect to the canonical polynomial basis with powers\n\n {},\n\n'
          'the weights of the nodal basis functions of order {} are given by: \n\n{}.\n\n'.format(str(multi_indices)[1:-1], p, '\n\n'.join(map(str, myweights.T))))