classdef liu_tests < meq_test
  % Tests for LIUQE scanning across multiple equilibria and parameters
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    tok = 'ana';
    t = 0;
    verbose = 0;
    tol = 1e-5;
    L,LX;
  end

  properties(ClassSetupParameter)
    config = {'limited','diverted','secondaryX'};
    iters = struct('off_0',0,'on_48',48);
    itert={0 1};
    ipm={false,true}
  end
  
  methods(TestClassSetup)
    function setup_LLX(testCase,config,iters,itert,ipm)
      switch config
        case 'limited'
          shot = 1;
        case 'diverted'
          shot = 2;
        case 'secondaryX'
          shot = 3;
        otherwise
          error('undefined config %s\n',config)
      end
      tok_ = testCase.tok;
      t_   = testCase.t;
      
      % LIU from FBT
      [L_,LX_] = liu(tok_,shot,t_,'selu','e','nu',20,...
        'iters',iters,'itert',itert,...
        'ipm',ipm,'debug',testCase.verbose);
      
      testCase.LX = LX_;
      testCase.L  = L_;
    end
  end
  
  methods(Test,TestTags={'Unit'})
    function test_constraint_exact(testCase)
      % Tests for constraining ag values.
      
      L = testCase.L; %#ok<*PROP,*PROPLC>
      LX = testCase.LX;

      L.P.psichco = 1e-8;
      %% Initial LIUQE run with bfp = [1,2]
      LY0 = liut(L,LX);

      %% New LIUQE run with bfp = [1,3] and constraining last basis function coefficient to 0
      P = L.P;
      G = L.G;
      P.bfp = [1,3];
      P.liuagcon = {'ag','ag','ag','ag'}; % constrain ag values
      P.wCo = [Inf;0;0;Inf]; % First and last one are equality constrained
      LX.ag = [LY0.ag;0]; % prescribed values into special ag field

      L = liuc(P,G);
      
      LY1 = liut(L,LX);
      
      % Verifications
      mask = ~L.kj(1:L.ng);
      tol = testCase.tol*max(abs(LY0.ag));
      testCase.verifyEqual(LY1.ag(mask),LX.ag(mask),'AbsTol',1e-10,'Constrained ag value does not match its input value');
      testCase.verifyEqual(LY1.ag(1:3),LY0.ag,'AbsTol',tol,'Results of two LIUQE runs do not match');
     
    end
    
    function test_ag_constraint_approx(testCase)
      % Tests for constraining ag values.
      
      L = testCase.L;
      LX = testCase.LX;
      nt = numel(LX.t);

      %% Runs with large increasing weights for last basis function coefficient
      P = L.P;
      G = L.G;

      P.bfp = [1,3]; % 4 basis functions
      P.liuagcon = {'ag','ag','ag','ag'}; % ag=Co
      P.wCo = [0;0;0;1e-4]; % all Co are free, ag is only constrained via measurements
      L = liuc(P,G);
      LX.ag = zeros(L.ng,nt); % target
      LY0 = liut(L,LX);
      
      P.wCo = [0;0;1e-2]; % Penalize Co(3)-LX.ag(3) to bring ag(3) closer to LX.ag(3)
      L = liuc(P,G);
      LY1 = liut(L,LX);
      
      P.wCo = [0;0;Inf]; % Hard-constrain ag(3) = Co(3) = 0
      L = liuc(P,G);

      LY2 = liut(L,LX);
      
      % Verifications
      mask = logical(P.wCo);
      testCase.verifyLessThan(abs(LY1.ag(mask)),abs(LY0.ag(mask)),...
        'Penalizing ag value does not cause it to decrease');
      testCase.verifyEqual(LY2.ag(mask),LX.ag(mask),...
        'Constraining ag value does not make it equal to the target');
     
    end
   
    function test_Ie_constraint_exact(testCase)
      % Tests for constraining Ie values.

      L = testCase.L; %#ok<*PROP,*PROPLC>
      LX = testCase.LX;
      nt = numel(LX.t);

      %% Initial LIUQE run with standard coil set
      LY0 = liut(L,LX);

      %% New LIUQE run with extra coil but set its current to 0
      L1 = liu(L.P.tok,L.P.shot,LX.t,'selu','e','nu',20,...
        'iters',L.P.iters,'itert',L.P.itert,...
        'ipm',L.P.ipm,'debug',L.P.debug,...
        'fastcoil',1,'wIa',[ones(L.G.na,1);Inf]);
      LX1 = LX; LX1.Ia = [LX.Ia;zeros(1,nt)];

      LY1 = liut(L1,LX1);

      % Verifications
      tol = testCase.tol*max(abs(LY0.Ia));
      testCase.verifyEqual(LY1.Ia(  end  ),LX1.Ia(end),'Constrained Ie value does not match its input value');
      testCase.verifyEqual(LY1.Ia(1:end-1),LY0.Ia     ,'AbsTol',tol,'Results of two LIUQE runs do not match');
    end
    
    function iterator_test(testCase)
      % test liut use as iterator
      L_ = testCase.L;
      LX_ = testCase.LX;
      if L_.P.itert
        % for itert, check that doing 2 steps as iterator is the same as doing 2 steps in itert
        L_.P.itert = 1;
        LY1 = liut(L_,LX_); % one step
        % another step as iterator
        LY1 = liut(L_,LX_,...
          'Iy',LY1.Iy,'Ie',[LY1.Ia;LY1.Iu],...
          'ag',LY1.ag,'Co',LX2Co(L_,LY1),'dz',LY1.dz,'rst',false);
        
        % two steps with itert
        L_.P.itert = 2;
        LY2 = liut(L_,LX_);

        tol = 1e-8;
        testCase.assertEqual(LY1.Iy,LY2.Iy,'AbsTol',L_.Iy0*tol);
        testCase.assertEqual(LY1.Ia,LY2.Ia,'AbsTol',L_.Ia0*tol);
        testCase.assertEqual(LY1.ag,LY2.ag,'AbsTol',L_.ag0*tol);
        testCase.assertEqual(LY1.dz,LY2.dz,'AbsTol',L_.dzx*tol);
      else
        % for itert=0, check that additional call converges in 1 step
        tol = 1e-4;
        L_.P.psichco = tol;
        LY1 = liut(L_,LX_);
        LY2 = liut(L_,LX_,'Iy',LY1.Iy,'Ie',[LY1.Ia;LY1.Iu],...
          'ag',LY1.ag,'Co',LX2Co(L_,LY1),'dz',LY1.dz,'rst',false);
        
        % Check that values match
        testCase.verifyEqual(LY1.Ia,LY2.Ia,'AbsTol',L_.Ia0*tol);
        testCase.verifyEqual(LY1.Iu,LY2.Iu,'AbsTol',L_.Iu0*tol);
        testCase.verifyEqual(LY1.ag,LY2.ag,'AbsTol',L_.ag0*tol);
        testCase.verifyEqual(LY1.dz,LY2.dz,'AbsTol',L_.dzx*tol);
        % check only one iteration was carried out and residual is smaller
        testCase.assertEqual(LY2.niter,1,'expected 1 iteration');
        testCase.assertLessThan(LY2.res,LY1.res,'expected lower residual');
      end
    end
  end
end
