classdef (SharedTestFixtures={mexm_fixture}) ipm2_test < meq_test
  % Tests for IPM2
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    x,H,f,UH,A,b,xi,zi,si,Nx,flag
    itertol = 1e-12;
    errtol = 1e-2;
    maxiter = 150;
    verbosity = 0;
  end
  
  properties(ClassSetupParameter)
    constraints = {'none','inactive','intersect','active','infeasible','no_minimum'};
  end
  
  properties(TestParameter)
    method = {'ipm2mex','ipm2mexm'};
  end
    
  methods(TestClassSetup)
    function ipm2TestSetup(testCase,constraints)
            
      testCase.assumeTrue(~isempty(which('quadprog')), 'quadprog missing, need optimization toolbox for fbtt')
      if ~isempty(which('optimoptions')) % optimoptions not available in Octave
        opt = optimoptions('quadprog','algorithm','interior-point-convex','display','none');
      else
        opt = [];
      end
      
      testCase.Nx = 2;
      xy0 = rand(testCase.Nx,1);
      xy1 = rand(testCase.Nx,1);
      t = 360*rand;
      ax1 = 1.5 + rand;
      ax2 = 1.5 + rand;
      
      R = [cosd(t) sind(t); -sind(t) cosd(t)];
      HH = R' * diag([1/ax1^2   1/ax2^2]) * R;
      testCase.H = 2 * round((HH+HH')/2, 3);
      testCase.f = round((-[2*xy0(1)/ax1^2    2*xy0(2)/ax2^2] * R)', 3);
      
      testCase.A = xy1.';
      switch constraints
        case  'inactive', offset = -0.4;
        case 'intersect', offset =  0.0;
        case    'active', offset = +0.4;
        otherwise,        offset =  0.0;
      end
      testCase.b = dot(xy1,R.'*xy0) + offset; % Make sure exact minimum is excluded
      
      testCase.UH = testCase.H(triu(true(size(testCase.H)))); % upper triangular part of H
      testCase.xi = sign(testCase.A*[1;1] - testCase.b)*[1;1];
      testCase.zi = 1;
      testCase.si = 1;
      
      switch constraints
        case 'none', testCase.A = zeros(1,2);
                     testCase.b = zeros(1,1);
                     testCase.xi = [1;1];
        case 'infeasible', testCase.A = [1,0;-1,0];
                           testCase.b = [0.1;0.1];
                           testCase.zi = [1;1];
                           testCase.si = [1;1];
        case 'no_minimum', testCase.H = -testCase.H;
                           testCase.UH = -testCase.UH;
      end

      % Disable some warnings when problem is infeasible
      if strcmp(constraints,'infeasible')
        s = warning('off','MATLAB:nearlySingularMatrix');
        testCase.addTeardown(@() warning(s));
      end
      
      [testCase.x,~,testCase.flag] = quadprog(testCase.H,testCase.f,-testCase.A,-testCase.b,[],[],[],[],[],opt);      
    end
  end
  
  methods(Test, TestTags = {'Unit'})
    function ipm2Test(testCase,method)

      [x2,~,~,~,ST] = feval(method,testCase.UH,testCase.f,testCase.A,testCase.b,...
        testCase.xi,testCase.si,testCase.zi,testCase.itertol,testCase.maxiter,logical(testCase.verbosity));
      
      if testCase.flag == 1
        testCase.verifyTrue(logical(ST), sprintf('%s procedure returned an error',method))
        testCase.verifyLessThan(norm(testCase.x-x2),testCase.errtol,'error out of the prescribed bounds, test result failure')
      else
        testCase.verifyFalse(logical(ST), sprintf('%s procedure did not return an error when quadprog did not converge',method));
      end
      
      if testCase.verbosity
        nA = size(testCase.A,1);
        Np = 100;
        xg = linspace(-3,3,Np)';
        [xg1,xg2] = meshgrid(xg,xg);
        F = zeros(Np);
        C = zeros(Np,Np,nA);
        for r = 1:length(xg)
          for c = 1:length(xg)
            xx = [xg1(r,c);xg2(r,c)];
            F(r,c) = xx'*testCase.H*xx/2 + testCase.f'*xx;
            C(r,c,:) = testCase.A*xx - testCase.b;
          end
        end
        figure
        contour(xg1,xg2,F), grid on, hold on
        for ii = 1:nA, contour(xg1,xg2,C(:,:,ii),[0 0],'-k');end
        if testCase.flag == 1
          plot(testCase.x(1),testCase.x(2),'xr',x2(1),x2(2),'^b')
        end
      end
    end
  end 
end
