classdef fgelin_test < meq_test
  % Test of linearized fge with open loop response to different perturbations
  % Test fge NL against fge linear with fget and matlab lsim
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    verbosity = 0; % >0 makes comparison plots
    dt = 1e-4;  % Simulation dt for fge
    L_
    LX_
    pert_
  end

  properties (ClassSetupParameter)
    casename = {'stable_CDE','stable_fixed_IP','stable_fixed_ag','stable_fixed_IP_full_pp','doublet_fixed_IP','stable_TCV_CDE'}
  end

  properties(TestParameter)
    algoNL = {'all-nl','all-nl-Fx','Newton-GS'}
  end
  
  methods
    function [shot,cde,tok,t0,tint,pert,iterq,agcon,usecs] = setup_fge_sim(testCase,casename)
      % Defaults (ANAMAK circular without CDE)
      shot = 1;
      cde = '';
      tok = 'ana';
      t0 = 0;
      tint = 5e-3;
      pert = 'all_step';
      iterq = 0;
      agcon = {'Ip','bp','qA'};
      usecs = false;
      %
      switch casename
        case 'stable_CDE'
          % ANAMAK circular with CDE
          cde = 'OhmTor_rigid_0D';
          pert = 'Va_step';
          agcon = {'bp','qA'};
        case 'stable_fixed_IP'
          % ANAMAK circular without CDE
        case 'stable_fixed_ag'
          % ANAMAK circular without CDE and prescribed ag
          pert = 'Va_sin'; % No point in perturbing Ip/bp/qA
          agcon = {'ag','ag','ag'};
        case 'stable_fixed_IP_full_pp'
          % ANAMAK circular without CDE with iterq>0 (post-processing only)
          iterq = 20;
        case 'doublet_fixed_IP'
          % ANAMAK doublet without CDE
          shot = 82;
          usecs = true;
        case 'stable_TCV_CDE'
          % TCV limited with CDE
          shot = 54653;
          cde = 'OhmTor_rigid_0D';
          tok = 'tcv';
          t0 = 0.1;
          pert = 'Va_step';
          agcon = {'bp','qA'};
      end
      % Check this TOK choice
      [ok,msg] = meq_test.check_tok(tok);
      testCase.assumeTrue(ok,msg);
    end
  end

  methods(TestClassSetup)
    function setup_fge(testCase,casename)
      % Setup FGE L structure with initial condition LX

      % Get simulation parameters
      [shot, cde, tok, t0, tint, pert, iterq, agcon, usecs] = testCase.setup_fge_sim(casename);

      dt_ = testCase.dt;
      t = (t0:dt_:t0 + tint);
      
      %% Compute also interpolation for Fn,Brn,Bzn in the center of the tokamak, and z=0,r=min(rl)
      [L] = fge(tok,shot,t);
      rn = [mean([max(L.G.rl),min(L.G.rl)]);min(L.G.rl)];
      zn = [mean([max(L.G.zl),min(L.G.zl)]);0];
      
      %% FGE parameters
      PP = {'algoNL','all-nl','selu','e','nu',20,'tolF',1e-10,'cde',cde,...
            'infct',@qintmex,'rn',rn,'zn',zn,'iterq',iterq,'agcon',agcon,...
            'ilu_droptol',1e-4};
      if usecs
        PP = [PP,{'icsint',true,'ilim',3}];
      end
      
      %% get fge structure
      % L, LX structure for linear and non linear sim
      [L,LX] = fge(tok,shot,t,PP{:});
      L.G.Vadelay(:) = 0; % Non-zero delay is incompatible with MATLAB's lsim since returned state-space models assume zero delay
      
      % Compute linearization
      L = fgel(L,meqxk(LX,1));

      %% Setup class instance
      testCase.L_ = L;
      testCase.LX_ = LX;
      testCase.pert_ = pert;
    end
  end
  
  methods(Test,TestTags={'fgelin'})
    function fge_ss_lin(testCase)
      % Compare FGE linear and non-linear simulations
      %
      % The main sources of diagreement even for relatively small
      % perturbations are the shift of the axes or boundary points from one
      % grid cell to the next and the change of the plasma domain
      % identification matrix Opy. While the effect of the first can be
      % somewhat mitigated using cs interpolationn, the second one always
      % produces jumps in many quantities (such as rIp, zIp, Ft, Wk, li,
      % ag) while the time derivative remains close.
      % While one could design tests that do not yield such shifts relative 
      % to the computational grid, some of the current tests yield a more
      % realistic comparison of linear and non-linear sims and the
      % tolerance has been increased to reflect this.

      % Get pre-computed L and LX structures for FGE
      L = testCase.L_;
      LX = testCase.LX_;

      % Get perturbation type
      pert = testCase.pert_;

      t = LX.t;
      nt = numel(t);
      dt_ = testCase.dt;
      
      %% Set inputs to exact comparison
      % Constant inputs constraints
      LX0 = meqxk(LX,1);
      for ii = 1:size(L.agconc,1)
        field = L.agconc{ii,3};
        LX.(field) = repmat(LX0.(field),1,nt);
      end
      
      %% Perturbation to inputs
      switch pert
        case 'Va_step'
          offset = 1e-3; % Offset for step ms
          stepize = 10; % Step size
          icoil = 1; % Coil index
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          dVa = zeros(L.G.na,nt);
          dVa(icoil,tind:end) = stepize;
          LX.Va = LX.Va + dVa;
        case 'Va_sin'
          % Sinusoidal perturbations on all coils, stagger start of perturbations
          dVa = 1e-2*L.Va0.*sin(0.2e3*2*pi*LX.t-(1:L.G.na)');
          dVa = dVa.*(0.2e3*2*pi*LX.t>(1:L.G.na)'); % Delay perturbations
          LX.Va = LX.Va + dVa;
        case 'bp_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.bp(1,tind:end) = alpha*LX.bp(1,tind:end);
        case 'Ip_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.Ip(1,tind:end) = alpha*LX.Ip(1,tind:end);
        case 'qA_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.qA(1,tind:end) = alpha*LX.qA(1,tind:end);
        case 'all_step'
          offset_Ip = 1e-3;  alpha_Ip = 1.01;
          offset_bp = 2e-3;  alpha_bp = 1.01;
          offset_qA = 3e-3;  alpha_qA = 1.01;
          offset_Va = 4e-3;  dVa = 10;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_Ip));
          if ~L.P.idoublet
            LX.Ip(1,tind:end) = alpha_Ip*LX.Ip(1,tind:end);
          else
            LX.IpD(1,tind:end) = alpha_Ip*LX.IpD(1,tind:end);
          end
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_bp));
          if ~L.P.idoublet
            LX.bp(1,tind:end) = alpha_bp*LX.bp(1,tind:end);
          else
            LX.bpD(1,tind:end) = alpha_bp*LX.bpD(1,tind:end);
          end
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_qA));
          LX.qA(1,tind:end) = alpha_qA*LX.qA(1,tind:end);
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_Va));
          LX.Va(1,tind:end) = LX.Va(1,tind:end) + dVa;
        case 'Ini'
          offset = 3e-3; % [s]
          dIni = -1e5;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.Ini(1,tind:end) = LX.Ini(1,tind:end) + dIni;
        otherwise
      end
      
      %% NL sim
      LYnl = fget(L,LX);
      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LYnl),'NL sim did not converge for all time slices, aborting...')

      %% State-space linearized system A,B,C,D
      L = fgel(L, LX0);
      sys = fgess(L); % Continuous time
      [sysd,~,x0,~,L.lin.x0dotL,y0] = fgess(L,dt_); % Discrete time

      %% compare poles
      polec = esort(pole(sys));
      poled = dsort(pole(sysd));

      % info
      fprintf('tok:%s, shot=%d, growthrate = %2.2f/s\n',L.P.tok,L.P.shot,polec(1)/(2*pi));
      % compare discrete-time pole to real one
      testCase.verifyEqual(polec(1),(poled(1)-1)/dt_,'RelTol',0.03)

      %% Linear sim with fget stepper
      LYli = fget(L,LX, 'lin', true);
      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LYli),'fgelin sim did not converge for all time slices, aborting...')

      %% Linear sim with matlab stepper
      % Set inputs for linear simulation
      tliss = t;
      u = zeros(numel(tliss),numel(sys.InputName));

      % Get constraints
      Co = LX2Co(L,LX);

      % precompute dCodt
      dCodt = [diff(Co')/dt_;zeros(1,size(Co,1))]; % Add the last value at 0 since diff remove dimension
      u(:,sys.InputGroup.Co) = Co';
      u(:,sys.InputGroup.dCodt) = dCodt;

      if L.icde, u(:,sys.InputGroup.Ini) = LX.Ini; end

      % fget time stepper for feedforward voltage assumes
      %  x[n] = A x[n-1] + B VA[n]
      % Shift input voltages for machine precision comparison
      u(:,sys.InputGroup.Va) = LX.Va(:,[2:end,end])';

      % Run linear simulation with matlab time stepper
      [y,T] = lsim(sysd,u-L.lin.u0L',tliss,x0-L.lin.x0L);
      Y = y' + y0;
      % Correction due to the non-zero residual at the initial step
      xcor = zeros(numel(x0),1);
      for ii = 2:numel(T)
        xcor = sysd.A*xcor + L.lin.x0dotL;
        Y(:,ii) = Y(:,ii) + sysd.C*xcor;
      end

      % Extract outputs from linear simulation
      Ostr = sys.OutputGroup; % Significantly faster to copy the structure
      outlabels = fieldnames(Ostr);
      for ii = outlabels.'
        myfield = ii{:};
        sz = meqsize(L,myfield);
        if prod(sz)==1, sz= 1; end % meqsize returns [1 1] for scalar quantities
        LYss.(myfield) = squeeze(reshape(Y(Ostr.(myfield),:),[sz,numel(T)])); % NOTE: some sizes might not be fully consistent with output of meqlpack
      end
      LYss.t = T; LYss.shot = L.P.shot;

      %% to enable comparison for multidomain cases, remove unused axis values for >nA
      assert(~any(diff(LYnl.nA)),'doesn''t work if number of domains changes during run')
      for ii = fieldnames(LYss)'
        myfield = ii{:};
        if endsWith(myfield,'A')
          LYli.(myfield) = LYli.(myfield)(1:LYnl.nA(1),:);
          LYss.(myfield) = LYss.(myfield)(1:LYnl.nA(1),:);
        end
      end

      %% Define tolerance and typical values for tested fields

      rtol = 1e-3; % Relative tolerance
      
      typ = {'rA',L.P.r0,'zA',L.P.r0,'Ip',L.Ip0,'bp',1,'Wk',1e-7*pi*L.P.r0*L.Ip0.^2,...
             'qA',3,'Fn',L.Fx0,'Brn',L.Fx0/L.P.r0,'Bzn',L.Fx0/L.P.r0,'Ff',L.Fx0,...
             'FA',L.Fx0,'FB',L.Fx0,'Ia',abs(L.Ia0),'Bm',L.Fx0/L.P.r0,'Ft',1e-2,...
             'rIp',L.P.r0*L.Ip0,'zIp',L.P.r0*L.Ip0,'ag',abs(L.ag0),'li',1,...
             'Iu',abs(L.Iu0),...
             'rq',L.P.r0,'zq',L.P.r0,'q95',3,'kappa',1,'epsilon',1,'delta',1,...
             'rIpD',L.P.r0*L.Ip0,'zIpD',L.P.r0*L.Ip0,'IpD',L.Ip0};

      % These have typically larger errors (see comment in method description)
      excld_fields = {'rIp','zIp','ag','li','Iu'};
      % These are not always filled
      iterq_fields = {'rq','zq','epsilon','kappa','delta','q95'};
      multi_fields = {'rIpD','zIpD','IpD'};

      %% Debugging plots
      namef = {'rA';'zA';'rIp';'zIp'}; % Add arbritrary number of fields if desired
      testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)

      namef = {'Ip';'bp';'qA'}; % Add arbritrary number of fields if desired
      testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)

      namef = {'Ft';'FA';'FB';'ag';'li'}; % Add arbritrary number of fields if desired
      testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)

      namef = {'Ia';'Iu';'Bm';'Ff'};  % Add arbritrary number of fields if desired
      testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)

      namef = {'Fn';'Brn';'Bzn'};  % Add arbritrary number of fields if desired
      testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)

      testCase.plot_LCFS(t(1:100:end),L,{LYnl,LYli,LYss})

      if L.P.idoublet
        namef = {'IpD';'rIpD';'zIpD'};  % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss},rtol,typ)
      end

      if L.P.iterq
        namef = {'q95'; 'kappa';'epsilon';'delta'};  % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli},rtol,typ)
      end

      %% Check outputs
      LYs = {LYli,LYss};
      names = {'lin','linss'};
      for jj = 1:numel(names)
        name = names{jj};
        LY = LYs{jj};
        % {name, tollbl} pairs : absolute tolerance
        for ii=1:2:numel(typ)
          field = typ{ii};
          if any(strcmp(field,excld_fields)), continue; end
          if any(strcmp(field,iterq_fields)) && (~L.P.iterq || strcmp(name,'linss')), continue; end % Skip missing iterq fields for linss
          if any(strcmp(field,multi_fields)) && ~L.P.idoublet, continue; end % Skip missing doublet fields
          st = sprintf('%s nl vs %s abs tolerance exceeded',field,name);
          testCase.verifyEqual(LY.(field),LYnl.(field),'AbsTol',rtol*typ{ii+1},st)
        end
      end
    end

    function fgel_algoNL_test(testCase,algoNL)
      % Compare state-space matrices and initial condition for different algoNL values
      % Since the linear state is identical for all algoNL values, the
      % state-space systems should be identical as well.

      % Get pre-computed L and LX structures for FGE
      L = testCase.L_;
      LX = testCase.LX_;

      if strcmp(algoNL,L.P.algoNL)
        return; % Skip comparison test between identical models
      end
      % Update algoNL
      L2 = L;
      L2.P.algoNL = algoNL;
      L2 = fgec(L2.P,L2.G);
      % Compute linearization
      L2 = fgel(L2,meqxk(LX,1));

      % Compare state-space matrices
      [sys ,~,~,~,x0dot ,~] = fgess(L );
      [sys2,~,~,~,x0dot2,~] = fgess(L2);
      tol = 1e-6;
      msg = 'Values of %s of ss model for algoNL=%s do not match with the reference algoNL=all-nl';
      testCase.verifyEqual(sys2.A,sys.A,'AbsTol',tol*max(abs(sys.A(:))),sprintf(msg,'A matrix',algoNL));
      testCase.verifyEqual(sys2.B,sys.B,'AbsTol',tol*max(abs(sys.B(:))),sprintf(msg,'B matrix',algoNL));
      testCase.verifyEqual(sys2.C,sys.C,'AbsTol',tol*max(abs(sys.C(:))),sprintf(msg,'C matrix',algoNL));
      testCase.verifyEqual(sys2.D,sys.D,'AbsTol',tol*max(abs(sys.D(:))),sprintf(msg,'D matrix',algoNL));
      testCase.verifyEqual(x0dot2,x0dot,'AbsTol',tol*max(abs(x0dot(:))),sprintf(msg,'x0dot',algoNL));

    end

    function fgeFlin_test(testCase,algoNL)
      % Compare results of fgeFlin using analytical and numerical jacobians

      % Get pre-computed L and LX structures for FGE
      L = testCase.L_;
      LX = testCase.LX_;

      % Update algoNL
      if ~strcmp(algoNL,L.P.algoNL)
        L.P.algoNL = algoNL;
        L = fgec(L.P,L.G);
      end
      % Get single slice
      LX = meqxk(LX,1);

      % Analytic jacobian
      [~,JxA,JuA,JxdotA] = fgeFlin(L,LX);

      % Numerical jacobian
      L.P.anajac = false;
      [~,JxN,JuN,JxdotN] = fgeFlin(L,LX);

      % Compare values
      rtol = 1e-5;
      msg = 'Values of output %s from fgeFlin do not match for anajac=0/1';
      testCase.verifyEqual(JxN   ,JxA   ,'AbsTol',rtol*max(abs(JxA   (:))),sprintf(msg,'Jx'));
      testCase.verifyEqual(JuN   ,JuA   ,'AbsTol',rtol*max(abs(JuA   (:))),sprintf(msg,'Ju'));
      testCase.verifyEqual(JxdotN,JxdotA,'AbsTol',rtol*max(abs(JxdotA(:))),sprintf(msg,'Jxdot'));
    end
  end
  
  % Methods for plotting.
  methods
    function plot_diff(testCase,namef,LYs,rtol,typ)
      % Debugging plot for time traces
      label = {'NL', 'lin fget', 'lin lsim mat'};
      markertype = {'k', 'r--', 'b:'};
      n = 2;
      m = numel(namef);
      k = numel(LYs);
      % Only make figure visible if verbose
      visible = testCase.verbosity>0;
      fig = figure('Visible',visible);
      % Always close invisible figures
      if ~visible, clean = onCleanup(@() close(fig)); end
      for ii = 1:m
        name = namef{ii};
        ax = subplot(n,m,ii,'Parent',fig);
        hold(ax,'on');
        hs = cell(k,1); % Store plot handle of each LY
        for jj = 1:k
          LY = LYs{jj};
          h = plot(ax,LY.t,LY.(name),markertype{jj});
          hs{jj} = h(1);
        end
        % Title and legend
        title(ax,name);
        legend([hs{:}],label(1:k));
        
        % Plot errors
        ax = subplot(n,m,ii+m,'Parent',fig);
        hold(ax,'on');
        hs = cell(k,1);
        for jj = 2:k
          LY = LYs{jj};
          LYref = LYs{1};
          h = plot(ax,LY.t,LY.(name) - LYref.(name),markertype{jj});
          hs{jj} = h(1); % Pick only the first one for the legend
        end
        % Index of quantity in typ array
        kk = find(strcmp(name,typ(1:2:end)),1,'first');
        if ~isempty(kk)
          % Add horizontal lines marking tolerance on error
          plot(ax,ax.XLim,+[1,1]*rtol*max(typ{2*kk}),'--k','XLimInclude','off')
          plot(ax,ax.XLim,-[1,1]*rtol*max(typ{2*kk}),'--k','XLimInclude','off')
        end
        % Title and legend
        title(ax,sprintf('%s - %s nl',name,name ));
        legend([hs{2:end}],label(2:k));
      end
    end
    
    function plot_LCFS(testCase,t,L,LYs)
      % Debugging plot for flux surfaces
      label = {'NL', 'lin fget', 'lin lsim mat'};
      color = {'b','r','g'};
      nt = numel(t);  m = ceil(nt/2); n = 1 + (nt>m);
      k = numel(LYs);
      % Only make figure visible if verbose
      visible = testCase.verbosity>0;
      fig = figure('Visible',visible);
      % Always close invisible figures
      for  tt=1:nt
        ax = subplot(n,m,tt,'Parent',fig);
        hold(ax,'on');
        hs = cell(k,1); % Store plot handle of each LY
        for  jj = 1:k
          LY = LYs{jj};
          [~,tind]= min(abs(LY.t - t(tt)));
          LYk = meqxk(LY,tind);
          if L.P.idoublet
            v = LYk.FB;
          else
            v = LYk.FB*[1 1];
          end
          [~,h] = contour(ax,L.G.rx, L.G.zx, LYk.Fx, v, color{jj});
          hs{jj} = h;
        end
        plot(ax,L.G.rl, L.G.zl,'k');
        title(ax,sprintf('t = %.4f', t(tt)));
        legend([hs{:}],label(1:k));
        axis(ax,'equal');
      end
      if ~visible, close(fig); end
    end
  end
end


