classdef meqagconfun_jacobian_test < meq_jacobian_test
  % Test to verify the jacobians of the meqconfun functions
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    verbosity = 0;
    tok = 'ana';
  end
  
  properties (ClassSetupParameter)
    shot = struct('circular',1,'diverted',2,'diverted2',3,'squashed',5,...
      meq_jacobian_test.doublets_under_test{:});
  end

  properties (TestParameter)
    % Convergence test skipped due to conditioning problems
    %   (res is small but precision on res is poor)
    convergence_test = struct('false',false);
  end
  
  methods(TestClassSetup)
    function testSetup(testCase, shot)
      % Compute FGS initial guess (FBT solution)
      t = 0;
      testCase.get_fgs_LLX(shot,t);
    end
  end
  
  methods(Test,TestTags={'Jacobian'})
    function meqagconfun_finite_difference_test(testCase,convergence_test)
      % Compare analytical derivatives with finite difference estimates for
      % the different meqagconfuns.
      %
      % There is a loop over all available functions (ag,Ip,bp,...)
      % and a second loop to vary the index of each array that is being
      % constrained (e.g. we test the constraint over ag(1) then ag(2)
      % etc.)
      %
      % We use a wrapper around the meqagconfuns to allow easy finite
      % difference estimates using jacfd.

      % List of available functions
      S = meqagconfun();

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      LY = LX;
      
      % need these also for the ag funs
      [~,TpDg,ITpDg] = testCase.L.bfct(1,L.bfp,LY.Fx,LY.F0,LY.F1,LY.Opy,L.ry,L.iry);

      fun = @meqagconfun_Co;
      agconc = {S.Ip,1,'Ip',1}; % Default value
      x0 = {agconc,L,LX,LY.F0,LY.F1,LY.rA,LY.dr2FA,LY.dz2FA,LY.drzFA, ...
        LY.ag,LY.Fx,LY.Opy,TpDg,ITpDg,1,LX.IpD(1)};

      % orders of meqagconfun derivatives output
      names_in = {'F0','F1','ag','Fx','TpDg','ITpDg', ...
        'rA','dr2FA','dz2FA','drzFA','Co'};
      % Index of variables in meqagcon inputs
      iargin = [4,5,10,11,13,14,6,7,8,9,16];

      % get names of the ag constraint functions
      agconfuns = fieldnames(S);
      for k=1:numel(agconfuns)
        % - agconc: cell array containing function handle and name of LX
        % field corresponding to constraint. 2nd and 4th element are
        % ignored
        % - ni: number of possible values for ii (last argument)
        name = agconfuns{k};
        LXfield = name;
        switch name
          case 'ag'
            agconc = {S.ag,1,'ag',1};
            ni = L.ng;
          case 'qA'
            agconc = {S.qA,1,'qA',1};
            ni = LX.nA;
          case 'fbtlegacy'
            continue
          otherwise
            if L.nD>1
              LXfield = [name,'D'];
            end
            agconc = {S.(agconfuns{k}),1,LXfield,1};
            ni = L.nD;
        end
        x0{1} = agconc;

        % Loop over all constraint indices
        for ii = 1:ni
          x0{end-1} = ii;
          x0{end}   = LX.(LXfield)(ii);

          % Set size of input perturbations based on input infinite norm (one value per input variable)
          dxscal = 10.^(log10(cellfun(@(x) max(abs(x(:))),x0(iargin))));
          dxscal(dxscal == 0) = 1;
        
          J0 = cell(1, numel(iargin)); % the residual and 11 derivatives are returned
          [~, J0{:}] = fun(x0{:});

          % Check meqagconfun output
          names_out = {func2str(fun)};
          iargout = 1;

          jacfd_args = {'szF0',{1}};

          if convergence_test
            % Test convergence
            epsval = arrayfun(@(x) x*testCase.meshDeltaFx, dxscal, 'UniformOutput', false);
          else
            % Test baseline error
            epsval = num2cell(dxscal*testCase.deltaFx);
          end
          % Call generic function
          testCase.test_analytical_jacobian(...
            fun,x0,J0,jacfd_args,iargout,names_out,iargin,names_in,epsval);
        end
      end
    end
  end
end

function varargout = meqagconfun_Co(agconc,L,LX,F0,F1,rA,dr2FA,dz2FA,drzFA,ag,Fx,Opy,TpDg,ITpDg,ii,Co)
% Helper function to enable computing derivative of meqagconfun w.r.t Co
% using finite differences.
fun = agconc{1};
LXfield = agconc{3};
LX.(LXfield)(ii) = Co;
[varargout{1:nargout}] = fun(L,LX,F0,F1,rA,dr2FA,dz2FA,drzFA,ag,Fx,Opy,TpDg,ITpDg,ii);
end
