classdef (SharedTestFixtures={meq_simulink_fixture}) liusim_tests < meq_test
  % Tests for LIUQE Simulink version
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  % Notes: TCV bfef and bf3i are tuned to resemble a bf3p run. 
  % Lower bound for flux-surface contours increased and iqR set to 1/2 to
  % increase chances of passing tests. itert/iterq also increased for
  % bf3i/bfef.
  
  properties
    verbosity = 0;
  end
  
  properties (TestParameter)
%     tok  = struct('CREATE','CREATE','TCV','TCV');
%     t1   = struct('CREATE',625.0,'TCV',0.5);
%     t2   = struct('CREATE',625.1,'TCV',0.6);
%     dt   = struct('CREATE',1e-3,'TCV',1e-3);
%     shot = struct('CREATE',{{'~/testdata/meq/Case001_bellshape_curdriv_t625d000_t629d140_CNL4LIUQE.mat'}},...
%       'TCV',61400);

   % TCV_pdom is a case where domain identification used to fail for RTLIUQE
    tok  = struct('TCV_bf3p','TCV'    ,'TCV_bfef','TCV'     ,'TCV_bf3i','TCV'    ,'ana_bf3p','ana'    ,'ana_it2','ana'    ,'ana_iq2','ana'    ,'TCV_pdom','tcv'    );
    t1   = struct('TCV_bf3p',0.5      ,'TCV_bfef',0.5       ,'TCV_bf3i',0.5      ,'ana_bf3p',0        ,'ana_it2',0        ,'ana_iq2',0        ,'TCV_pdom',0.90     );
    t2   = struct('TCV_bf3p',0.6      ,'TCV_bfef',0.6       ,'TCV_bf3i',0.6      ,'ana_bf3p',0.1      ,'ana_it2',0.1      ,'ana_iq2',0.1      ,'TCV_pdom',0.92     );
    dt   = struct('TCV_bf3p',1e-3     ,'TCV_bfef',1e-3      ,'TCV_bf3i',1e-3     ,'ana_bf3p',1e-3     ,'ana_it2',1e-3     ,'ana_iq2',1e-3     ,'TCV_pdom',0.001    );
    shot = struct('TCV_bf3p',61400    ,'TCV_bfef',61400     ,'TCV_bf3i',61400    ,'ana_bf3p',2        ,'ana_it2',2        ,'ana_iq2',2        ,'TCV_pdom',78698    );
    bfct = struct('TCV_bf3p','bf3pmex','TCV_bfef','bfefmex' ,'TCV_bf3i','bf3imex','ana_bf3p','bf3pmex','ana_it2','bf3pmex','ana_iq2','bf3pmex','TCV_pdom','bf3pmex');
    bfp  = struct('TCV_bf3p',false    ,'TCV_bfef',[2 3]     ,'TCV_bf3i',bf3ip()  ,'ana_bf3p',false    ,'ana_it2',false    ,'ana_iq2',false    ,'TCV_pdom',false    );
    wreg = struct('TCV_bf3p',2e-6     ,'TCV_bfef',[2e-6;1;1],'TCV_bf3i',0        ,'ana_bf3p',5e-5     ,'ana_it2',5e-5     ,'ana_iq2',5e-5     ,'TCV_pdom',2e-6     );
   itert = struct('TCV_bf3p',1        ,'TCV_bfef',2         ,'TCV_bf3i',2        ,'ana_bf3p',1        ,'ana_it2',2        ,'ana_iq2',1        ,'TCV_pdom',1        );
   iterq = struct('TCV_bf3p',1        ,'TCV_bfef',2         ,'TCV_bf3i',2        ,'ana_bf3p',1        ,'ana_it2',2        ,'ana_iq2',2        ,'TCV_pdom',0        );
  end
  
  properties(TestParameter)
    % Variants
    BFCTMODE  = struct('bf3p',1,'bf3i',2,'bfef',3);
    HFCTMODE  = struct('LS',1,'IPMsingle',2,'IPMdouble',3);
    IPMODE    = struct('MEAS',1,'TRAPEZE',2);
    BSLVMODE  = struct('single',1,'double',2);
    INFCTMODE = struct('none',1,'qintmex',2);
    AWMODE    = struct('OFF',0,'ON',1);
    LOCSMODE  = struct('OFF',0,'ON',1);
    LOCRMODE  = struct('OFF',0,'ON',1);
    ITERQMODE = struct('OFF',0,'ON',1);
  end
  
  methods (Test, TestTags={'Simulink'},ParameterCombination='sequential')
    
    function slx_vs_m(testCase,tok,shot,t1,t2,dt,bfct,bfp,wreg,itert,iterq)
          
      [ok,msg] = meq_test.check_tok(tok,shot);
      testCase.assumeTrue(ok,msg);
      
      time = t1:dt:t2; nt = numel(time);
            
      % provide variability in size of aW, raR
      switch tok
        case 'ana', nFW = 2; naR = 2;
        otherwise,  nFW = 1; naR = 1;
      end
      %% Get L,LX
      [L,LX] = testCase.get_liu_L_LX(tok,shot,time,'itert',itert,'iterq',iterq,'slx',true,...
                                   'bfct',str2func(bfct),'bfp',bfp,...
                                   'iters',48,'noq',50,'bslvdprec',true,'wreg',wreg,...
                                   'wregadapt',15,'npq',18,'pq',[],'pql',0.4,...
                                   'iqR',1./[3/2,2],'raS',linspace(0,1,3),...
                                   'naR',naR,'nFW',nFW,...
                                   'infct',@qintmex,'rn',1,'zn',0.1);

      % run Matlab
      L.liurtemu = true;
      LYm = liut(L,LX,'rst',1,'slx',false);
      LYm.niter = int32(LYm.niter);
      
      % run Simulink
      L.liurtemu = false;
      LYs = liutsim(L,LX);      
      
      %% Check outputs
      fields=fieldnames(LYs); nfields = numel(fields);

      ignoredFields = {'shot','IpD','cycle','dr2FA','dz2FA','drzFA','PpQ',... % partial derivatives are too sensitive
                       'zIp','rIp','zIpD','rIpD','rYD','zYD','rY','zY',...    % TODO: Check why PpQ is needed...
                       'Bzrn','FR','FW','qA'};                                % TODO: FW dimensions [1 1 N] vs [1 N]
                                                                              % qA method is different in meqpost and meqpostq

      mfields = fieldnames(LYm);

      % The point where the loop breaks is different in MATLAB and Simulink
      uptodate_fields = {'Iy','ag','dz','dzg','Bm','Ff','Ia','Is','Iu',...
        'rst','niter','err'};

      for ifield = 1:nfields
        myfield = fields{ifield};
        if any(strcmp(myfield,ignoredFields)), continue; end % ignore this field
        if ~any(strcmp(myfield,mfields)), continue; end % skip non-existent fields
        
        LYsval = LYs.(myfield);
        LYmval = LYm.(myfield);

        if ~any(strcmp(myfield,uptodate_fields)), continue; end
        
        testCase.verifyEqual(size(LYsval),size(LYmval),sprintf('dimensions not equal: %s\n',myfield));
        dimssame = isequal(size(LYsval),size(LYmval));
        if dimssame && ismatrix(LYsval)
          % check numerical error
          ds = reshape(LYsval,[],nt);
          dm = reshape(LYmval,[],nt);
          absErr = norm(double(ds-dm));
          relErr = absErr/norm(double(ds));
          if testCase.verbosity
            %% plot
            
            clf;
            subplot(2,1,1)
            plot(LYs.t,ds,'b',LYm.t,dm,'r--')
            title(sprintf('%s: abs;%4.3e rel:%4.3e: .slx(b), .m(r--) \n',myfield,absErr,relErr));
            
            subplot(2,1,2)
            plot(LYs.t,sum(ds-dm,1)/size(ds,1));
            title('error');
            pause(0.1);
          end
        
          %% compare
          ok = relErr < 2e-2 || absErr < 1e-3;
          testCase.verifyTrue(ok,sprintf('%s values did not match within tolerance',myfield))
        end
      end
    end
    
    function test_liusim_itert(testCase,tok,shot,t1,dt,bfct,bfp)
      
      [ok,msg] = meq_test.check_tok(tok,shot);
      testCase.assumeTrue(ok,msg);
      
      time = [t1,t1+dt];
      
      %% Get L,LX
      [L,LX] = testCase.get_liu_L_LX(tok,shot,time,'itert',2,'iterq',1,...
                                     'bfct',str2func(bfct),'bfp',bfp,...
                                     'iters',20,'noq',50,'bslvdprec',true,'wreg',2e-6,...
                                     'wregadapt',15,'npq',18,'pq',[],...
                                     'iqR',1,'raS',linspace(0,1,3),...
                                     'slx',true);
                               
      % run Simulink with itert=2
      LY2 = liutsim(L,LX);
      
      % run Simulink with itert=1 and repeated indices
      L.P.itert = 1;
      LX = meqxk(LX,[1,1,2,2]);
      LY1 = liutsim(L,LX);
        
      %% compare
      testCase.verifyEqual(LY2.ag,LY1.ag(:,2:2:end),'RelTol',10*eps('single'),...
        'Simulation with itert=2 and itert=1 with repeated measurements did not match within tolerance')
      
    end
  end
  
  methods (Test, TestTags={'Simulink'}, ParameterCombination='pairwise')
    
    function variant(testCase,BFCTMODE,HFCTMODE,IPMODE,BSLVMODE,INFCTMODE,AWMODE,LOCSMODE,LOCRMODE,ITERQMODE)
      
      tok = 'ana'; %#ok<*PROPLC>
      t1 = 0.;
      t2 = 0.01;
      dt = 1e-3;
      shot = 2;
      wreg = 5e-5;
      itert = 1;
      
      tolh = 2e-3;
          
      [ok,msg] = meq_test.check_tok(tok,shot);
      testCase.assumeTrue(ok,msg);
      
      time = t1:dt:t2;
      
      switch BFCTMODE
        case 1, bfct = 'bf3pmex'; bfp = false;
        case 2, bfct = 'bf3imex'; bfp = bf3ip();
        case 3, bfct = 'bfefmex'; bfp = [1,2];
      end
      
      switch HFCTMODE
        case 1, iterh= 0; ipmhdprec = false;
        case 2, iterh=30; ipmhdprec = false;
        case 3, iterh=30; ipmhdprec = true;
      end
      
      switch IPMODE
        case 1, Ipmeas=1;
        case 2, Ipmeas=0;
      end
      
      switch BSLVMODE
        case 1, bslvdprec = false;
        case 2, bslvdprec = true;
      end
      
      switch INFCTMODE
        case 1, infct = [];
        case 2, infct = @qintmex;
      end
      
      switch AWMODE
        case 0, nFW = 0;
        case 1, nFW = 1;
      end
      
      switch LOCSMODE
        case 0, raS = [];
        case 1, raS = linspace(0,1,3);
      end
      
      switch LOCRMODE
        case 0, iqR = [];
        case 1, iqR = 1./[3/2,2];
      end
      
      switch ITERQMODE
        case 0, iterq = 0;
        case 1, iterq = 1;
      end
      
      %% Check that simulation runs
      [L,LX] = testCase.get_liu_L_LX(tok,shot,time,'itert',itert,'iterq',iterq,'slx',true,...
                                     'bfct',str2func(bfct),'bfp',bfp,...
                                     'iters',48,'noq',50,'bslvdprec',bslvdprec,'wreg',wreg,...
                                     'wregadapt',15,'npq',18,'pq',[],'iqR',iqR,'raS',raS,'pql',0.4,...
                                     'infct',infct,'rn',1,'zn',0.1,'nFW',nFW,...
                                     'iterh',iterh,'tolh',tolh,'ipmhdprec',ipmhdprec,...
                                     'Ipmeas',Ipmeas);

      % Run simulation
      LY = liutsim(L,LX);
       
      testCase.verifyTrue(~isempty(LY),'Simulation yielded empty LY');
      
      % Check settings
      check_simulink_settings(testCase,BFCTMODE,HFCTMODE,IPMODE,BSLVMODE,...
                                       INFCTMODE,AWMODE,LOCSMODE,LOCRMODE,...
                                       ITERQMODE);
    end

    function liusim_movie_test(~)
      % Test integration with meqmovie

      tok_ = 'ana';
      shot_ = 1;
      time = 0:1e-3:0.1;

      %% Get L,LX
      [L,LX] = liu(tok_,shot_,time,'slx',true,'iterq',1);

      % run Simulink with itert=2
      LY = liutsim(L,LX);

      %% Plot using meqmovie
      meqmovie(L,LY,'decimate',10);
    end
  end
  
  
  methods
    function check_simulink_settings(testCase,varargin)
      % Checks that values in input correspond to stored value for model

      use_sldd = isempty(getenv('GITLAB_CI')) || verLessThan('matlab','9.6.0');
      
      if use_sldd
        sldd = Simulink.data.dictionary.open('liutslx.sldd');
        data = sldd.getSection('Design Data');
      end
      for ii = 2:nargin
        name = inputname(ii);
        if use_sldd
          entry = data.getEntry(name);
          value = entry.getValue();
        else
          value = evalin('base',name);
        end
        testCase.verifyEqual(value,varargin{ii-1},sprintf('Mismatch for sldd entry %s',name));
      end
        
    end
  end

  methods(Static)
    function [L,LX] = get_liu_L_LX(tok,shot,time,varargin)

      % Get L
      L = liu(tok,shot,time,varargin{:});

      % Get LX
      switch lower(tok)
        case 'tcv'
          LX = liuxtcv(shot,time,L);
        case 'ana'
          [Lfbt,~,LYfbt] = fbt(tok,shot,[]);
          LYfbt = meqinterp(LYfbt,time,'nearest');
          LYfbt.t = time;
          LX = meqxconvert(Lfbt,LYfbt,L,true); % Return minimal LX
        otherwise
          error('Expected tok=''tcv'' or ''ana'' but got ''%s'' instead',tok);
      end
      LX = liux(L,LX);
    end
  end
end

%% Auxiliary functions
function bfp = bf3ip()
% Replicate bf3pmex basis functions with bf3imex
n=41;
FN=linspace(0,1,n).';
GN=[FN-1,FN-1,(FN-1).*FN];
IGN=bfprmex(GN);
FP=[1;0;0];FT=[0;1;1];
bfp = struct('gNg',GN,'IgNg',IGN,'fPg',FP,'fTg',FT);
end
