%IPM  Interior point method general minimisation
% [X,Y,S,Z,KT] = IPM(H,C,A,B,AE,BE,X,Y,S,Z,MODE[,NIT,TOL])
% finds X such that min X'*H*X/2+C'*X, A*X>B, AE*X=BE using the interior
% point method. XI,YI,SI,ZI are initial guesses. NIT and TOL control the
% iteration numbers and the convergence criterion. MODE =
% .0 for the brute force implementation
% .1 for the optimised implementation
% .2 no equality constraints min X'*H*X/2+C'*X, A*X>B
% .3 min X'*H*X/2+C'*X, X>0
% .4 min X'*H*X/2+C'*X, A(1)*X>0
% .5 min X'*H*X/2+C'*X, A(1)*X(B(1))>0
%
% For details, see: [MEQ-redbook] 
%
% [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
function [x,y,s,z,kt] = ipm(H,c,A,b,Ae,be,x,y,s,z,mode,niter,tol)
 
 if nargin < 12, niter = 20; end
 if nargin < 13, tol = 1e-6; end
 n  = size(H ,1);
 
 switch mode
  case 0 % min x'*H*x/2+c'*x, A*x>b, Ae*x=be
   if isempty(A ), A  = zeros(0,n); b  = zeros(0,1); end
   if isempty(Ae), Ae = zeros(0,n); be = zeros(0,1); end
   ni = size(A ,1);
   ne = size(Ae,1);
   if isempty(x), x = ones(n ,1);                  end
   if isempty(y), y = ones(ne,1);                  end
   if isempty(s), s = ones(ni,1);                  end
   if isempty(z), z = ones(ni,1);                  end
   D = [
    H          Ae'          zeros(n,ni) -A'      ;
    Ae         zeros(ne,ne+ ni+         ni)      ;
    zeros(ni,n +ne)         diag(z./s)  eye(ni)  ;
    -A         zeros(ni,ne) eye(ni)     zeros(ni)];
   for kt = 1:niter
    rd = H*x + c - A'*z + Ae'*y;
    re = Ae*x - be;
    rc = z;
    rp = s - A*x + b;
    D(n+ne+1:n+ne+ni,n+ne+1:n+ne+ni) = diag(z./s);
    d = D \ [-rd ; -re ; -rc ; -rp];
    ds = d(   ne+n+1:n+ne+ni);
    dz = d(ni+ne+n+1:end    );
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = s'*z/ni;
    tau1 = (s + alpha*ds)'*(z + alpha*dz)/ni;
    sigma = (tau1/tau).^3;
    
    d = D \ [-rd ; -re ; -rc-(ds.*dz-sigma*tau)./s ; -rp];
    dx = d(        1:n      );
    dy = d(      n+1:n+ne   );
    ds = d(   ne+n+1:n+ne+ni);
    dz = d(ni+ne+n+1:end    );
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    x = x + alpha*dx;
    y = y + alpha*dy;
    s = s + alpha*ds;
    z = z + alpha*dz;
    
    if norm([rd;rp],Inf) < tol, break, end
   end
   
  case 1 % min x'*H*x/2+c'*x, A*x>b, Ae*x=be, optimised solution
   if isempty(A ), A  = zeros(0,n); b  = zeros(0,1); end
   if isempty(Ae), Ae = zeros(0,n); be = zeros(0,1); end
   ni = size(A ,1);
   ne = size(Ae,1);
   if isempty(x), x = ones(n ,1);                  end
   if isempty(y), y = ones(ne,1);                  end
   if isempty(s), s = ones(ni,1);                  end
   if isempty(z), z = ones(ni,1);                  end
   for kt = 1:niter
    rd = H*x + c - A'*z + Ae'*y;
    re = Ae*x - be;
    rc = z.*s;
    rp = s - A*x + b;
    d = z./s;
    D = chol(H + A'*diag(d)*A);
    De = chol(Ae*(D\(D'\Ae'))); % See eq. (5.16) in MEQ-Redbook
    
    xc = -rc./z;
    xp = d.*(rp+xc);
    xd = A'*xp-rd;
    xe = re+Ae*(D\(D'\xd));
    dy = De\(De'\xe);
    dx = D\(D'\(xd-Ae'*dy));
    dz = xp - d.*(A*dx);
    ds = xc-s./z.*dz;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = s'*z/ni;
    tau1 = (s + alpha*ds)'*(z + alpha*dz)/ni;
    sigma = (tau1/tau).^3;
    
    rc = z.*s+(ds.*dz-sigma*tau);
    xc = -rc./z;
    xp = d.*(rp+xc);
    xd = A'*xp-rd;
    xe = re+Ae*(D\(D'\xd));
    dy = De\(De'\xe);
    dx = D\(D'\(xd-Ae'*dy));
    dz = xp - d.*(A*dx);
    ds = xc-s./z.*dz;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    x = x + alpha*dx;
    y = y + alpha*dy;
    s = s + alpha*ds;
    z = z + alpha*dz;
    if norm([rd;rp],Inf) < tol, break, end
   end
   
  case 2 % min x'*H*x/2+c'*x, A*x>b
   if isempty(A ), A  = zeros(0,n); b  = zeros(0,1); end
   ni = size(A ,1);
   if isempty(x), x = ones(n ,1);                  end
   if isempty(s), s = ones(ni,1);                  end
   if isempty(z), z = ones(ni,1);                  end
   s = max(1,A*x-b); z = s;
   for kt = 1:niter
    rd = H*x + c - A'*z;
    rc = z.*s;
    rp = s - A*x + b;
    d = z./s;
    D = chol(H + A'*diag(d)*A);
    
    xc = -rc./z;
    xp = d.*(rp+xc);
    xd = A'*xp-rd;
    dx = D\(D'\xd);
    dz = xp - d.*(A*dx);
    ds = xc-s./z.*dz;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = s'*z/ni;
    tau1 = (s + alpha*ds)'*(z + alpha*dz)/ni;
    sigma = (tau1/tau).^3;
    
    rc = z.*s+(ds.*dz-sigma*tau);
    xc = -rc./z;
    xp = d.*(rp+xc);
    xd = A'*xp-rd;
    dx = D\(D'\xd);
    dz = xp - d.*(A*dx);
    ds = xc-s./z.*dz;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    x = x + alpha*dx;
    s = s + alpha*ds;
    z = z + alpha*dz;
    if norm([rd;rp],Inf) < tol, break, end
   end
   
  case 3 % min x'*H*x/2+c'*x, x>0
   if isempty(x), x = ones(n,1);             end
   if isempty(z), z = ones(n,1);             end
   for kt = 1:niter
    rd = H*x + c - z;
    rc = z.*x;
    d = z./x;
    D = chol(H + diag(d));
    
    xc = -rc./x;
    xd = xc-rd;
    dx = D\(D'\xd);
    dz = xc - d.*dx;
    k = find(dx < 0); alphas = min([1 ; -x(k)./dx(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = x'*z/n;
    tau1 = (x + alpha*dx)'*(z + alpha*dz)/n;
    sigma = (tau1/tau).^3;
    
    rc = z.*x+(dx.*dz-sigma*tau);
    xc = -rc./x;
    xd = xc-rd;
    dx = D\(D'\xd);
    dz = xc - d.*dx;
    k = find(dx < 0); alphas = min([1 ; -x(k)./dx(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    x = x + alpha*dx;
    z = z + alpha*dz;
    if norm(rd,Inf) < tol, break, end
    
   end

  case 4 % min x'*H*x/2+c'*x, a*x>0
   A = sign(A);
   if isempty(x), s = ones(n,1); else, s = A*x; end
   if isempty(z), z = ones(n,1);                end
   for kt = 1:niter
    rd = H*s + A*c - z;
    rc = z.*s;
    d = z./s;
    D = chol(H + diag(d));
    
    xc = -rc./s;
    xd = xc-rd;
    ds = D\(D'\xd);
    dz = xc - d.*ds;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = s'*z/n;
    tau1 = (s + alpha*ds)'*(z + alpha*dz)/n;
    sigma = (tau1/tau).^3;
    
    rc = z.*s+(ds.*dz-sigma*tau);
    xc = -rc./s;
    xd = xc-rd;
    ds = D\(D'\xd);
    dz = xc - d.*ds;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    s = s + alpha*ds;
    z = z + alpha*dz;
    if norm(ds,Inf) < tol, break, end
    
   end
   x = A*s;
  
  case 5 % min x'*H*x/2+c'*x, a*x(b)>0
   A = sign(A);
   if isempty(x), x = ones(n,1); end
   if isempty(z), z = 1;         end
   s = max(1,A*x(b)); x(b) = s/A;
   for kt = 1:niter
    rd = H*x + c; rd(b) = rd(b) - A*z;
    rc = z.*s;
    d = z./s;
    D = H; D(b,b) = D(b,b) + d;
    D = chol(D);
    
    xc = -rc./s;
    xd = -rd; xd(b) = xd(b) + A*xc;
    dx = D\(D'\xd);
    ds = A*dx(b);
    dz = xc - d.*ds;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = min(alphas, alphaz);
    tau = s*z/n;
    tau1 = (s + alpha*ds)*(z + alpha*dz)/n;
    sigma = (tau1/tau).^3;
    
    rc = z.*s+(ds.*dz-sigma*tau);
    xc = -rc./s;
    xd = -rd; xd(b) = xd(b) + A*xc;
    dx = D\(D'\xd);
    ds = A*dx(b);
    dz = xc - d.*ds;
    k = find(ds < 0); alphas = min([1 ; -s(k)./ds(k)]);
    k = find(dz < 0); alphaz = min([1 ; -z(k)./dz(k)]);
    alpha = 0.99*min([alphas alphaz]);
    z = z + alpha*dz;
    dx(b) = A*alpha*ds;
    x = x + dx;
    s = A*x(b);
    if norm(ds,Inf) < tol, break, end
    
   end
 end
