classdef liutcv_tests < meq_test
  % Test class for tests of LIUQE on TCV shot,time database
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties(MethodSetupParameter)
    bfct      = {'bfabmex','bf3pmex'} % basis function options (use names only, no function handles)
    itert     = struct('no',0); % others don't work now
    idml      = struct('yes',1); % with and without dml
    iterh     = struct('std_30',30); % number of iterations of nonlinear initial fitting
    robust    = struct('no',0); % fit measurements using robust fitting (requires Statistics and Machine Learning Toolbox)
    wreg      = struct('std_2Em6',2e-6);
    wregadapt = struct('std_15', 15);
    elomin    = struct('std_1p1', 1.1);
  end

  properties(MethodSetupParameter,Abstract)
    ishot % use subclasses to define this value
  end
  
  properties
    G % LIUQE geometry structure (fixed)
    L % LIUQE parameter structure (updated for each shot)
    LX % LIUQE Experiment data structure
    LXdb % database of LX
    datadir % location for test data
    paramstr % string of parameters for this test
    
    shotg = 50000; % reference shot for machine geometry
    datafile = 'liutcvselsampLX.mat';
    tol_Ip = 0.1; % relative error allowed in Ip fit
    tol_Ia = 1000; % absolute error allowed in Ia fit
    
    verbosity = 0;
  end
  
  methods (Static)
    function x = get_ishot(nshots_sel)
      datfile = liutcv_tests.get_datfile();
      % get LX data
      dd = readtable(datfile);
      
      % get liuqe data for these shots and times
      shots = dd.shot.'; times = dd.time.';
      nslices = numel(shots);
      
      s = rng(1); % initialize random variable generator to fixed value;
      sel = randperm(nslices); % random selection
      sel = sel(1:nshots_sel); % get first nshots_sel
      % This is more efficient but results in a different selection of shots and slices.
      % sel = randperm(nslices,nshots_sel);
      sel = sort(sel);
      rng(s); % reset rng.
      
      x = struct();
      for ii = sel
        x.(strrep(sprintf('TCV%05d_%5.3f',shots(ii),times(ii)),'.','s')) = ii;
      end
      
    end
    
    function excludelist = get_excludelist()
      excludelist = {
        40042,'faulty magnetic probe';
        49823,'stray shot';
        54001,'TMP: early slices do not work with DML';
        57942,'faulty magnetic probe';
        };
      for shot = 50000:50606
        excludelist(end+1,:) = {shot,'Faulty magnetics sensors in range 50000->50606'};
      end
    end
    
    function [isexcluded,msg] = is_excluded(shot)
      excludelist = liutcv_tests.get_excludelist;
      iexcl = (shot==cell2mat(excludelist(:,1)));
      isexcluded = any(iexcl);
      if isexcluded
        msgs = excludelist(:,2);
        msg = sprintf('Shot %d is on exclusion list due to %s',shot,cell2mat(msgs(iexcl)));
      else
        msg = '';
      end
    end
    
    function datfile = get_datfile()
      datfile = 'liutcvselsamp.dat'; % contains shot,time for database
    end
  end
  

  
  methods (TestClassSetup)
    
    function setEnvironment(testCase)
      warning('off','backtrace');
      testCase.addTeardown(@() warning('on','backtrace'));
      
      w=warning; % store warning state
      testCase.addTeardown(@() warning(w));
      warning('off','MATLAB:rankDeficientMatrix'); %turn off rank warnings

    end
    
    function setDataDir(testCase)
      % load from environment variable (from gitlab settings)
      % or set default
      env = getenv('TESTDATADIR');
      if isempty(env)
        testCase.datadir = fullfile('~','testdata','meq');
      else
        testCase.datadir = env;
      end
    end
    
    function loadG(testCase)
      LL = liuqe(testCase.shotg,0);
      testCase.G = LL.G;
    end
    
    function loadLX(testCase)
      
      fname_testdata = fullfile(testCase.datadir,testCase.datafile);
      if exist(fname_testdata,'file')
        tmp  = load(fname_testdata,'LXdb');
        testCase.LXdb = tmp.LXdb;
      else
        fprintf('%s not found, loading test data from MDS...\n',fname_testdata);
        
        %% generate test LX database
        % load LX data from mds
        datfile = liutcv_tests.get_datfile();
        
        % get LX data
        dd = readtable(datfile);

        % get liuqe data for these shots and times
        shots = dd.shot.'; times = dd.time.';
        if numel(shots)>50
          fprintf('Loading data for %d time slices, this may take a while...\n',numel(shots));
        end
        [~,LXdb] = liu('TCV',shots,times); %#ok<*PROP>
        
        % save to test data directory for future faster access
        if ~exist(testCase.datadir,'dir')
          s = mkdir(testCase.datadir);
          assert(s,'could not create directory %s',testCase.datadir);
        end
        save(fname_testdata,'LXdb');
        fprintf('Saved LX data in %s for future use\n',fname_testdata);
        
        testCase.LXdb = LXdb;
      end
    end
    
  end
  
  methods(TestMethodSetup)
    function set_L(testCase,ishot,bfct,itert,idml,robust,wreg,wregadapt,elomin,iterh) 
      %% sets this shot's L
      % liuqe parameter structure for this shot
      
      % set dependent parameters
      if iterh>0
        % when using nonlinear solver for initial curent
        tolh= 2e-3; %#ok<*NASGU>
        zup = NaN; % NaNs set box automatically equal to limiter
        zlp = NaN;
        rip = NaN;
        rop = NaN;
        nelem = 10; % allow more elements in current fit
      end
      bfp = []; % use default bfp settings per basis function
      % re-run this part from liu.m for every shot - since FE parameters
      % depend on shot
      try
        bfct = str2func(bfct); % Convert to function handle
        % attempt to get parameters for this shot
        % Assemble name-value pairs for parameters
        plist={'bfct','bfp','itert','idml','robust','wreg','wregadapt','elomin','iterh'};
        vlist={ bfct , bfp , itert , idml , robust , wreg , wregadapt , elomin , iterh};
        if iterh > 0
          plist=[plist, {'nelem','tolh','zup','zlp','rip','rop'}];
          vlist=[vlist, { nelem , tolh , zup , zlp , rip , rop }];
        end
        pvpairs = [plist;vlist];
        
        shot = testCase.LXdb.shot(ishot);
        P = liuptcv(shot,pvpairs{:});
        
        % Generate string used to show how to reproduce test
        paramstr = [sprintf('''%s'',@%s,',pvpairs{1,1},func2str(pvpairs{2,1})),... % Function handle
          sprintf('''%s'',%g,',pvpairs{3:end})];
        testCase.paramstr=paramstr(1:end-1);
      catch ME
        testCase.assumeTrue(false,sprintf('Invalid L.P in shot %d: %s',...
          testCase.LXdb.shot(ishot),getReport(ME,'basic')));
      end
      
      P = liup(P);
      G = liug(testCase.G,P); %#ok<*PROPLC>
      testCase.L = liuc(P,G); % Final assembly of all parameters
    end
    
    function set_LX(testCase,ishot)
      %% set this shot's LX
      testCase.LX = meqxk(testCase.LXdb,ishot);
    end
  end
  
  
  methods(Test,TestTags = {'TCV'})
    function test_liutcv_db(testCase)
      % Main function to run the test on the database

      % display info
      shot = testCase.LX.shot;
      time = testCase.LX.t;
  
      % check that shot is not on the exclusion list
      [isexcluded,msg] = liutcv_tests.is_excluded(shot);
      testCase.assumeFalse(isexcluded,msg); % skip the shot if it is
      
      shottstr = sprintf('shot %d: t=%5.3f',shot,time);
      failcmdstr = sprintf('[L,LX,LY]=liuqe(%d,%5.3f,%s)',shot,time,testCase.paramstr);
      
      try
        lastwarn(''); % reset lastwarn
        LY = liut(testCase.L,testCase.LX);
      catch ME
        % run-time error
        testCase.assertFail(...
          sprintf('liut.m run-time error. liuqe call: %s\n %s\n',failcmdstr,getReport(ME,'extended')))
      end
      
      if ~isempty(lastwarn)
        % liut issued a warning
        if testCase.verbosity
          LY = liut(testCase.L,testCase.LX,'debug',1);
        end
        
        % liut warning returned.
        % Some warnings are ok, but for others we should mark the test as
        % failed.
        [~,wrnID] = lastwarn; % get last warning
        
        if contains(wrnID,'LowIp')
          % test can be filtered in this case
          testCase.assumeFail('liut.m failed due to too low Ip')
        elseif contains(wrnID,{'iterhFail','fixOpy'})
          % ok for continuing testing
        else
          % throw an assertion failure to mark this test as not-passed.
          testCase.assertFail(...
            sprintf('liut.m returned not-allowed warning:\n %s\n liuqe call:\n %s. ',lastwarn,failcmdstr));
        end
      end
      
      % Direct checks on results
      Ip_rel_error = abs(LY.Ip - testCase.LX.Ip)./abs(LY.Ip); % Ip relative error
      Ia_abs_error = max(abs(LY.Ia - testCase.LX.Ia)); % Ia absolute error
      
      testCase.verifyLessThan(Ip_rel_error,testCase.tol_Ip,...
        sprintf('Exceeded Ip error tolerance.\n Failed command:\n %s\n',failcmdstr));
      
      testCase.verifyLessThan(Ia_abs_error,testCase.tol_Ia,...
        sprintf('Exceeded Ia error tolerance.\nFailed command:%s\n',failcmdstr));
      
    end
  end
end
