classdef fgelin_test < meq_test
  % Test of linearized fge with open loop response to different perturbations
  % Test fge NL against fge linear with fget and matlab lsim
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    verbosity = 0; % >0 makes comparison plots
    dt = 1e-4;  % Simulation dt for fge
  end
  
  properties (TestParameter)
    shot = struct('stable_CDE',1             ,'stable_fixed_IP', 1,        'stable_fixed_IP_full_pp', 1        ,'doublet_fixed_IP',82);
    cde  = struct('stable_CDE','OhmTor_rigid','stable_fixed_IP','',        'stable_fixed_IP_full_pp',''        ,'doublet_fixed_IP','');
    tok  = struct('stable_CDE','ana'         ,'stable_fixed_IP','ana',     'stable_fixed_IP_full_pp','ana'     ,'doublet_fixed_IP','ana');
    t0   = struct('stable_CDE', 0            ,'stable_fixed_IP', 0,        'stable_fixed_IP_full_pp', 0        ,'doublet_fixed_IP',0 ); % Initial time of simulation
    tint = struct('stable_CDE',0.005         ,'stable_fixed_IP',0.005,     'stable_fixed_IP_full_pp',0.005     ,'doublet_fixed_IP',0.005); % Simulation interval
    pert = struct('stable_CDE','Va_step'     ,'stable_fixed_IP','all_step','stable_fixed_IP_full_pp','all_step','doublet_fixed_IP','all_step');
    iterq= struct('stable_CDE',0             ,'stable_fixed_IP',0,         'stable_fixed_IP_full_pp',20        ,'doublet_fixed_IP',0);
    %{
    % Long slow test for long time behavior of CDE in TCV
    shot = struct( 'stable_TCV_CDE', 54653 );
    cde =  struct( 'stable_TCV_CDE','OhmTor_rigid');
    tok =  struct( 'stable_TCV_CDE','tcv');
    t0 = struct( 'stable_TCV_CDE', 0.1);
    tint = struct( 'stable_TCV_CDE',0.04);
    pert = struct( 'stable_TCV_CDE','Va_step');
    %}
  end
  
  methods(Test,TestTags={'fgelin'},ParameterCombination='sequential')
    function fge_ss_lin(testCase,shot,cde,tok, t0, tint, pert, iterq)
      dt = testCase.dt; %#ok<*PROPLC,*PROP>
      t= (t0:dt:t0 + tint);
      nt = numel(t);
      
      %% Compute also interpolation for Fn,Brn,Bzn in the center of the tokamak, and z=0,r=min(rl)
      [L] = fge(tok,shot,t);
      rn = [mean([max(L.G.rl),min(L.G.rl)]);min(L.G.rl)];
      zn = [mean([max(L.G.zl),min(L.G.zl)]);0];
      
      %% get fge structure
      % L, LX structure for linear and non linear sim
      par = {tok,shot,t,'usepreconditioner',1,...
        'debug',0,'selu','e','nu',20,'tolF',1e-11,'vac',false, 'cde', cde, 'infct', @qintmex, 'rn', rn, 'zn', zn, 'iterq', iterq};
      
      if isequal(tok,'ana') && shot>=80 && shot<90
        % doublet parameters
        par = [par,...
          {'agcon',{{'Ip','bp'},{'Ip','bp'},{'Ip','bp'}},'bfct',@bfabmex,'bfp',[1 1],'idoublet',true,'agfitp',[true,false,true]}];
      end
      
      [L,LX] = fge(par{:});
      
      %% Set inputs to exact comparison
      % Constant inputs constraints
      LX0 = meqxk(LX,1);
      for ii = 1:size(L.agconc,1)
        field = L.agconc{ii,3};
        LX.(field) = repmat(LX0.(field),1,nt);
      end
      
      %% State-space linearized system A,B,C,D
      L = fgel(L, meqxk(LX,1));
      sys = fgess(L); % Continuous time
      [sysd,~,x0,yo] = fgess(L,dt); % Discrete time
      
      %% compare poles
      polec = esort(pole(sys));
      poled = dsort(pole(sysd));
      
      % info
      fprintf('tok:%s, shot=%d, growthrate = %2.2f/s\n',tok,shot,polec(1)/(2*pi));
      % compare discrete-time pole to real one
      testCase.verifyEqual(polec(1),(poled(1)-1)/dt,'RelTol',0.03)
      
      %% Perturbation to inputs
      switch pert
        case 'Va_step'
          offset = 1e-3; % Offset for step ms
          stepize = 10; % Step size
          icoil = 1; % Coil index
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          dVa = zeros(L.G.na,numel(t));
          dVa(icoil,tind:end) = stepize;
          LX.Va = LX.Va + dVa;
        case 'Va_sin'
          % Sinusoidal perturbations on all coils except first, stagger start of perturbations
          dVa = 1e-2  *sin(0.2e3*2*pi*LX.t+(1:L.G.na)'); dVa(:,1) = 0;
          LX.Va = LX.Va + dVa;
        case 'bp_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.bp(1,tind:end) = alpha*LX.bp(1,tind:end);
        case 'Ip_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.Ip(1,tind:end) = alpha*LX.Ip(1,tind:end);
        case 'qA_step'
          offset = 1e-3; % Offset for step ms
          alpha = 1.01;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.qA(1,tind:end) = alpha*LX.qA(1,tind:end);
        case 'all_step'
          offset_Ip = 1e-3;  alpha_Ip = 1.01;
          offset_bp = 2e-3;  alpha_bp = 1.01;
          offset_qA = 3e-3;  alpha_qA = 1.01;
          offset_Va = 4e-3;  dVa = 10;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_Ip));
          LX.Ip(1,tind:end) = alpha_Ip*LX.Ip(1,tind:end);
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_bp));
          LX.bp(1,tind:end) = alpha_bp*LX.bp(1,tind:end);
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_qA));
          LX.qA(1,tind:end) = alpha_qA*LX.qA(1,tind:end);
          [~,tind] = min(abs(LX.t-LX.t(1)-offset_Va));
          LX.Va(1,tind:end) = LX.Va(1,tind:end) + dVa;
        case 'Ini'
          offset = 3e-3; % [s]
          dIni = -1e5;
          [~,tind] = min(abs(LX.t-LX.t(1)-offset));
          LX.Ini(1,tind:end) = LX.Ini(1,tind:end) + dIni;
        otherwise
      end
      
      %% NL sim
      LYnl = fget(L,LX);
      
      %% Linear sim with fget stepper
      LYli = fget(L,LX, 'lin', true);
      
      %% Linear sim with matlab stepper
      % Set inputs for linear simulation
      tliss = t;
      u = zeros(numel(tliss),numel(sys.InputName));
      
      % Get constraints
      Co = LX2Co(L,LX);
      if L.icde, Co = Co(~contains(L.agconc(:,3),'Ip'),:); end
      
      % precompute dCodt
      dCodt = [diff(Co')/dt;zeros(1,size(Co,1))]; % Add the last value at 0 since diff remove dimension
      u(:,sys.InputGroup.Co) = Co';
      u(:,sys.InputGroup.dCodt) = dCodt;
      
      if L.icde, u(:,sys.InputGroup.Ini) = LX.Ini; end
      
      % fget time stepper for feedforward voltage assumes
      %  x[n] = A x[n-1] + B VA[n]
      % Shift input voltages for machine precision comparison
      u(:,sys.InputGroup.Va) = LX.Va(:,[2:end,end])';
      
      % Run linear simulation with matlab time stepper
      [y,T,x] = lsim(sysd,u,tliss,x0);
      Y = y' + yo;
      
      % Extract outputs from linear simulation
      Ostr = sys.OutputGroup; % Significantly faster to copy the structure
      outlabels = fieldnames(Ostr);
      for ii = outlabels.'
        myfield = ii{:};
        sz = meqsize(L,myfield);
        if prod(sz)==1, sz= 1; end % meqsize returns [1 1] for scalar quantities
        LYss.(myfield) = squeeze(reshape(Y(Ostr.(myfield),:),[sz,numel(T)])); % NOTE: some sizes might not be fully consistent with output of meqlpack
      end
      LYss.t = T; LYss.shot = shot;
      
      %% to enable comparison for multidomain cases, remove unused axis values for >nA
      assert(~any(diff(LYnl.nA)),'doesn''t work if number of domains changes during run')
      for ii = fieldnames(LYss)'
        myfield = ii{:};
        if endsWith(myfield,'A')
          LYli.(myfield) = LYli.(myfield)(1:LYnl.nA(1),:);
          LYss.(myfield) = LYss.(myfield)(1:LYnl.nA(1),:);
        end
      end
      
      %% Debugging plots
      if testCase.verbosity>0
        namef = {'rA';'zA';'rIp';'zIp'}; % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss})
        
        namef = {'Ip';'bp';'qA'}; % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss})
        
        namef = {'Ft';'FA';'FB';'ag';'li'}; % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss})
        
        namef = {'Ia';'Iu';'Bm';'Ff'};  % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss})
        
        namef = {'Fn';'Brn';'Bzn'};  % Add arbritrary number of fields if desired
        testCase.plot_diff(namef,{LYnl,LYli,LYss})
        
        testCase.plot_LCFS(t(1:100:end),L,{LYnl,LYli,LYss})
        
        if L.P.idoublet
          namef = {'IpD';'rIpD';'zIpD'};  % Add arbritrary number of fields if desired
          testCase.plot_diff(namef,{LYnl,LYli,LYss})
        end
        
        if L.P.iterq
          namef = {'q95'; 'kappa';'epsilon';'delta'};  % Add arbritrary number of fields if desired
          testCase.plot_diff(namef,{LYnl,LYli})
        end
        
      end
      
      %% Check outputs
      % {name, tollbl} pairs : absolute tolerance
      tol = {'rA',1e-4,'zA',1e-4,'Ip',15,'bp',1e-5,'qA',2e-4,'Fn',1e-4,'Brn',1e-5,'Bzn',1e-5};
      for ii=1:2:numel(tol)
        st = sprintf('%s nl vs lin abs tolerance exceeded', tol{ii});
        testCase.verifyEqual(LYli.(tol{ii}),LYnl.(tol{ii}),'AbsTol',tol{ii+1},st)
      end
      for ii=1:2:numel(tol)
        st = sprintf('%s nl vs linss abs tolerance exceeded', tol{ii});
        testCase.verifyEqual(LYss.(tol{ii}),LYnl.(tol{ii}),'AbsTol',tol{ii+1},st)
      end
      
      % {name, tollbl} pairs : relative tolerance
      tol = {'FA', 1e-2,'FB',1e-2,'Ia',2e-3,'Bm',1e-3,'Ff',2e-3,'Ft',1e-2};
      for ii=1:2:numel(tol)
        st = sprintf('%s nl vs lin rel tolerance exceeded', tol{ii});
        testCase.verifyEqual(LYli.(tol{ii}),LYnl.(tol{ii}),'RelTol',tol{ii+1},st)
      end
      for ii=1:2:numel(tol)
        st = sprintf('%s nl vs linss rel tolerance exceeded', tol{ii});
        testCase.verifyEqual(LYss.(tol{ii}),LYnl.(tol{ii}),'RelTol',tol{ii+1},st)
      end
      
      % Check the flux surface post-precessing
      if iterq
        tol = {'rq', 1e-4,'zq',1e-4,'q95',1e-2,'kappa', 1e-2, 'epsilon', 1e-2, 'delta',1e-2};
        for ii=1:2:numel(tol)
          st = sprintf('%s nl vs lin with nl post-processing abs tolerance exceeded', tol{ii});
          testCase.verifyEqual(LYli.(tol{ii}),LYnl.(tol{ii}),'AbsTol',tol{ii+1},st)
        end
      end
      
      % Domain specific quantities only for doublets
      if L.P.idoublet
        tol = {'rIpD',1,'rIpD',1 , 'IpD',1};
        for ii=1:2:numel(tol)
          st = sprintf('%s nl vs lin abs tolerance exceeded', tol{ii});
          testCase.verifyEqual(LYli.(tol{ii}),LYnl.(tol{ii}),'AbsTol',tol{ii+1},st)
        end
        for ii=1:2:numel(tol)
          st = sprintf('%s nl vs linss abs tolerance exceeded', tol{ii});
          testCase.verifyEqual(LYss.(tol{ii}),LYnl.(tol{ii}),'AbsTol',tol{ii+1},st)
        end
      end
      
    end
  end
  
  % Static methods for plotting.
  methods (Static)
    function plot_diff(namef,LYstrarr)
      label = {'NL', 'lin fget', 'lin lsim mat'};
      % Debugging plot for time traces
      markertype = {'k', 'r--', 'b:'};
      n = 2;
      m = numel(namef);
      ind = 0;
      figure
      for ii = namef.'
        ind = ind +1;
        subplot(n,m,ind)
        hold on
        clear ax
        iax =0;
        for jj = 1:numel(LYstrarr)
          iax = iax +1;
          LY = LYstrarr{jj};
          plot(LY.t,LY.(ii{:}),markertype{jj});
          ax(iax) = plot(NaN,NaN,markertype{jj});
        end
        title(sprintf('%s',ii{:}))
        legend(ax, label)
        
        % Plot errors
        ind1 = ind + numel(namef);
        subplot(n,m,ind1)
        hold on
        clear ax
        iax =0;
        for jj = 2:numel(LYstrarr)
          iax = iax +1;
          LY = LYstrarr{jj};
          LYref = LYstrarr{1};
          plot(LY.t,LY.(ii{:}) - LYref.(ii{:}),markertype{jj});
          ax(iax) = plot(NaN,NaN,markertype{jj});
        end
        title(sprintf('%s - %s nl',ii{:},ii{:} ))
        legend(ax, label{2:end})
      end
    end
    
    function plot_LCFS(t,L,LYstrarr)
      label = {'NL', 'lin fget', 'lin lsim mat'};
      % Debugging plot for flux surfaces
      color = {'b','r','g'};
      nt = numel(t);  m = ceil(nt/2); n = nt -m;
      figure
      hold on
      ind = 0;
      for  tt=1:nt
        ind = ind +1;
        subplot(m,m,ind)
        hold on
        clear ax
        for  jj = 1:numel(LYstrarr)
          LY = LYstrarr{jj};
          [~,tind]= min(abs(LY.t - t(tt)));
          LYk = meqxk(LY,tind);
          if L.P.idoublet
            contour(L.G.rx, L.G.zx, LYk.Fx, LYk.FB, color{jj})
          else
            contour(L.G.rx, L.G.zx, LYk.Fx, LYk.FB*[1 1], color{jj})
          end
          ax(jj) = plot(NaN,NaN,color{jj});
        end
        plot(L.G.rl, L.G.zl,'k')
        title(sprintf('t = %.4f', t(tt)))
        legend(ax, label)
        axis equal
      end
    end
  end
end


