from util import np, _, plot_convergence
from quad import seven_point_gauss_6
from mesh import Triangulation
from integrate import compute_H1_norm_difference, \
                      stiffness_with_diffusivity_iter, \
                      assemble_matrix_from_iterables
from solve import solve_with_dirichlet_data


def solve_annulus(beta: float, mesh_size: float = 0.1) -> None:
  r"""
    P1 FEM solution of the reaction-diffusion problem:

          -∆u = 0          in  Ω
            u = sin(π/β θ) on ∂Ω

    where Ω is a triangular approximation of a radius 1 annulus with total angle β.

    parameters
    ----------
    mesh_size: float value 0 < mesh_size < 1 tuning the mesh density.
               Smaller value => denser mesh.

  """
  assert beta > 0

  nsteps0 = int(1 / mesh_size)
  nsteps_arc = int(beta / mesh_size)

  # from (0, 0) to (1, 0)
  segment0 = np.stack([np.linspace(0, 1, nsteps0), np.zeros(nsteps0)], axis=1)
  
  xi_arc = np.linspace(0, beta, nsteps_arc)[1:]

  # outer arc
  arc = np.stack([np.cos(xi_arc), np.sin(xi_arc)], axis=1)

  # from (0, 1) to (0, 0)
  segment1 = np.linspace(0, 1, nsteps0)[::-1][1:-1][:, _] * arc[-1:]

  bpoints = np.concatenate([segment0, arc, segment1])

  mesh = Triangulation.from_polygon(bpoints, mesh_size=mesh_size)

  # plot the mesh
  mesh.plot()

  quadrule = seven_point_gauss_6()

  # theta as a function of x
  theta = lambda x: np.arctan2(x[:, 1], x[:, 0]) % (2 * np.pi)

  bindices = mesh.boundary_indices

  # the boundary function u = sin(π/β θ)
  u_trace = lambda x: # ???

  # dirichlet data => u_trace evaluated in boundary points
  data = u_trace(mesh.points[bindices])

  # instantiate the stiffness matrix iterable
  Aiter = # ???

  # pass into the assembly routine to assemble A
  S = # ???

  # rhs vector is 0
  rhs = np.zeros(len(mesh.points))

  # use the `solve.solve_with_dirichlet_data` function with
  # freezedofs = bindices and data = data
  solution_approx = solve_with_dirichlet_data(# ???)

  # plot
  mesh.tripcolor(solution_approx)

  # exponent alpha
  exponent = np.pi / beta

  # r as a function of x
  r = lambda x, keepdims=False: np.linalg.norm(x, axis=1, keepdims=keepdims)

  # exact solution evaluated in x
  exact = lambda x: r(x) ** exponent * np.sin(np.pi / beta * theta(x))

  # radial unit vector
  e_r = lambda x: x / r(x, keepdims=True)

  # unit vector in theta direction, orthogonal to e_r
  e_theta = lambda x: e_r(x)[:, ::-1] * np.array([[-1, 1]])

  # gradient of the exact solution
  dexact = lambda x: exponent * r(x, keepdims=True) ** (exponent - 1) * (
      + np.sin(np.pi / beta * theta(x))[:, _] * e_r(x)
      + np.cos(np.pi / beta * theta(x))[:, _] * e_theta(x)
  )


  dnorm_H1 = compute_H1_norm_difference(mesh, quadrule, solution_approx, exact, dexact, a0=0, a1=1)

  print('The H1 semi-norm of the difference between the approximate and the exact'
        ' solution equals {:.6}.\n'.format(dnorm_H1))
        
  return dnorm_H1


if __name__ == '__main__':
  mesh_sizes = (0.1, 0.05, 0.025, 0.0125)
  betas = (np.pi / 2, 3 * np.pi/2)
  errors_list = [[solve_annulus(beta, mesh_size=mesh_size) for mesh_size in mesh_sizes] for beta in betas]
  
  plot_convergence(mesh_sizes, errors_list, betas, label='β')
