classdef linear_solver_test < meq_solver_test
  % Test the linear solver functions
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.

  properties(ClassSetupParameter)
    nb = {1,20}
  end

  properties(TestParameter)
    algoGMRES = {'direct_inversion','matlab_gmres','sim','giv','aya','qr'};
    hasX0 = {false,true};
  end

  methods (Test, TestTags = {'Unit'})
    function test_invert_jacobian(testCase,algoGMRES)
      % Test invert_jacobian function

      A    = testCase.A_;
      b    = testCase.b_;
      m    = testCase.m_;
      tol  = testCase.tol_;
      Pinv = testCase.Pinv_;

      % Parameter structure
      P = struct();
      P.algoGMRES = algoGMRES;
      P.mkryl = m;
      P.epsilon_res = tol;

      % Directly copied from invert_jacobian
      isJFunctionHandle = isa(A,'function_handle');
      isSingleRHS = size(b,2) == 1;
      if isJFunctionHandle && isSingleRHS
        GMRES_opts = {'matlab_gmres','sim','giv','aya'};
        descr = 'J as a function handle and 1 RHS';
      elseif isSingleRHS
        GMRES_opts = {'direct_inversion','matlab_gmres','sim','giv','aya'};
        descr = 'J as a matrix and 1 RHS';
      elseif isJFunctionHandle
        GMRES_opts = {'sim','giv','qr'};
        descr = 'J as a function handle and multiple RHS';
      else
        GMRES_opts = {'direct_inversion','sim','giv','qr'};
        descr = 'J as a matrix and multiple RHS';
      end
      expected_error = ~any(strcmp(P.algoGMRES,GMRES_opts));

      if expected_error
        testCase.verifyError(@() invert_jacobian(A,b,Pinv,P),'invert_jacobian:algoGMRES',sprintf('algoGMRES=%s not compatible with %s',P.algoGMRES,descr));
      else
        x = invert_jacobian(A,b,Pinv,P);

        % Check result
        testCase.check_result(x);
      end
    end
    
    function test_solve_linear_problem(testCase,hasX0)
      % Test solve_linear_problem function

      A    = testCase.A_;
      b    = testCase.b_;
      m    = testCase.m_;
      tol  = testCase.tol_;
      Pinv = testCase.Pinv_;

      isSingleRHS = size(b,2) == 1;

      % Choose GMRES algorithm suited for the task
      if isSingleRHS
        algoGMRES_ = 'aya';
      else
        algoGMRES_ = 'qr';
      end

      % Parameter structure
      P = struct();
      P.algoGMRES = algoGMRES_;
      P.mkryl = m;
      P.epsilon_res = tol;
      P.userowmask = false;

      % Provide initial guess
      if hasX0
        x0 = eye(size(b)); % Good enough for tests
      else
        x0 = [];
      end
      
      x = solve_linear_problem(A,b,x0,Pinv,[],[],P);

      % Check result
      testCase.check_result(x);
    end
  end
end
