classdef rzp_fast_test < meq_test
  %% Test class for real time calculations
  % Test the working flow of rzpfastA and rzpfastbgr with different 
  % input parameters their mexification and compare the eigenvalues 
  % calculated with these fast versions versus those calculated with the
  % slow one.
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    verbosity = 0; % 1 to make verbose output, 2 to make plots

    selu     = 'e';
    iterq    = 20;
    agcon    = {'Ip','bp','li'};
    
    rtol     = 2.e-2;  %tolerance on the elements relative difference 
    atol     = 1.5e-8; %tolerance on the elements absolute difference
    n_eigen  = 5;      %number of eigenvalues for the fast-slow comparison
    
    gritermax  = 100;    %maximum number of iteration for the fast growth rate calculation
    grtol      = 1.e-4;  %tolerance for the fast growth rate calculation
    eigitermax = 100;    %maximum number of iteration for the fast eigenvector calculation
    eigtol     = 1.e-4;  %tolerance for the fast eigenvector calculation
    eigvinit   = [];     %starting assumption of the eigenvector for the fast eigenvector calculation
    
    tok,shot,t,nu,insrc;
    P_error,L,LX,FPS;
  end
  
  %% Properties
  properties (ClassSetupParameter)
    tok_            = {'ana','tcv'};  %tokamak type
    tnum            = {1,10};         %number of time points
    shot_time_index = {1,2};        %index for the shot number (different lists are used in anamak and tcv case)
  end
  
  properties (MethodSetupParameter)
    vessel_nu  = struct('nu_max', 'max', 'nu_30', 30   , 'fbt', 30   );
    insrc_type = struct('nu_max', 'liu', 'nu_30', 'liu', 'fbt', 'fbt');
  end
  
  %% Static method used to shorten notation
  methods (Static)
    
    function [K,b,alpha,beta,c] = base_fast_compacted_input(L,LX)
      %Squeeze input parameters for rzpfastbase
      [K,b,alpha,beta,c] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast); 
    end
    
    function m = m_fast_compacted_input(L,LX)
      %Squeeze input parameters for rzpfastm
      [K,b] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast); 
      [m] = rzpfastm(K,b,L.N);
    end
    
    function [A,Ainv] = A_fast_compacted_input(L,LX)
      %Squeeze input parameters for rzpfastA
      [~,~,alpha,beta,~,Kinv] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast);  
      [A,Ainv] = rzpfastA(alpha,beta,L.DD,Kinv,L.DDinv);
    end
    
    function [gamma,res,it] = gr_fast_compacted_input(L,LX,iter_max,tol)
      %Squeeze input parameters for rzpfastgr
      [~,~,alpha,~,c] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast); 
      [gamma,res,it] = rzpfastgr(alpha,c,L.dd,L.dmax,iter_max,tol);
    end
    
    function [v,w,res,it] = eig_fast_compacted_input(L,LX,iter_max,tol,v_init)
      %Squeeze input parameters for rzpfasteig
      [~,~,alpha,beta,c] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast);
      gamma = rzpfastgr(alpha,c,L.dd,L.dmax,L.P.gritermax,L.P.grtol);
      A = rzpfastA(alpha,beta,L.DD);
      [v,w,res,it] = rzpfasteig(gamma,A,iter_max,tol,v_init);
    end
    
    function [dzmax] = dzmax_fast_compacted_input(L,LX,FPS)
      %Squeeze input parameters for rzpfastdzmax
      [~,b,alpha,beta,c,Kinv] = rzpfastbase(LX.Ia,LX.Iu,LX.Iy,LX.Ip,LX.rIp,...
        L.N,L.F,L.Trd,L.Tzd,L.ny,L.np,L.ne,L.RBrye,L.RBzye,L.RBrye_fast,L.RBzye_fast);
      gamma = rzpfastgr(alpha,c,L.dd,L.dmax,L.P.gritermax,L.P.grtol);
      [A,Ainv] = rzpfastA(alpha,beta,L.DD,Kinv,L.DDinv);
      [v,w] = rzpfasteig(gamma,A,L.P.eigitermax,L.P.eigtol,L.P.eigvinit);
      [dzmax] = rzpfastdzmax(Kinv,b,alpha,beta,gamma,v,w,Ainv,LX.Va,LX.Ia,L.Le_fast,L.ne,FPS.bea,FPS.Vsat,FPS.Iamax,FPS.Tps,FPS.na_sel);
    end
  end
  
  %% Setup
  methods (TestClassSetup)
    function setup_parameters(testCase,tok_,tnum,shot_time_index)
      %% Setup the shot and time for the test according to the parameter combination
      [ok,msg] = meq_test.check_tok(tok_);
      testCase.assumeTrue(ok,msg);

      testCase.tok  = tok_;
      if strcmp(tok_,'ana')  
        shot_list     = [1,2]; 
        time_list     = {linspace(0.005,0.015,tnum),linspace(0.005,0.015,tnum)};
        testCase.shot = shot_list(shot_time_index);
        testCase.t    = time_list{shot_time_index};
      else 
        shot_list     = [61400,63783]; 
        time_list     = {linspace(0.99,1,tnum),linspace(0.99,1,tnum)};
        testCase.shot = shot_list(shot_time_index);
        testCase.t    = time_list{shot_time_index};
      end

    end
  end
  
  methods (TestMethodSetup, ParameterCombination='sequential')
    function setup_nu(testCase,vessel_nu,insrc_type)
      %% Setup nu according to the tokamak and the vessel type
      testCase.insrc = insrc_type;
      if strcmp(vessel_nu,'max')
        if strcmp(testCase.tok,'ana')
          testCase.nu = 200;
        else
          testCase.nu = 256;
        end
      else
        testCase.nu = vessel_nu;
      end
      
      P_args = {'insrc',testCase.insrc,'selu',testCase.selu,...
        'agcon',testCase.agcon,'debug',testCase.verbosity,'nu',testCase.nu};
      testCase.P_error = sprintf('%s: %s, %s: %s, %s: %d',P_args{1:2},P_args{9:10});   
      
      [testCase.L,testCase.LX] = rzp(testCase.tok,testCase.shot,testCase.t,P_args{:},'gritermax',testCase.gritermax,'grtol',testCase.grtol,'eigitermax',testCase.eigitermax,'eigtol',testCase.eigtol,'eigvinit',testCase.eigvinit,'fastcoil',true);
      
      if strcmp(testCase.tok,'ana')
        %Use same parameters as TCV's FPS
        testCase.FPS.bea = [contains(testCase.L.G.dima,'VS'); zeros(testCase.L.G.nu,1)];
        testCase.FPS.bea = testCase.L.Tee_fast'*testCase.FPS.bea;
        testCase.FPS.Tps = 1.e-4;  %[s]
        testCase.FPS.Vsat = 280;   %[V] 566 if in series, but 280 if in antiseriers
        testCase.FPS.Iamax = 1600; %[A] 2000 if in series, but 1600 if in antiseriers
        testCase.FPS.na_sel = find(contains(testCase.L.G.dima,'VS'));
        testCase.FPS.La = testCase.L.Mee(testCase.FPS.na_sel,testCase.FPS.na_sel);
      else
        testCase.FPS.bea = [contains(testCase.L.G.dima,'G'); zeros(testCase.L.G.nu,1)];
        testCase.FPS.bea = testCase.L.Tee_fast'*testCase.FPS.bea;
        testCase.FPS.Tps = 1.e-4;  %[s]
        testCase.FPS.Vsat = 280;   %[V] 566 if in series, but 280 if in antiseriers
        testCase.FPS.Iamax = 1600; %[A] 2000 if in series, but 1600 if in antiseriers
        testCase.FPS.na_sel = find(contains(testCase.L.G.dima,'G'));
        testCase.FPS.La = testCase.L.Mee(testCase.FPS.na_sel,testCase.FPS.na_sel);
      end
    end
  end
  
  %% Tests
  methods (Test, TestTags = {'rzp'}, ParameterCombination='sequential')   
    function code_test_parameter(testCase)    
      %% Test that all rzpfast* functions are operating      
      
      %Test for rzpfastbase
      case_string = sprintf('fast base change calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          testCase.base_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii));   
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfastbase by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
      
      %Test for rzpfastm
      case_string = sprintf('fast instability index calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          m = testCase.m_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii));
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfastgr by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
      
      %Test for rzpfastA
      case_string = sprintf('fast A calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          testCase.A_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii));   
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfastA by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
      
      %Test for rzpfastgr
      case_string = sprintf('fast growth-rate calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          testCase.gr_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii),testCase.gritermax,testCase.grtol);   
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfastgr by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
      
      %Test for rzpfasteig
      case_string = sprintf('fast eigenvalue calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          testCase.eig_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii),testCase.eigitermax,testCase.eigtol,testCase.eigvinit);   
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfasteig by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
      
      %Test for rzpfastdzmax
      case_string = sprintf('fast maximum controllable displacement calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
      try
        for ii = 1:numel(testCase.t)
          testCase.dzmax_fast_compacted_input(testCase.L,meqxk(testCase.LX,ii),testCase.FPS);   
        end
      catch ME
        fprintf('\n %s error caused in %s in rzpfastdzmax by parameters %s \n',...
          ME.message,case_string,testCase.P_error);
          rethrow(ME);
      end
    end
    
    function mexify_test(testCase)    
      %% Test that rzpfastmexify (which contain all rzpfast* files mexification) is operating   
      if numel(testCase.t) == 1 && testCase.nu == 30
        temp_dir  = tempname;
        store_dir = fullfile(temp_dir,'meq','codegen');
        out_dir   = store_dir;
        
        case_string = sprintf('fast growth rate mexification with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
        file_name = {'rzpfastbase','rzpfastm','rzpfastA','rzpfastgr','rzpfasteig','rzpfastdzmax'};
        out_file = cell(size(file_name));
        try
          for ii = 1:numel(file_name)
            out_file{ii}  = fullfile(out_dir,[file_name{ii} '_mex.',mexext]);
            if exist(out_file{ii},'file')
              delete(out_file{ii});
            end
          end
          rzpfastmexify(testCase.L,testCase.FPS,store_dir,out_dir);  
        catch ME
          fprintf('\n %s error caused in %s in rzpfastmexify by parameters %s \n',...
            ME.message,case_string,testCase.P_error);
          rethrow(ME);
        end
        
        for ii = 1:numel(file_name)
          delete(out_file{ii});
        end
        
        rmdir(temp_dir,'s'); % Remove full temporary directory
      end
    end
    
    function margin_test(testCase)
      %% Test fast vs slow ideal instability margin calculation
      if testCase.nu == 30 && numel(testCase.t) == 1
        case_string = sprintf('fast ideal instability calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));
        
        eigen = fgeeig(testCase.L,'LX',testCase.LX); 

        m_fast = testCase.m_fast_compacted_input(testCase.L,testCase.LX);
        
        [~,dFprzdag,~,dFprzdIe] = rzpFlinfast(testCase.L,testCase.LX);
        K = dFprzdag(2:3,2:3); % K = [dFrdRc dFrdZc; dFzdRc dFzdZc];
        b = dFprzdIe(2:3,:)';  % b = [dFrdIe; dFzdIe]';
        S = testCase.L.Mee-b*(K\b');
        Am = -S/testCase.L.Mee;
        m_slow = max(real(eig(Am)));
        
        if eigen > 0
          testCase.verifyTrue( m_fast > 0  && m_slow > 0 ,sprintf('Ideal instability margin less than 0 in %s by paramters %s',case_string,testCase.P_error));
          testCase.verifyEqual( m_slow, m_fast,...
            'RelTol', testCase.rtol,sprintf('Eigen fast vs slow tolerance exceeded in %s by paramters %s',case_string,testCase.P_error));
        else
          testCase.verifyTrue( m_fast <= 0 && m_slow <= 0 ,sprintf('Ideal instability margin more than 0 in %s by paramters %s',case_string,testCase.P_error));
        end

      end
    end

    function eigen_test(testCase)
      %% Test fast vs slow eigenvalue calculation
      if testCase.nu == 30 && numel(testCase.t) == 1
        case_string = sprintf('fast eig calculation with %s_%u_%.5g',testCase.tok,testCase.shot,testCase.t(1));

        %Fast eigenvalue computation through fast A calculation
        A = testCase.A_fast_compacted_input(testCase.L,testCase.LX);
        eig_fast = esort(eig(A));
        
        %Ultrafast eigenvalue computation without A calculation
        eig_uf = testCase.gr_fast_compacted_input(testCase.L,testCase.LX,testCase.gritermax,testCase.grtol);   
         
        %Slow eigenvalue computation through fgess
        eig_slow = fgeeig(testCase.L,'LX',testCase.LX,'nmodes',5);
        
        %Eigenvalue comparison
        if eig_slow(1) > 0 
          testCase.verifyEqual( eig_fast(1:testCase.n_eigen)', eig_slow(1:testCase.n_eigen),...
            'RelTol', testCase.rtol,sprintf('Eigen fast vs slow tolerance exceeded in %s by paramters %s',case_string,testCase.P_error));
          testCase.verifyEqual( double(eig_uf), eig_slow(1),...
            'RelTol', testCase.rtol,sprintf('Eigen ufast vs slow tolerance exceeded in %s by paramters %s',case_string,testCase.P_error));
          testCase.verifyEqual( double(eig_uf), eig_fast(1),...
            'RelTol', testCase.rtol,sprintf('Eigen ufast vs fast tolerance exceeded in %s by paramters %s',case_string,testCase.P_error));
        end
      end
    end
    
  end
end
