classdef ipm_test < meq_test
  % Tests for IPM
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    H,f,A,b,Aeq,beq,lb,ub,osel,modes
  end
  
  properties(ClassSetupParameter)
    problem = {'I','P','Sp','Sn','Sp1','Sn1','S1p','S1n','S1p1','S1n1','E','EB'};
  end
  
  methods(TestClassSetup)
    function set_options(testCase)
      testCase.osel = optimoptions('quadprog','Display','None');
    end
    
    function set_problem(testCase,problem)
      [H,f,A,b,Aeq,beq,lb,ub] = deal([]); %#ok<*PROPLC>
      % Examples taken from help of quadprog:
      switch problem
        case 'I'
          % "Quadratic Program with Linear Constraints"
          H = [1 -1; -1 2];
          f = [-2; -6];
          A = [1 1; -1 2; 2 1];
          b = [2; 2; 3];
          modes = [0,1,2];
        case 'P'
          % Adapted from I for positive solutions
          H = [1 -1; -1 2];
          f = [-2; -6];
          lb = zeros(2,1);
          ub = Inf(2,1);
          modes = [0,1,2,3];
        case {'Sp','Sn','Sp1','Sn1'}
          % Adapted from I for imposed signs
          H = [1 -1; -1 2];
          f = [-2; -6];
          switch problem
            case 'Sp',  A =  2.5;
            case 'Sn',  A = -2.5;
            case 'Sp1', A =  1.0;
            case 'Sn1', A = -1.0;
          end
          A = A*eye(2,2);
          b = zeros(2,1);
          modes = [0,1,2,4];
        case {'S1p','S1n','S1p1','S1n1'}
          % Adapted from I for one imposed sign
          H = [1 -1; -1 2];
          f = [-2; -6];
          switch problem
            case 'S1p',  A =  2.5;
            case 'S1n',  A = -2.5;
            case 'S1p1', A =  1.0;
            case 'S1n1', A = -1.0;
          end
          A = A*[0,1];
          b = 0;
          modes = [0,1,2,5];
        case 'E'
          % "Quadratic Program with Linear Equality Constraint"
          H = [1 -1; -1 2];
          f = [-2; -6];
          Aeq = [1 1];
          beq = 0;
          modes = [0,1];
        case 'EB'
          % "Quadratic Program with Linear Constraints and Bounds"
          H = [ 1,-1, 1
               -1, 2,-2
                1,-2, 4];
          f = [2;-3;1];
          lb = zeros(3,1);
          ub = ones(size(lb));
          Aeq = ones(1,3);
          beq = 1/2;
          modes = [0,1];
      end
      testCase.H     = H;
      testCase.f     = f;
      testCase.A     = A;
      testCase.b     = b;
      testCase.Aeq   = Aeq;
      testCase.beq   = beq;
      testCase.lb    = lb;
      testCase.ub    = ub;
      testCase.modes = modes;
    end
  end
  
  methods(Test, TestTags = {'Unit'})
    function ipmTest(testCase)
      
      H     = testCase.H; %#ok<*PROP>
      f     = testCase.f;
      A     = testCase.A;
      b     = testCase.b;
      Aeq   = testCase.Aeq;
      beq   = testCase.beq;
      lb    = testCase.lb;
      ub    = testCase.ub;
      
      osel  = testCase.osel;
      
      modes = testCase.modes;
      
      [x,~,stat,~,~] = quadprog(H,f,A,b,Aeq,beq,lb,ub,[],osel);
      testCase.assertGreaterThan(stat,0,'Output status of quadprog must be greater than 0');
      
      % Use meq's ipm method
      niter = osel.MaxIterations;
      tol = osel.StepTolerance;
      for mode = modes
        [x0,stat] = testCase.run_ipm(H,f,A,b,Aeq,beq,lb,ub,mode,niter,tol);
        
        testCase.verifyTrue(stat,sprintf('ipm(mode=%d) did not converge',mode));
        testCase.verifyEqual(x,x0,'AbsTol',1e-6,sprintf('Solution of quadprog and ipm(mode=%d) do not match within tolerance',mode));
      end
      
    end
  end
  
  methods
    function [x,stat] = run_ipm(testCase,H,f,A,b,Aeq,beq,lb,ub,mode,niter,tol)
      
      if mode > 1
        testCase.verifyTrue(isempty(Aeq) && isempty(beq),'For modes other than 0 and 1, ipm does not support equality constraints')
      end
      
      nvar = size(H,1);
      % Group bounds with constraints
      Aipm = [-A;eye(nvar*~isempty(lb));-eye(nvar*~isempty(ub))];
      bipm = [-b;lb;-ub];
      % ipm does not support infinite bounds
      mask = isfinite(bipm);
      Aipm = Aipm(mask,:);
      bipm = bipm(mask);
      
      switch mode
        case 3
          test = isequal(Aipm,eye(nvar)) && isequal(bipm,zeros(nvar,1));
          testCase.assertTrue(test,'For ipm mode 3, Aipm,bipm must be equivalent to x>0');
        case 4
          test = isequal(Aipm/Aipm(1,1),eye(nvar)) && isequal(bipm,zeros(nvar,1));
          testCase.assertTrue(test,'For ipm mode 4, Aipm,bipm must be equivalent to a*x>0');
        case 5
          [i,j,s] = find(Aipm);
          test = numel(i) == 1 && b(i) == 0;
          testCase.assertTrue(test,'For ipm mode 5, Aipm,bipm must be equivalent to a*x(b)>0');
          Aipm = s; bipm = j;
      end

      [x,~,~,~,kt] = ipm(H,f,Aipm,bipm,Aeq,beq,[],[],[],[],mode,niter,tol);
      stat = kt < niter;
    end
  end
end
