/* [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved. */
# include "meq.h"

# ifdef SINGLE
# define FCT sipm1
# define FCOPY   cblas_scopy
# define FPTRF LAPACKE_spptrf
# define FPTRS LAPACKE_spptrs
# define FSPMV   cblas_sspmv
# define ABS fabsf
# else
# define FCT dipm1
# define FCOPY   cblas_dcopy
# define FPTRF LAPACKE_dpptrf
# define FPTRS LAPACKE_dpptrs
# define FSPMV   cblas_dspmv
# define ABS fabs
# endif

# define COPY(n,y,x)   FCOPY(n,x,1,y,1)                                    /* y = x */
# define PTRF(n,u)     FPTRF(LAPACK_COL_MAJOR,'U',n,u)                     /* U = chol(U) */
# define PTRS(n,u,x)   FPTRS(LAPACK_COL_MAJOR,'U',n,1,u,x,n)               /* x = U\(U'\x) */
# define SPMV(n,u,x,y) FSPMV(CblasColMajor,CblasUpper,n,1.0,u,x,1,1.0,y,1) /* y = U*X+y */ 

void FCT(FLT *x, FLT *z, bool *stat, FLT *h, FLT *c, FLT a, FLT *x0, FLT *z0, FLT tol, int nit, FLT *w, int n)

{
	/* [x,z,stat] = ipm1mex(H,c,a,x0,z0,tol,nit) */
 int nh=n*(n+1)/2, kit, k;
 FLT alpha, r0, r1;
 FLT *u  =  w     , /* size nh+4*n */
     *d  =  u + nh,
     *e  =  d + n ,
     *dx =  e + n ,
     *dz = dx + n ;
 /* x = x0 */
 COPY(n,x,x0);
 /* z = z0 */
 COPY(n,z,z0);
 /* for kit=nit:-1:1 */
 for (kit=nit; kit--;) {
  /* if any(x==0) return; */
  *stat = 1;
  for (k=n;k--;) *stat &= x[k] != FLTC(0.0);
  if (!*stat) return;
  /* U = H */
  COPY(nh,u,h);
  /* d = z./x; */
  /* U = U + diag(d) */
  for (k=n;k--;) {
   d[k] = z[k]/x[k];
   u[k*(k+3)/2] += d[k];
  }
  /* U = chol(U) */
  if (PTRF(n,u) != 0) { *stat = 0; return; }
  /* dx = c */
  COPY(n,dx,c);
  /* dx = H*x + dx */
  SPMV(n,h,x,dx);
  /* dx = U\(U'\dx) */
  PTRS(n,u,dx);
  /* dz = z - d.*dx */
  /* alpha = min(1,x/dx) */
  /* alpha = min(1,z/dz) */
  alpha = FLTC(1.0);
  for (k=n;k--;) {
   dz[k] = z[k] - d[k]*dx[k];
   if (a*dx[k] > FLTC(0.0)) { r0 = x[k]/dx[k]; if (r0 < alpha) alpha = r0; }
   if (a*dz[k] > FLTC(0.0)) { r0 = z[k]/dz[k]; if (r0 < alpha) alpha = r0; }
  }
  /* r0 = sum((x-alpha*dx).*(z-alpha*dz)) */
  /* r1 = sum(x.*z) */
  r0 = r1 = FLTC(0.0);
  for (k=n;k--;) {
   r0 += (x[k] - alpha*dx[k]) * (z[k] - alpha*dz[k]);
   r1 +=  x[k] * z[k];
  }
  /* tau = r0^3/(r1^2*n) */
  r1 = r0/r1;               /* Since r0 and r1 can be both very small,     */
  r0 = (r1*r1)*(r0/(FLT)n); /* this is more precise and less prone to NaNs */ 
  /* e = (dx.*dz-tau)./x */
  /* dz = U\(U'\e) */
  for (k=n;k--;) dz[k] = e[k] = (dx[k]*dz[k]-r0)/x[k];
  PTRS(n,u,dz);
  /* dx = dx + dz */
  /* dz = e + z - d.*dx */
  /* alpha = min(1,x/dx) */
  /* alpha = min(1,z/dz) */
  alpha = FLTC(1.0);
  for (k=n;k--;) {
   dx[k] += dz[k];
   dz[k]  = e[k] + z[k] - d[k]*dx[k];
   if (a*dx[k] > FLTC(0.0)) { r0 = x[k]/dx[k]; if (r0 < alpha) alpha = r0; }
   if (a*dz[k] > FLTC(0.0)) { r0 = z[k]/dz[k]; if (r0 < alpha) alpha = r0; }
  }
  /* alpha = 0.99 * alpha */
  alpha *= FLTC(0.99);
  /* x = x - alpha*dx */
  /* z = z - alpha*dz */
  /* r0 = max(abs(dz)) */
  r0 = FLTC(0.0);
  for (k=n;k--;) {
   x[k] -= alpha*dx[k];
   z[k] -= alpha*dz[k];
   r1 = ABS(dx[k]);
   if (r1 > r0) r0 = r1;
  }
  /* stat = r0<tol; if stat, break, end */
  if ((*stat = (r0 < tol))) break;
 }
}
