classdef (SharedTestFixtures={mexm_fixture}) ipmj_test < meq_test
  % Tests for IPM (Interior Point Method)
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    L,Wrj,Wkj,Yr,Ye,Yk,Ie0,Qcj,Xc,Ie,aj;
    C,d,Tyg; % For diagnostics
    errtol = 1e-3;
    chitol = 1e-3;
    maxiter = 150;
    verbosity = 0;
  end
  
  properties(ClassSetupParameter)
    shot = struct('limited',1,'diverted',2);
  end
  
  properties(TestParameter)
    method = {'ipmjmex','ipmjmexm'};
  end
    
  methods(TestClassSetup)
    function ipmjTestSetup(testCase,shot)
            
      s = warning('off','MATLAB:nearlySingularMatrix');
      testCase.addTeardown(@() warning(s));
      
      testCase.assumeTrue(~isempty(which('lsqlin')), 'lsqlin missing, need optimization toolbox for fbtt')
      if ~isempty(which('optimoptions')) % optimoptions not available in Octave
        opt = optimoptions('lsqlin','algorithm','interior-point','display','off',...
                           'MaxIterations',testCase.maxiter,'StepTolerance',testCase.errtol);
      else
        opt = [];
      end
      
      [L,LX,LY] = liu('ana',shot,1); %#ok<*PROP,*PROPLC>
      
      Xd = [LX.Ff;LX.Bm;LX.Ia;LX.Iu;LX.Ft;LX.Ip]; % fmautp
      Yd = L.Wd.*Xd;
      
      % Prepare inputs to ipmj - Restart from last iteration
      Ie0 = [LY.Ia;LY.Iu];
      Fx = LY.Fx;
      FA = LY.FA;
      FB = LY.FB;
      rA = LY.rA;
      Opy = LY.Opy;
      Iy = LY.Iy;
      rBt = LX.rBt;
   
      Wij = zeros(L.ni,L.nj);
      Wqj = zeros(L.nq,L.nj);
      Qcj = zeros(L.nc,L.nj);
      Yk  = zeros(L.nk,1);
      Yr  = Yd(L.kdr);
      Ye  = Yd(L.kde);
      
      Yk(L.kki) = Yd(L.kdi);
      testCase.assertTrue(~any(L.Wj),'Should not try to use individual aj constraints in ipmj_test');
   
      b0 = bboxmex(logical(Opy),L.zy,L.ry);
      ie0 = (b0(4) - b0(2)) / (b0(3) - b0(1));
      
      % Base functions and regularisation and constraint matrix
      [~,fTg         ] = L.bfct(0,L.bfp                        );
      [Tyg,Tpgi,ITpgi] = L.bfct(1,L.bfp,Fx,FA,FB,Opy,L.ry,L.iry);
      [Qqg,Xq        ] = L.bfct(6,L.bfp,[],FA,FB,rA,1/rA,L.idsx);
      [Qcg,Xc        ] = L.bfct(7,L.bfp,[],FA,FB               );
   
      iD = 1; Tpg(:,iD) = Tpgi; ITpg(:,iD) = ITpgi; % doublet prep
   
      % Response matrices
      if L.ndz, Mrj = respmex(L.Mry,Tyg,L.Mdzry,Iy(:));
      else,     Mrj =         L.Mry*Tyg               ;
      end
      Wrj = L.Wr.*Mrj;
      Wij(L.kit,1:L.ng) = L.Wt*(2e-7*(ITpg.*fTg)'+L.Mty*Tyg*rBt); % Ft
      Wij(L.kip,1:L.ng) = L.Wp*Tpg';                              % Ip
      % adapt regularisation with elongation
      Wq = L.Wq*min(exp(-(1./ie0-L.P.elomin).*L.P.wregadapt),1);
      Wqj(:,1:L.ng) = Wq.*Qqg;
      Yq = Wq.*Xq;
      Qcj(:,1:L.ng) = Qcg;
   
      Yk(L.kkq) = Yq;
      Wkj = [Wij;Wqj;L.Wjj];
      
      % Prepare lsqlin call
      C = [L.Wre                      Wrj              ; ...
           L.Wee                      zeros(L.ne,L.nj) ; ...
           zeros(L.ni+L.nq+L.nj,L.ne) Wkj              ];
      d = [Yr;Ye;Yk];
      A = [zeros(1,L.ne),-Qcj];
      b = -Xc;
      
      x = lsqlin(C,d,A,b,[],[],[],[],[],opt);
      Ie = x(1:L.ne);
      aj = x(L.ne+1:end);
      
      testCase.L  = L;
      testCase.Wrj = Wrj;
      testCase.Wkj = Wkj;
      testCase.Yr = Yr;
      testCase.Ye = Ye;
      testCase.Yk  = Yk;
      testCase.Ie0  = Ie0;
      testCase.Qcj = Qcj;
      testCase.Xc = Xc;
      testCase.Ie = Ie;
      testCase.aj = aj;
      
      testCase.C   = C;
      testCase.d   = d;
      testCase.Tyg = Tyg;
    end
  end
  
  methods(Test, TestTags = {'Unit'})
    function ipmjTest(testCase,method)

      L = testCase.L;
      Wrj = testCase.Wrj;
      Wkj = testCase.Wkj;
      Yr  = testCase.Yr;
      Ye  = testCase.Ye;
      Yk  = testCase.Yk;
      Ie0 = testCase.Ie0;
      Qcj = testCase.Qcj;
      Xc  = testCase.Xc;
      
      tol = 1e-2; % With norm(aj,Inf)~1e4, this amounts to reltol~1e-6
      
      % ipmj emulator
      uAjj = uatamex(Wrj) + uatamex(Wkj);
      Aje = Wrj'*L.Wre;
      Aej = -L.Aer*Wrj;
      Ie1 = L.Aer*Yr+L.Aee*Ye;
      aj1 = -Wrj'*Yr-Wkj'*Yk;
      aj0 = double(any(Qcj,1)).';
      [s,z] = deal(repmat(tol,size(Xc)));
      [Ie,aj,ST] = feval(method,uAjj,Aje,Aej,Ie1,aj1,Ie0,aj0,Qcj,Xc,s,z,tol,testCase.maxiter);
      
      Yd = [Yr;Ye;Yk(L.kki)];
      
      Zd = testCase.C(1:L.nd,:)*[Ie;aj];
      chi = sqrt(sum((Zd-Yd).^2)/L.nd); % chi of current fit
      
      Ie_ = testCase.Ie;
      aj_ = testCase.aj;
      
      Zd_ = testCase.C(1:L.nd,:)*[Ie_;aj_];
      chi_ = sqrt(sum((Zd_-Yd).^2)/L.nd);
      
      
      testCase.verifyTrue(logical(ST)  , 'ipmjmex procedure returned an error')
      testCase.verifyTrue(all(Qcj*aj>Xc), 'aj does not respect inequality constraint')
      testCase.verifyLessThan(chi./chi_ - 1, testCase.chitol, 'ipmj fit is worse than reference fit (lsqlin) by more than tolerance')
      % NOTE, not checking difference in aj, Ie since methods sometimes
      % yield different answers but with similar fit agreement.
      
      if testCase.verbosity
        figure;
        ax = subplot(1,2,1);
        hold(ax,'on');
        plot(ax,Yd-Zd,1:L.nd);
        plot(ax,Yd-Zd_,1:L.nd,'--');
        ax.YDir = 'reverse';
        ylabel(ax,'Meas index');
        title(ax,'Residual (fmautp)');
        
        ax = subplot(1,2,2);
        hold(ax,'on');
        [~,h]=contour(ax,L.ry,L.zy,reshape(testCase.Tyg*aj(1:L.ng),L.nzy,L.nry),'Linestyle','-');
              contour(ax,L.ry,L.zy,reshape(testCase.Tyg*aj_(1:L.ng),L.nzy,L.nry),h.LevelList,'Linestyle','--');
        title(ax,'Current distribution')
      end
    end
  end 
end
