#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@author: Jochen Hinz
"""


from mesh import Triangulation

from scipy import sparse
from solve import solve_with_dirichlet_data
import numpy as np


def mass_matrix(mesh: Triangulation) -> sparse.csr_matrix:

  # the number of DOFs is equal to the number of mesh vertices / points
  ndofs = len(mesh.points)

  # make empty lil-matrix of shape (ndofs, ndofs)
  # we start with a lil-matrix because it can directly be assigned to
  M = sparse.lil_matrix((ndofs,)*2)

  # the constant block matrix
  # (2, 1, 1), (1, 2, 1), (1, 1, 2)
  mloc = np.ones((3, 3)) + np.eye(3)

  # loop over the triangles
  # for tri, detBK in zip(mesh.triangles, mesh.detBK):
  # produces
  # triangle_indices[0-th triangle], detBK[0-th triangle]
  # triangle_indices[1-st triangle], detBK[1-st triangle]
  # ...
  # triangle_indices[ndofs-1-st triangle], detBK[ndofs-1-st triangle]
  for tri, detBK in zip(mesh.triangles, mesh.detBK):
    # the sub-block of M resulting from slicing-out all rows with indices
    # tri = [i0, i1, i2] and then slicing out all columns with the same indices
    # can be accessed by using the np.ix_ function
    # M[np.ix_(tri, tri)] = the sub-block with row and column indices (i0, i1, i2)
    # np.ix_(tri, tri) is the same as np.ix_(*(tri,)*2)
    M[np.ix_(*(tri,)*2)] += detBK / 24 * mloc

  return M.tocsr()


def stiffness_matrix(mesh: Triangulation) -> sparse.csr_matrix:

  ndofs = len(mesh.points)

  A = sparse.lil_matrix((ndofs,)*2)

  # the local gradients as row-vectors
  # gradshapeF[i] gives the gradient of the i-th local basis function
  gradshapeF = np.array([ [-1, -1],
                          [1, 0],
                          [0, 1] ])

  for tri, detBK, BKinv in zip(mesh.triangles, mesh.detBK, mesh.BKinv):

    # this is equivalent to (BKinv.T @ gradshapeF.T).T (but compacter)
    grad_glob = gradshapeF @ BKinv

    Aloc = np.empty((3, 3), dtype=float)
    for i in range(3):
      for j in range(3):
        Aloc[i, j] = detBK / 2 * (grad_glob[i] * grad_glob[j]).sum()

    # add to the right position in the matrix
    A[np.ix_(*(tri,)*2)] += Aloc

    # this line is the vectorisation of the above which we will learn in the coming weeks
    # _ = np.newaxis
    # A[np.ix_(*(tri,)*2)] += detBK / 2 * (grad[:, _] * grad[_]).sum(-1)

  return A.tocsr()


def load_vector(mesh: Triangulation, F: float = 1.0) -> np.ndarray:

  ndofs = len(mesh.points)

  L = np.zeros((ndofs,), dtype=float)

  for tri, detBK in zip(mesh.triangles, mesh.detBK):
    L[tri] += F / 6 * detBK

  return L


def assemble_neumann_rhs(mesh: Triangulation, mask: np.ndarray, g: float = 1.00) -> np.ndarray:
  mask = np.asarray(mask, dtype=np.bool_)
  mask.shape == mesh.lines.shape[:1]

  local_neumann_load = g / 2 * np.ones(2)

  rhs = np.zeros(len(mesh.points), dtype=float)

  # retain only the boundary edges mesh.lines[i] if mask[i] is True
  neumann_lines = mesh.lines[mask]
  
  # loop over each line [index_of_a, index_of_b] and the corresponding points (a, b)
  for line, (a, b) in zip(neumann_lines, mesh.points[neumann_lines]):
    rhs[line] += ### YOUR CODE HERE

  return rhs


def reaction_diffusion(mesh_size=0.05):
  """
    P1 FEM solution of the reaction-diffusion problem:

      -∆u + u = 1    in  Ω
            u = 0    on ∂Ω_D
         ∂n u = 1    on ∂Ω_N

    where Ω = (0, 1)^2 and ∂Ω_N is the bottom side of Ω.

    parameters
    ----------
    mesh_size: float value 0 < mesh_size < 1 tuning the mesh density.
               Smaller value => denser mesh.

  """

  # create a triangulation of the unit square by passing an array with
  # rows equal to the square's vertices in counter-clockwise direction.
  # The last vertex need not be repeated.
  mesh = Triangulation.from_polygon( np.array([ [0, 0],
                                                [1, 0],
                                                [1, 1],
                                                [0, 1] ]), mesh_size=mesh_size)

  # lines is an integer array of shape (nboundary_edges, 2) containing the indices
  # of the vertices that lie on the boundary edges.
  lines = mesh.lines

  # create a boolean mask of shape mesh.lines[:1] that equals `True`
  # if both associated points' y-values satisfy abs(y) < 1e-10
  neumann_mask = np.abs(mesh.points[lines][..., 1] < 1e-10).all(axis=1)

  # the boundary element that are part of the Dirichlet boundary correspond to the mask
  # that is the negation of `neumann_mask`
  dirichlet_mask = ~neumann_mask

  # plot the mesh
  mesh.plot()

  # make mass matrix
  M = mass_matrix(mesh)

  # make stiffness matrix
  A = stiffness_matrix(mesh)

  # assemble the load vector
  rhs = load_vector(mesh, F=1) + assemble_neumann_rhs(mesh, neumann_mask, g=1)

  # the boundary vertices are the unique indices of the mesh's boundary edges restricted to the Dirichlet boundary
  bindices = np.unique(lines[dirichlet_mask])

  # use the `solve_with_dirichlet_data` method to solve the system under the boundary condition
  solution = solve_with_dirichlet_data(A + M, rhs, bindices, np.zeros_like(bindices))

  mesh.tripcolor(solution)


if __name__ == '__main__':
  reaction_diffusion()
