classdef bf_jacobian_test < meq_jacobian_test
  % test the analytical jacobian of the basis function sets
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    verbosity = 0;
    tok = 'ana';
    bfct_;
    bfp_;
  end
    
  properties(ClassSetupParameter, Abstract)
    shot
  end
  
  properties (MethodSetupParameter, Abstract)
    bfct
    bfp
  end

  properties (TestParameter)
    convergence_test = struct('false',false,'true',true);
  end
  
  methods(TestClassSetup, ParameterCombination='sequential')
    function setup(testCase,shot)
      % Compute FGS initial guess (FBT solution)
      t = 0;
      testCase.get_fgs_LLX(shot,t);
    end
  end

  methods(TestMethodSetup, ParameterCombination='sequential')
    function bfsetup(testCase,bfct,bfp)
      testCase.bfct_ = bfct;
      testCase.bfp_  = bfp;
    end
  end
  
  methods(Test,TestTags={'Jacobian'})
    function bf_mode1_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(1,delta);

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      Fx = LX.Fx;
      F0 = LX.F0;
      F1 = LX.F1;
      Opy = LX.Opy;

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {1,testCase.bfp_,Fx,F0,F1,Opy,L.ry,L.iry};

      % Get analytical gradients
      Tyg = fun(x0{:});
      [dTygdFy,dTygdF0,dTygdF1,dITygdF0,dITygdF1] = fun(11,x0{2:end});
      
      % Sizes
      ng = size(Tyg,2);
      ny = L.ny;
      nx = L.nx;
      nD = L.nD;

      % Convert to derivatives of Tyg, TpDg, ITpDg
      dTygdFx   = zeros(ny,ng,nx);
      dTpDgdFx  = zeros(nD,ng,nx);
      dTpDgdF1  = zeros(nD,ng,nD);
      dTpDgdF0  = zeros(nD,ng,nD);
      dITpDgdFx = zeros(nD,ng,nx);
      dITpDgdF1 = zeros(nD,ng,nD);
      dITpDgdF0 = zeros(nD,ng,nD);
      iy2x = find(L.lxy);
      ind = sub2ind([ny,ng,nx],repmat((1:ny).',1,ng),repmat(1:ng,ny,1),repmat(iy2x(:),1,ng));
      dTygdFx(ind) = dTygdFy;
      for iD = 1:L.nD
        mask = Opy(:) == iD;
        dTpDgdFx (iD,:,L.lxy) = (dTygdFy.*mask).';
        dTpDgdF0 (iD,:,    :) = sum(dTygdF0.*mask, 1);
        dTpDgdF1 (iD,:,    :) = sum(dTygdF1.*mask, 1);
        dITpDgdFx(iD,:,L.lxy) = (Tyg.*mask).';
        dITpDgdF0(iD,:,    :) = sum(dITygdF0.*mask, 1);
        dITpDgdF1(iD,:,    :) = sum(dITygdF1.*mask, 1);
      end

      % Check jacobians of all mode 1 outputs
      names_out   = {'Tyg','TpDg','ITpDg'};
      iargout = [1,2,3];
      names_in = {'Fx','F0','F1'};
      iargin = [3,4,5];
      J0 = {  dTygdFx,  dTygdF0,  dTygdF1; ...
             dTpDgdFx, dTpDgdF0, dTpDgdF1; ...
            dITpDgdFx,dITpDgdF0,dITpDgdF1};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end

    function bf_mode2_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(2,delta);

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      FQ = L.pQ.^2;
      F0 = LX.F0;
      F1 = LX.F1;

      % Remove FQ=1 for bffbt (singular point)
      if strcmp(testCase.bfct_,'bffbt'), FQ = FQ(1:end-1); end

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {2,testCase.bfp_,FQ,F0,F1};

      % Get analytical gradients
      [dgQdFQ_,dgQdF0,dgQdF1,dIgQdFQ_,dIgQdF0,dIgQdF1] = fun(12,x0{2:end});
      
      % F0 sizes
      nD = L.nD;
      nQ = numel(FQ);
      if nD==1
        ng = size(dgQdFQ_,2);
      else
        ng = size(dgQdFQ_,3);
      end

      % Include non-diagnonal terms for jacobians with respect to FQ
      dgQdFQ = zeros(nQ,nD,ng,nQ);
      dIgQdFQ = dgQdFQ;
      ind = sub2ind([nQ,nD*ng,nQ],repmat((1:nQ).',1,nD*ng),repmat(1:nD*ng,nQ,1),repmat((1:nQ).',1,nD*ng));
      dgQdFQ(ind)  =  dgQdFQ_;
      dIgQdFQ(ind) = dIgQdFQ_;
      if nD==1
        dgQdFQ  = reshape( dgQdFQ,[nQ,ng,nQ]);
        dIgQdFQ = reshape(dIgQdFQ,[nQ,ng,nQ]);
      end

      % Check jacobians of all mode 2 outputs
      names_out = {'gQDg','IgQDg'};
      iargout = [1,2];
      names_in = {'FQ','F0','F1'};
      iargin = [3,4,5];
      J0 = { dgQdFQ, dgQdF0, dgQdF1;
            dIgQdFQ,dIgQdF0,dIgQdF1};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end

    function bf_mode3_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(3,delta);

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      F0 = LX.F0;
      F1 = LX.F1;

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      [fPg,fTg] = fun(0,testCase.bfp_);
      ag = ones(numel(fPg),1);
      x0 = {3,testCase.bfp_,ag,F0,F1,fPg,fTg,L.idsx};

      % Get analytical gradients
      [daPpgdF0,daTTpgdF0,daPgdF0,dahqTgdF0,...
       daPpgdF1,daTTpgdF1,daPgdF1,dahqTgdF1] = fun(13,x0{2:end});
      
      % F0 sizes
      ng = numel(fPg);
      szF0 = repmat({ng},1,8);

      % Check jacobians of all mode 3 outputs
      names_out = {'aPpg','aTTpg','aPg','ahqTg'};
      iargout = [1,2,3,4];
      names_in = {'F0','F1'};
      iargin = [4,5];
      J0 = { daPpgdF0, daPpgdF1;
            daTTpgdF0,daTTpgdF1;
              daPgdF0,  daPgdF1;
            dahqTgdF0,dahqTgdF1};
          
      jacfd_args = {'szF0',szF0};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,jacfd_args,iargout,names_out,iargin,names_in,epsval);
    end

    function bf_mode5_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(5,delta);

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      F0 = LX.F0;
      F1 = LX.F1;

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {5,testCase.bfp_,[],F0,F1};

      % Get analytical gradients
      [dg0gdF0,dg0gdF1,dIg0gdF0,dIg0gdF1] = fun(15,x0{2:end});

      % Check jacobians of all mode 5 outputs
      names_out = {'g0g','Ig0g'};
      iargout = [1,2];
      names_in = {'F0','F1'};
      iargin = [4,5];
      J0 = { dg0gdF0, dg0gdF1; ...
            dIg0gdF0,dIg0gdF1};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end

    function bf_mode6_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(6,delta);

      % Prepare inputs
      L = testCase.L;
      LX = testCase.LX;
      F0 = LX.F0;
      F1 = LX.F1;
      r0 = ones(L.nD,1);
      r0(1:LX.nA) = LX.rA(1:LX.nA);
      r0(LX.nA+1:LX.nB) = LX.rB(1:LX.nB-LX.nA);

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {6,testCase.bfp_,[],F0,F1,r0,1./r0,L.idsx};

      % Get analytical gradients
      [dQqgdF0,dQqgdF1,dQqgdr0,dQqgdir0] = fun(16,x0{2:end});

      % Check jacobians of first mode 6 output only
      names_out = {'Qqg'};
      iargout = 1;
      names_in = {'F0','F1','r0','ir0'};
      iargin = [4,5,6,7];
      J0 = {dQqgdF0,dQqgdF1,dQqgdr0,dQqgdir0};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end

    function bf_mode7_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(7,delta);

      % Prepare inputs
      LX = testCase.LX;
      F0 = LX.F0;
      F1 = LX.F1;

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {7,testCase.bfp_,[],F0,F1};

      % Get analytical gradients
      [dQcgdF0,dQcgdF1] = fun(17,x0{2:end});

      % Check jacobians of first mode 7 output only
      names_out = {'Qcg'};
      iargout = 1;
      names_in = {'F0','F1'};
      iargin = [4,5];
      J0 = {dQcgdF0,dQcgdF1};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end
    
    function bf_mode91_gradient(testCase,convergence_test)

      % Check shot/bfct/bfp is appropriate
      if convergence_test
        delta = max(testCase.meshDeltaFx);
      else
        delta = testCase.deltaFx;
      end
      testCase.check_bfct(91,delta);

      % Prepare inputs
      L  = testCase.L;
      LX = testCase.LX;
      Fx = LX.Fx;
      F0 = LX.F0;
      F1 = LX.F1;

      % Prepare arguments
      fun = str2func(testCase.bfct_);
      x0 = {91,testCase.bfp_,Fx,F0,F1};

      % Get analytical gradients
      [dGdFx_,dGdF0,dGdF1,dIGdFx_,dIGdF0,dIGdF1] = fun(92,x0{2:end});
      
      % Convert dGdFx, dIGdFx
      ng = size(dGdFx_,2);
      dGdFx = zeros(L.nx, ng, L.nx);
      dIGdFx = dGdFx;
      for ig=1:ng
         dGdFx(:, ig, :) = reshape(diag( dGdFx_(:, ig)), L.nx, 1, L.nx);
        dIGdFx(:, ig, :) = reshape(diag(dIGdFx_(:, ig)), L.nx, 1, L.nx);
      end

      % Check jacobians of first mode 7 output only
      names_out = {'G','IG'};
      iargout = [1,2];
      names_in = {'Fx','F0','F1'};
      iargin = [3,4,5];
      J0 = {dGdFx,dGdF0,dGdF1;dIGdFx,dIGdF0,dIGdF1};
      
      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end
      % Call generic function
      testCase.test_analytical_jacobian(...
        fun,x0,J0,{},iargout,names_out,iargin,names_in,epsval);
    end
  end
  

  % Helper methods
  methods
    function check_bfct(testCase,mode,deltaFx)

      fun = testCase.bfct_;
      par = testCase.bfp_;
      
      isdoubletbf = ismember(fun, {'bfgenD', 'bfdoublet'});
      % Test only bfgenD or bfdoublet for shots 82/84
      testCase.assertFalse(testCase.L.nD >  1 && ~isdoubletbf,...
        'Do not test single-domain bf for multi-domain cases');
      % Do not test bfgenD or bfdoublet for single domain cases
      testCase.assertFalse(testCase.L.nD == 1 &&  isdoubletbf,...
        'Test bfgenD only for multi-domain cases');
      
      % Check deltaFx for bf3imex/bfsp
      % For bf sets with interpolation, we want to avoid that a grid point jumps
      % from one interpolation cell to the next due to the perturbation as this
      % would create a large error when computing the finite difference jacobian
      switch fun
        case 'bf3imex'
          knots = linspace(0,1,size(par.gNg,1));
        case 'bfsp'
          knots = [par.tauP; par.tauT];
        otherwise
          return;
      end

      % bf3i/bfsp are not tested with multiple domains
      switch mode
        case {1,91}
          FxA = testCase.LX.Fx(:) - testCase.LX.FA;
          FBA = testCase.LX.FB    - testCase.LX.FA;

          testCase.assumeTrue(min(abs(FxA - knots(:).'*FBA),[],[1,2])>deltaFx,...
            sprintf('Skipping test for bfct=%s due to point being too close to interpolation knot', fun));
        case 2
          FQ = testCase.L.pQ.^2;
          testCase.assumeTrue(min(abs(FQ - knots(:)),[],[1,2])>deltaFx,...
            sprintf('Skipping test for bfct=%s due to point being too close to interpolation knot', fun));
      end
    end
  end
end
