classdef meqsqp_test < meq_test
  % Test for SQP procedure in meqopt
  % see help meqopt for problem description
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties
    tol = 1e-8;  % cost function tolerance
    ctol = 1e-8; % constraint tolerance
    verbosity = 0;
    nmax = 20; % maximum iterations
    qpsolve, qpopts
    qpdebug = false;
  end
  
  properties(TestParameter)
    functiontype = {'linear','nonlinear-u','nonlinear-x'}
    multivariable = {'false','true'}
    costfunction = {'x','u','z','xu'};
    equalityconstraint = {'none','linear-x','linear-u'};
    inequalityconstraint = {'none','linear-x','linear-u'};

  end
  
  methods(TestClassSetup)
    
    function setup_qpsolve(testCase)
      %% QP solver settings
      qpmaxiter = 50;
      qpsolver = 'ipm';
      useopt = strcmp(qpsolver,'ipm-optimized');
      testCase.qpopts = ipmopts('useoptimized',useopt,...
        'niter',qpmaxiter,...
        'debug',testCase.qpdebug,...  % QP verbosity (for debugging)
        'presolve',true); % use presolver
      testCase.qpsolve = @ipmwrapper; % local wrapper for ipm
    end
  end
  
  methods(Test, TestTags = {'Unit'})      

    function test_meq_sqp(testCase,functiontype,multivariable,costfunction,equalityconstraint,inequalityconstraint)
      multivariable = eval(multivariable);

      % skip scalar cases with multiple constraints
      if ~multivariable && ~any(ismember({equalityconstraint,inequalityconstraint},'none'))
        if testCase.verbosity
          testCase.assumeFail('Skip combining equality & inequality constraints for single-variable case')
        else
          return % silently skip
        end
      end
      %% Model function
      [F,Jx,Ju,nx,nu,nz] = testCase.model_function_factory(functiontype,multivariable);
      
      %% Cost function definition
      [Raxc{1},Rauc{1},Razc{1},Ra0c{1}] = ...
        testCase.cost_function_factory(costfunction,nx,nu,nz);

      %% Constraints
      [Rexc{1},Reuc{1},Rezc{1},Re0c{1},neq] = ...
        testCase.equality_constraint_factory(equalityconstraint,nx,nu,nz);
    
      [Rixc{1},Riuc{1},Rizc{1},Ri0c{1},niq] = ...
        testCase.inequality_constraint_factory(inequalityconstraint,nx,nu,nz);

      %% Cost function handles (cost, equality, inequality)
      W  = @(x,u,z) 1/2*sum((Raxc{1}*x + Rauc{1}*u + Razc{1}*z - Ra0c{1}).^2);
      We = @(x,u,z) Rexc{1}*x + Reuc{1}*u + Rezc{1}*z - Re0c{1};
      Wi = @(x,u,z) Rixc{1}*x + Riuc{1}*u + Rizc{1}*z - Ri0c{1};
      
      %% dummy L
      L.isEvolutive = false;
      L.nN = nx;
      L.P.userowmask = false;
      
      %% Other variables needed
      d1dt = []; d2dt = [];
      
      TusU{1}  = [eye(nu,nu),zeros(nu,nz)];
      TzsU{1} = [zeros(nz,nu),eye(nz)];
      dxdxps = {}; mask = false; Uscale = 1;
      
      % initial guess
      xs{1} = 2*ones(nx,1); us{1} = 5*ones(nu,1); zs{1} = zeros(nz,1);
     
      if testCase.verbosity
        fprintf('\nRunning %s nx=%d nu=%d nz=%d neq=%d niq=%d costfun=%s, eq=%s ineq=%s\n',...
          functiontype,nx,nu,nz,neq,niq,costfunction,equalityconstraint,inequalityconstraint);
        if multivariable
          fprintf('%10s %10s %10s %10s %10s %10s %10s %10s %10s %10s\n',...
            'iter','|F|','W','We','Wi','|x|','|u|','|z|','smin(dF/dx)','convrate');
        else
          fprintf('%10s %10s %10s %10s %10s %10s %10s %10s %10s %10s\n',...
            'iter','|F|','W','We','Wi','x','u','z','smin(dF/dx)','convrate');
        end
      end
      
      Fval_prev = Inf; % init
      error_ratio_prev = -Inf;
      %% iterations
      for ii=1:testCase.nmax
        cost = W (xs{1},us{1},zs{1});
        ceq  = We(xs{1},us{1},zs{1});
        ciq  = Wi(xs{1},us{1},zs{1});
        
        Fval = F(xs{1},us{1});
        
        error_ratio = log10(norm(Fval)/(norm(Fval_prev)+eps*testCase.tol));
        conv_rate = error_ratio/error_ratio_prev;
        
        % Jacobians
        dFdxval  =  Jx(xs{1},us{1});
        dFduval  =  Ju(xs{1},us{1});
        dx0s{1}  = -dFdxval\Fval;
        dxdus{1} = -dFdxval\dFduval;
                
        smin = min(svd(dFdxval)); % smallest singlar value of Jacobian
        
        % debug
        if testCase.verbosity
          if ~multivariable
            fprintf('%10d %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10e\n',...
              ii,norm(Fval),cost,ceq,ciq,xs{1},us{1},zs{1},smin,conv_rate);
          else
            fprintf('%10d %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10.3e %10e\n',...
              ii,norm(Fval),cost,ceq,ciq,norm(xs{1}),norm(us{1}),norm(zs{1}),smin,conv_rate);
          end
        end
        
        % convergence check
        if norm(Fval)<testCase.tol, break; end
        
        % Evaluate residuals
        [Ra,Re,Ri] = meqlsqeval(L,...
          Raxc,Rauc,Razc,Ra0c,Rexc,Reuc,Rezc,Re0c,Rixc,Riuc,Rizc,Ri0c,...
          d1dt,d2dt,xs,us,zs);
        
        % Non-linear step
        
        % Reduced cost function
        [Rat,RaUt,Ret,ReUt,Rit,RiUt] = meqlsqprep(L,...
          Raxc,Rauc,Razc,Ra,Rexc,Reuc,Rezc,Re,Rixc,Riuc,Rizc,Ri,...
          TusU,TzsU,d1dt,d2dt,xs,us,zs,dx0s,dxdus,dxdxps,mask,false);
        
        % lsq solve
        [dxs,dus,dzs,~,~,s] = meqlsqsolve(L,...
          testCase.qpsolve,testCase.qpopts,...
          Rat,RaUt,Ret,ReUt,Rit,RiUt,...
          dx0s,dxdus,dxdxps,TusU,TzsU,...
          Uscale,false,testCase.ctol);
        
        if ~isempty(s)
          testCase.assertFail(sprintf('meqlsqsolve failed\n msg   = %s\n msgID = %s',s.msgid,s.msg))
        end
        % solution update
        xs{1} = xs{1}+dxs{1};
        us{1} = us{1}+dus{1};
        zs{1} = zs{1}+dzs{1};
        
        % update states
        error_ratio_prev = error_ratio;
        Fval_prev = Fval;
      end
      
      % Checks
      testCase.verifyLessThan(norm(Fval),testCase.tol,'model function not satisfied')
      
      if neq
        testCase.verifyLessThan(abs(ceq),testCase.tol,'equality constraint violation')
      end
      if niq
        testCase.verifyLessThan(ciq,testCase.tol,'inequality constraint violation')
      end
      
      if strcmp(functiontype,'linear')
        % expect immediate convergence
        testCase.verifyEqual(ii,2,'must converge in 1 iterations for linear case')
      else
        % check convergence rate
        if ii>3
          if isequal(costfunction,'z') && nz==0 % special case where cost function is null
            expected_convergence = 1;
          else
            expected_convergence = 2;
          end
          margin = 0.5; % margin for convergence rate
          testCase.verifyGreaterThan(conv_rate,expected_convergence-margin,'lower convergence rate than expected');
        end
      end
    end
  end
  
  methods(Static)
    function [F,Jx,Ju,nx,nu,nz] = model_function_factory(functiontype,multivariable)
      if ~multivariable
        nx=1;nu=1;nz=0;
        switch functiontype
          case 'linear'
            F = @(x,u) 2*x + u; % =0
            Jx = @(x,u,z) 2;
            Ju = @(x,u,z) 1;
          case 'nonlinear-u'
            F = @(x,u)  x + u^2 - 1;
            Jx = @(x,u) 1;
            Ju = @(x,u) 2*u;
          case 'nonlinear-x'
            F = @(x,u) (x+4)^2 - 2*u - 4;
            Jx = @(x,u) 2*(x+4);
            Ju = @(x,u) -2;
          otherwise
            error('undefined functiontype %s',functiontype)
        end
      else % multivariable test cases
        nx = 3; nu=3; nz=0;
        switch functiontype
          case 'linear'
            F = @(x,u) 2*x + u; % x+u=0
            Jx = @(x,u) 2*eye(nx);
            Ju = @(x,u) eye(nx);
          case 'nonlinear-u'
            F = @(x,u) x + u.^2;
            Jx = @(x,u) eye(nx);
            Ju = @(x,u) 2*u.*eye(nu);
          case 'nonlinear-x'
            F = @(x,u) (x+4).^2 - 2*u - 4;
            Jx = @(x,u) 2*(x+4).*eye(nx);
            Ju = @(x,u) -2*eye(nu);
        end
      end
    end
    
    function  [Raxc,Rauc,Razc,Ra0c] = cost_function_factory(costfunctiontype,nx,nu,nz)
      
      % defaults
      Raxc = zeros(1,nx); Rauc = zeros(nu,nu); Razc = zeros(1,nz); Ra0c = -ones(1,1);
      
      if contains(costfunctiontype,'x')
        Raxc = eye(nx,nx);
      end
      if contains(costfunctiontype,'u')
        Rauc = eye(nu,nu);
      end
      if contains(costfunctiontype,'z')
        Razc = ones(1,nz);
      end
    end
    
    function [Rexc,Reuc,Rezc,Re0c,neq] = equality_constraint_factory(constrainttype,nx,nu,nz)
      
      switch constrainttype
        case 'none'
          Rexc = zeros(0,nx); Reuc = zeros(0,nu); Rezc = zeros(0,nz);  Re0c = [];
          neq=0;
        case 'linear-x' % x(1) + 0.5 = 0
          neq = 1;
          Rexc = eye(neq,nx); Reuc = zeros(neq,nu); Rezc = zeros(neq,nz);  Re0c = -0.1*ones(neq,1);
        case 'linear-u' % u + 0.5 = 0
          neq = 1;
          Rexc = zeros(neq,nx); Reuc = eye(neq,nu); Rezc = zeros(neq,nz);  Re0c = -0.5*ones(neq,1);
      end
    end
    
    function [Rixc,Riuc,Rizc,Ri0c,niq] = inequality_constraint_factory(constrainttype,nx,nu,nz)
      switch constrainttype
        case 'none'
          Rixc = zeros(0,nx); Riuc = zeros(0,nu); Rizc = zeros(0,nz);  Ri0c = [];
          niq=0;
        case 'linear-x'
          niq = 1;
          Rixc = fliplr(eye(niq,nx)); Riuc = zeros(niq,nu); Rizc = zeros(niq,nz);  Ri0c = -0.2*ones(niq,1);
        case 'linear-u'
          niq = 1;
          Rixc = zeros(niq,nx); Riuc = fliplr(eye(niq,nu)); Rizc = zeros(niq,nz);  Ri0c = -1*ones(niq,1);
      end
    end
  end
end
