classdef meq_jacobian_test < meq_test
  % superclass for meq jacobian test
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  % define properties
  properties (Access = protected)
    fontsize = 12;           % default fontsize for figures from the do_figure method 
    tol = 5.e-4;             % relative tolerance for jacobian error check
    convergence = 2;         % expected convergence rate
    tolConvergence = 0.1;    % tolerance on the convergence rate
    tolAccept = 5.e-10;      % bypass the convergence test if the error is 
                             % below tolAccept

    % flux perturbation
    deltaFx = 1e-6;
    % parameter for deltaFx convergence
    meshDeltaFx = logspace(-6,-5,3);

    %
    L; LX; LY; Ie;

    % Code to be tested
    code;

    % FGS/FGE parameters
    fbttol = 1e-6;     % Tolerance for FBT initial equilibrium
    algoNL = 'all-nl'; % Option for state variant
    algoGMRES = 'aya'; % GMRES variant
  end

  properties (Dependent)
    % Multiple domain cases with tok=ana have higher grid resolution and
    % lower Iy values so we need to change the amplitude of the
    % perturbation too
    % current perturbation
    deltaIyp;
    deltaIe;
    % parameter for deltaI convergence
    meshDeltaIy;
    meshDeltaIe;
  end

  properties (Constant,Hidden)
    doublets_under_test = ...
      {'doublet',82,'droplets',84,'doublet_with_mantle_current',88};
  end

  % Get methods for dependent properties (see definition of properties)
  methods
    function v = get.deltaIyp(testCase)
      if testCase.L.nD>1, v = 1e-2;
      else,               v = 1e-1;
      end
    end

    function v = get.deltaIe(testCase)
      if testCase.L.nD>1, v = 1e-2;
      else,               v = 1e-1;
      end
    end

    function v = get.meshDeltaIy(testCase)
      if testCase.L.nD>1, v = logspace(-3,-2,3);
      else,               v = logspace(-1, 0,3);
      end
    end

    function v = get.meshDeltaIe(testCase)
      if testCase.L.nD>1, v = logspace(-3,-2,3);
      else,               v = logspace(-1, 0,3);
      end
    end
  end
  
  % general  methods common to all tests
  methods
    function [L,LX] = get_L_LX(testCase,shot,t,varargin)
      % Get L and LX for code with specified settings

      % Common parameters
      PP = {testCase.tok,shot,t,...
           'debug',testCase.verbosity,...
           'tol',testCase.fbttol,...          % Tolerance for FBT initial equilibrium
           'tolF',1e-10,...                   % Tolerance for solveF
           'algoNL',testCase.algoNL,...       % Set state variant
           'algoF','jfnk',...                 % Use JFNK method (often faster)
           'anajac',true,...                  % Use analytical jacobian
           'algoGMRES',testCase.algoGMRES,... % Option for GMRES variant
           'usepreconditioner',true,...       % Preconditioner
           'selu','e','nu',20,...             % Vessel settings
           };

      % include varargin for the case that one wants to overwrite these defaults
      PP = [PP, varargin]; 

      switch testCase.code
        case 'fge'
          [L,LX] = fge(PP{:});
        case 'fgs'
          [L,LX] = fgs(PP{:});
        otherwise
          error('Only supported options for code are ''fgs'' and ''fge''');
      end
    end

    function [LY] = get_LY(testCase,L,LX)
      % Run code time loop
      switch testCase.code
        case 'fge'
          LY = fget(L,LX);
        case 'fgs'
          LY = fgst(L,LX);
        otherwise
          error('Only supported options for code are ''fgs'' and ''fge''');
      end
    end

    %% initialisation
    function get_fgs_LLX(testCase,shot,t)
      % Computes FGS initial guess (FBT solution)
      [testCase.L,testCase.LX] = ...
        fgs(testCase.tok,shot,t,'debug',testCase.verbosity,'tol',testCase.fbttol);
    end
    
    %% accept error
    % check the slope and the convergence error
    function [acceptAll] = accept_convergence_test(testCase,slope,error,tol_error)
      if nargin < 4
        tol_error = testCase.tolAccept;
      end

      slope_expect = testCase.convergence;
      tol_slope = testCase.tolConvergence;

      accept = slope(:) >= slope_expect - tol_slope;           % If slope is larger than expected, pass
      accept = accept | (sum(error(:,:) > tol_error,1) < 2).'; % If no more than 2 points are above tol_error, pass
      acceptAll = all(accept);
    end

    %% Test error on analytical jacobian
    function test_analytical_jacobian(testCase,fun,x0,J0,jacfd_args,...
        iargout,names_out,iargin,names_in,epsvals)
      % Performs comparison of the finite difference approximation of the
      % jacobian of function FUN at point X0 with the analytical value of the
      % jacobian J0.
      % The argument IARGOUT is used to select a subset of the return
      % arguments of the function FUN. The cell array of strings NAMES_OUT
      % is the list of names of these output arguments used for diagnostics
      % purposes.
      % Similarly IARGIN is used to select a subset of the input arguments
      % to be varied. The cell array of strings NAMES_IN is the list of
      % names of these input arguments. The cell array EPSVALS contains the
      % list of perturbation amplitude for each argument, if a scalar is
      % given, a simple comparison is performed, if a vector is provided,
      % then a convergence study is performed instead.
      % J0 should be a cell array of strings with as many rows as elements
      % in IARGOUT and as many columns as elements in IARGIN

      % Ensure inactive domains aren't changed by finite difference perturbation
      iF0 = iargin(strcmp(names_in,'F0'));
      iF1 = iargin(strcmp(names_in,'F1'));
      if ~isempty(iF0) && ~isempty(iF1)
        nD = numel(x0{iF0});
        if nD>1
          fun = @(varargin) fun_fixed_inactive_domains(fun,x0{iF0},x0{iF1},iF0,iF1,varargin{:});
        end
      end

      % Evaluate F
      nargoutF = max(iargout);
      F0 = cell(1,nargoutF);
      [F0{:}] = fun(x0{:});
      F0 = F0(iargout);

      % Loop over input arguments
      nin = numel(iargin);
      nout = numel(iargout);
      for iin = 1:nin
        epsval = epsvals{iin};

        if isscalar(epsval)
          % Standard comparison between analytical and numerical jacobian
          
          % Evaluate F jacobian
          J0_fd = cell(1,nargoutF);
          [J0_fd{:}] = jacfd(fun,x0,jacfd_args{:},'iargin',iargin(iin),'epsval',epsval);
          J0_fd = J0_fd(iargout);

          % Compare values
          for iout = 1:nout
            if isempty(F0{iout}), continue; end
            szeps = ones(1,ndims(J0{iout,iin}));
            szeps(end) = numel(epsval);
            % Compute best case (minimum) absolute error on jacobian
            % Since the FD estimate is based on the difference of close numbers
            % there is a lower limit on the absolute precision that the FD
            % estimate can reach, we will discard points for which the error is
            % close to that number
            J0_err_est = eps(abs(F0{iout}))./(2*reshape(epsval,szeps));
            % Avoid inf/nan
            J0_err_est(~isfinite(J0_err_est)) = 0;
            J0_fd{iout}(~isfinite(J0_fd{iout})) = 0;
            
            % Check analytical jacobian values
            if isscalar(J0{iout,iin})
              % this is done such that the verification is also valid if the
              % jacobian is an array of zeros but the out_derivs contains only
              % a single zero for simplicity
              J0{iout,iin} = repmat(J0{iout,iin},[1,numel(x0{iargin(iin)})]);
            end

            message = sprintf('Comparison of analytical and numerical jacobian of %s with respect to %s failed',names_out{iout},names_in{iin});
            abstol = max(100*max(abs(J0_err_est(:))),testCase.tol*max(abs(J0{iout,iin}(:))));
            testCase.verifyEqual(J0_fd{iout},full(J0{iout,iin}),'AbsTol',abstol,message);
          end
        else
          % Convergence test
          
          mesh = epsval;
          nmesh = numel(mesh);
          % Initialize error vectors
          J0_err = cell(1,nout);
          for iout = 1:nout
            J0_err{iout} = zeros(nmesh,1);
          end
        
          % Perform scan in perturbation amplitude
          for imesh = 1:nmesh
            % Evaluate F jacobian
            J0_fd = cell(1,nargoutF);
            [J0_fd{:}] = jacfd(fun,x0,jacfd_args{:},'iargin',iargin(iin),'epsval',mesh(imesh));
            J0_fd = J0_fd(iargout);

            % Compute infinite norm of error
            for iout = 1:nout
              if isempty(J0{iout}), continue; end
              % Avoid inf/nan
              J0_fd{iout}(~isfinite(J0_fd{iout})) = 0;
              % Compute error
              J0_err{iout}(imesh) = max(abs(J0_fd{iout}(:) - J0{iout,iin}(:)));
            end
          end

          % Compute convergence slope
          for iout = 1:nout
            if isempty(F0{iout}), continue; end

            % Compute best case (minimum) absolute error on jacobian
            % Since the FD estimate is based on the difference of close numbers
            % there is a lower limit on the absolute precision that the FD
            % estimate can reach, we will discard points for which the error is
            % close to that number
            J0_err_est = eps(max(abs(F0{iout}(:))))./(2*mesh(:));

            % Discard points with too low error.
            mask = J0_err{iout} > 100*J0_err_est;
            slope=testCase.compute_slope(mesh(mask),J0_err{iout}(mask));

            % Plot if requested
            if testCase.verbosity
              meq_jacobian_test.do_figure(mesh,...
                J0_err{iout},1,(iout-1)+(iin-1)*nout,...
                sprintf('Convergence of %s',names_out{iout}),...
                sprintf('\\Delta %s',names_in{iin}),'error',...
                testCase.fontsize);
            end

            % Verify slope
            message = sprintf('Convergence of numerical jacobian of %s with respect to %s failed',names_out{iout},names_in{iin});
            acceptAll = testCase.accept_convergence_test(slope,J0_err{iout},100*J0_err_est);
            testCase.verifyTrue(acceptAll,message);
          end
        end
      end
    end

  end

  methods(Static,Access = protected)
    %% regressions  
    % method for computing gradient slopes
    function [slope]=compute_slope(mesh,vectErrordFdF)
      % Do a power law fit, using a linear fit in log-log
      fit = [ones(size(mesh(:))),log10(mesh(:))]\log10(vectErrordFdF(:,:));
      % Reshape result to match dimensions 2-end of vectErrordFdF
      sz = size(vectErrordFdF);
      sz(end+1:3) = 1; % At least 3 dimensions
      slope = reshape(fit(2,:),sz(2:end));
    end
    
    %% extract indices to test
    % this method is used for extracting a subset of indices within each
    % domain in Y [vacuum (Opy==0) and plasma (Opy==[1,nD])] for testing. 
    % ntest indices within each domain of Y are chosen randomly with uniform distribution.
    function [testIdList] = extractIdtoTest(ntest,Opy,rngSeed,rngType)
      if(ntest>0)
        % select a subset of nodes for testing  
        if(nargin>3), rng(rngSeed,rngType); end
        nD = max(Opy(:));
        testIdList = zeros(ntest*(nD+1),1);
        for ii=0:nD
          domainId = find(Opy==ii);
          testIdList(ntest*ii+1:(ii+1)*ntest) = domainId(randi([1,numel(domainId)],ntest,1));        
        end
      else
        % test all nodes  
        nOpy = numel(Opy);
        testIdList = linspace(1,nOpy,nOpy)';  
      end
    end
  
    %% plots
    function do_figure(x,y,dim2,figoffset,titl,labelx,labely,fontsize)
      % do figures
      if(numel(dim2)==1), dim2=[1,dim2]; end
      for i=1:dim2(1)
        figure(figoffset+i);clf;
        set(gca,'fontsize',fontsize,'Xscale','log','Yscale','log');
        hold on;
        for j=1:dim2(2)
          plot(x,y(:,j,i),'linewidth',3);
        end
        hold off;
        title(sprintf('%s %i\n',titl,i-1));
        xlabel(labelx);
        ylabel(labely);
        grid on;
      end
      drawnow
    end    
  end
end

function varargout = fun_fixed_inactive_domains(fun,F0,F1,iF0,iF1,varargin)
% Generic helper function that will not vary F0/F1 for inactive domains
if numel(varargin) >= iF1
  maskD = (F0 == F1); % Inactive domains
  varargin{iF0}(maskD) = 0; % F0
  varargin{iF1}(maskD) = 0; % F1
end
% Call function
[varargout{1:nargout}] = fun(varargin{:});
end
