/* [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved. */
# include "meq.h"

void FLT_NAME(ipm4)(FLT *x, FLT *z, int *kt, meq_bool *stat, FLT *h, FLT *c,
                    FLT a, FLT *x0, FLT *z0, FLT tol, int nit, meq_bool verb, FLT *w, int n)
{
 /* [x,z,kt,stat] = ipm4mex(H,c,a,x0,z0,tol,nit,verb) */
 int nh=n*(n+1)/2, kit, k;
 FLT alpha, r0, r1, res;
 FLT *u  =  w     , /* size nh+4*n */
     *d  =  u + nh,
     *e  =  d + n ,
     *dx =  e + n ,
     *dz = dx + n ;
 /* x = x0 */
 COPY(n,x,x0);
 /* z = z0 */
 COPY(n,z,z0);
 /* for kit=nit:-1:1 */
 for (kit=0; kit<nit; kit++) {
  /* if any(x==0) return; */
  *stat = meq_true;
  for (k=n;k--;) *stat &= (x[k] != FLTC(0.0));
  if (!*stat) return;
  /* U = H */
  COPY(nh,u,h);
  /* d = z./x; */
  /* U = U + diag(d) */
  for (k=n;k--;) {
   d[k] = z[k]/x[k];
   u[k*(k+3)/2] += d[k];
  }
  /* U = chol(U) */
  if (PTRF(n,u) != 0) { *stat = meq_false; return; }
  /* dx = c */
  COPY(n,dx,c);
  /* dx = H*x + dx */
  SPMV(n,h,x,dx);
  /* res = max(abs(H*x+c-z)) */
  res = FLTC(0.0);
  for (k=n;k--;) {
   r0 = ABS(dx[k]-z[k]); if (r0 > res) res = r0;
  }
  /* dx = U\(U'\dx) */
  PTRS(n,u,dx);
  /* dz = z - d.*dx */
  /* alpha = min(1,x/dx) */
  /* alpha = min(1,z/dz) */
  alpha = FLTC(1.0);
  for (k=n;k--;) {
   dz[k] = z[k] - d[k]*dx[k];
   if (a*dx[k] > FLTC(0.0)) { r0 = x[k]/dx[k]; if (r0 < alpha) alpha = r0; }
   if (a*dz[k] > FLTC(0.0)) { r0 = z[k]/dz[k]; if (r0 < alpha) alpha = r0; }
  }
  /* r0 = sum((x-alpha*dx).*(z-alpha*dz)) */
  /* r1 = sum(x.*z) */
  r0 = r1 = FLTC(0.0);
  for (k=n;k--;) {
   r0 += (x[k] - alpha*dx[k]) * (z[k] - alpha*dz[k]);
   r1 +=  x[k] * z[k];
  }
  /* tau = r0^3/(r1^2*n) */
  r1 = r0/r1;               /* Since r0 and r1 can be both very small,     */
  r0 = (r1*r1)*(r0/(FLT)n); /* this is more precise and less prone to NaNs */ 
  /* e = (dx.*dz-tau)./x */
  for (k=n;k--;) dz[k] = e[k] = (dx[k]*dz[k]-r0)/x[k];
  /* res = max(res,(e+z).*x) */
  for (k=n;k--;) {
   r0 = ABS((e[k]+z[k])*x[k]); if (r0 > res) res = r0;
  }
  /* dz = U\(U'\e) */
  PTRS(n,u,dz);
  /* dx = dx + dz */
  /* dz = e + z - d.*dx */
  /* alpha = min(1,x/dx) */
  /* alpha = min(1,z/dz) */
  alpha = MEQ_HUGE;
  for (k=n;k--;) {
   dx[k] += dz[k];
   dz[k]  = e[k] + z[k] - d[k]*dx[k];
   if (a*dx[k] > FLTC(0.0)) { r0 = x[k]/dx[k]; if (r0 < alpha) alpha = r0; }
   if (a*dz[k] > FLTC(0.0)) { r0 = z[k]/dz[k]; if (r0 < alpha) alpha = r0; }
  }
  alpha *= FLTC(0.99);
  if (FLTC(1.0)<alpha) {alpha = FLTC(1.0);}
  /* x = x - alpha*dx */
  /* z = z - alpha*dz */
  /* r0 = max(abs(dx)) */
  r0 = FLTC(0.0);
  for (k=n;k--;) {
   x[k] -= alpha*dx[k];
   z[k] -= alpha*dz[k];
   r1 = ABS(dx[k]);
   if (r1 > r0) r0 = r1;
  }
  if (verb)
    printf("ipm-4 it=%3d res=%5.2e step=%5.2e tol=%5.2e\n",kit,res,r0,tol);
  /* stat = res<tol; if stat, break, end */
  if ((*stat = (res < tol))) break;
  /* stat = r0<tol; if stat, break, end */
  if ((*stat = (r0 < tol))) break;
 }
 *kt = kit+1;
}
