classdef fgs_fge_jacobian_comp_test < meq_jacobian_test
  % Compares FGS and FGE semi-analytical Jacobian with finite-difference Jacobian
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    verbosity = 0;
    tok = 'ana';
    t = 0;
  end

  methods
    function test_code_jacobian_vs_fd(testCase,L,LX,convergence_test)

      testCase.assumeFalse(strcmp(testCase.code,'fge') && convergence_test,'Skipping convergence test for fgeF');

      % Get F and x0 and u0
      [F,x0,u0] = get_F_x0_u0(testCase,L,LX);
      xdot0 = zeros(L.nN,1);
      dojacu = ~convergence_test;
      dojacxdot = ~convergence_test && L.nrD;

      % Compute analytical jacobian
      opts = optsF('dojacx',true,'dojacu',dojacu,'dojacxdot',dojacxdot);
      [~,~,Jx,Ju,Jxdot,rowmask] = F(x0,u0,xdot0,opts);
      
      % Augment Jxdot if needed
      if dojacxdot
        % Since the static equations do not depend on the previous state, Jxdot only as returned by
        % the different FGE operators only contains rows for the residuals corresponding to dynamic
        % equations, to compare with the finite difference estimate, we need to restore all rows.
        Jxdot_ = zeros(L.nN,L.nN);
        Jxdot_(L.ind.irD,:) = Jxdot;
      end

      if ~convergence_test
        mask = logical(rowmask);
        Jxdiag = diag(rowmask);
        % Check that Jx(mask,:) has only diagonal elements corresponding to values in rowmask
        testCase.verifyEqual(full(Jx(mask,:)),Jxdiag(mask,:),...
          'Content of rowmask and Jx do not agree')

        % Check Jacobian function handle
        %   Get jacobian function handle
        opts = optsF('dojacF',true);
        [~,~,JxF] = F(x0,u0,xdot0,opts);

        % Compute full matrix from function handle
        %  Jx
        Jx_ = zeros(L.nN,L.nN );
        dx_ = eye(L.nN);
        for ii = 1:L.nN
          Jx_(:,ii) = JxF(dx_(:,ii));
        end

        % Compare full jacobians and function handle values
        testCase.verifyEqual(full(Jx),Jx_,'AbsTol',100*eps*max(abs(Jx(:))),'Jx full matrix and function handle values do not match');
      end
      
      % Check derivatives with respect to x
      names_out = {[testCase.code,'F']};
      iargout = 1;
      if dojacu && dojacxdot
        names_in = {'x','u','xdot'};
        iargin = [1,2,3];
        J0 = {Jx,Ju,Jxdot_};
      elseif dojacu
        names_in = {'x','u'};
        iargin = [1,2];
        J0 = {Jx,Ju};
      else
        names_in = {'x'};
        iargin = 1;
        J0 = {Jx};
      end

      jacfd_args = {'szF0',{L.nN}};

      if convergence_test
        % Test convergence
        epsval = repmat({testCase.meshDeltaFx},1,numel(iargin));
      else
        % Test baseline error
        epsval = repmat({testCase.deltaFx},1,numel(iargin));
      end

      % Call generic function
      testCase.test_analytical_jacobian(...
        F,{x0,u0,xdot0,optsF},J0,jacfd_args,iargout,names_out,iargin,names_in,epsval);
    end

    function [F,x0,u0,xdot0] = get_F_x0_u0(testCase,L,LX)
      % Get operator function handle and x,u estimate based on LX

      % Setup dt for evolutive equations
      LYp = LX;
      LYp.t = LX.t - 1; % fgeF with u as argument requires dt>0
      x0 = L.LX2x(LX);
      switch testCase.code
        case 'fge'
          % u = [Va;Co;Ini]
          u0 = [LX.Va(:);LX2Co(L,LX);LX.IniD];
          xdot0 = zeros(L.nN,1);

          F = @(x,u,xdot,opts) fgeF(x,L,LX,LYp,opts,[],u,xdot);
        case 'fgs'
          % u = [Ia;Iu;LX2Co(L,LX)]
          u0 = [LX.Ia(:);LX.Iu(:);LX2Co(L,LX)];
          F = @(x,u,xdot,opts) fgeF(x,L,LX,LYp,opts,[],u);
        otherwise
          error('Only supported options for code are ''fgs'' and ''fge''');
      end
    end
  end

end
