/* [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved. */
# include "meq.h"

# ifdef SINGLE
# define FCT   sipm2
# define FCOPY cblas_scopy
# define FPTRF LAPACKE_spptrf
# define FPTRS LAPACKE_spptrs
# define FSPMV cblas_sspmv
# define FGEMV cblas_sgemv
# define FAXPY cblas_saxpy
# define FVADD vsAdd
# define FVSUB vsSub
# define FVMUL vsMul
# define FVDIV vsDiv
# define ABS fabsf
# else
# define FCT   dipm2
# define FCOPY cblas_dcopy
# define FPTRF LAPACKE_dpptrf
# define FPTRS LAPACKE_dpptrs
# define FSPMV cblas_dspmv
# define FGEMV cblas_dgemv
# define FAXPY cblas_daxpy
# define FVADD vdAdd
# define FVSUB vdSub
# define FVMUL vdMul
# define FVDIV vdDiv
# define ABS fabs
# endif

# define COPY(n,y,x)         FCOPY(n,x,1,y,1)                                      /* y = x          */
# define PTRF(n,u)           FPTRF(LAPACK_COL_MAJOR,'U',n,u)                       /* U = chol(U)    */
# define PTRS(n,u,x)         FPTRS(LAPACK_COL_MAJOR,'U',n,1,u,x,n)                 /* x = U\(U'\x)   */
# define SPMV(n,u,x,y)       FSPMV(CblasColMajor,CblasUpper,n,1.0,u,x,1,1.0,y,1)   /* y = U*x+y      */ 
# define GNMV(m,n,a,A,x,b,y) FGEMV(CblasColMajor,CblasNoTrans,m,n,a,A,m,x,1,b,y,1) /* y = a*A*x+b*y  */ 
# define GTMV(m,n,a,A,x,b,y) FGEMV(CblasColMajor,CblasTrans  ,m,n,a,A,m,x,1,b,y,1) /* y = a*A'*x+b*y */
# define VADD(n,z,x,y)       FVADD(n,x,y,z)                                        /* z = x+y        */
# define VSUB(n,z,x,y)       FVSUB(n,x,y,z)                                        /* z = x-y        */
# define VMUL(n,z,x,y)       FVMUL(n,x,y,z)                                        /* z = x.*y       */
# define VDIV(n,z,x,y)       FVDIV(n,x,y,z)                                        /* z = x./y       */
# define AXPY(n,y,a,x)       FAXPY(n,a,x,1,y,1)                                    /* y = a*x + y    */

void FCT(FLT *x, FLT *s, FLT *z, bool *stat, FLT *h, FLT *c, FLT *a, FLT *b,
         FLT *x0, FLT *s0, FLT *z0, FLT tol, int nit, FLT *w, int n, int ni)

{
	/* [x,s,z,stat] = ipm2(uH,c,A,b,x0,s0,z0,tol,nit) */
 int nh=n*(n+1)/2, kit, k, l, m;
 FLT r0, r1, r2;
 FLT *u  =  w     , /* nh */
     *d  =  u + nh, /* ni */
     *rd =  d + ni, /* n  */
     *rp = rd + n , /* ni */
     *xp = rp + ni, /* ni */
     *dx = xp + ni, /* n  */
     *ds = dx + n , /* ni */
     *dz = ds + ni, /* ni */
     *ec = dz + ni  /* ni */
     ;              /* nh + 2*n + 6*ni */
 /* x = x0 */
 COPY(n,x,x0);
 /* s = s0 */
 COPY(ni,s,s0);
 /* z = z0 */
 COPY(ni,z,z0);
 /* for kit=nit:-1:1 */
 for (kit=nit; kit--;) {
  /* if any(s==0), stat = false; return; end */
  *stat = 1;
  for (k=ni;k--;) *stat &= s[k] != FLTC(0.0);
  if (!*stat) return;
  /* d = z./s; */
  VDIV(ni,d,z,s);
  /* if any(~isfinite(d)), stat = false; return, end */
  *stat = 1;
  for (k=ni;k--;) *stat &= !(isinf(d[k]) || isnan(d[k]));
  if (!*stat) return;
  /* U = H */
  COPY(nh,u,h);
  /* U = U + A'*diag(d)*A */
  { FLT *up = u, *ap0, *ap1, *ap2, *dp;
  for (k=n , ap1=a ; k-- ; )
   for (l=n-k, ap0=ap1, ap2=a ; l-- ; up++)
    for (m=ni, ap1=ap0, dp=d ; m-- ; )
     *up += *ap1++ * *dp++ * *ap2++;
  }
  /* U = chol(U) */
  if (PTRF(n,u) != 0) { *stat = 0; return; }
  /* rd = H*x + c - A'*z; */
  COPY(n,rd,c);               /* rd = c;         */
  SPMV(n,h,x,rd);             /* rd = H*x + rd;  */
  GTMV(ni,n,-1.0,a,z,1.0,rd); /* rd = -A'*z + rd */
  /* rp = b - A*x; */
  COPY(ni,rp,b);              /* rp = b;         */
  GNMV(ni,n,-1.0,a,x,1.0,rp); /* rp = -A*x + rp; */
  /* search direction */
  /* xp = d.*rp; */
  VMUL(ni,xp,d,rp);
  /* dx = A'*xp - rd; */
  COPY(n,dx,rd);               /* dx = rd;         */
  GTMV(ni,n,1.0,a,xp,-1.0,dx); /* dx = A'*xp - dx; */
  /* dx = U\(U'\dx) */
  PTRS(n,u,dx);
  /* dz = A*dx - rp; */
  COPY(ni,dz,rp);              /* dz = rp;        */
  GNMV(ni,n,1.0,a,dx,-1.0,dz); /* dz = A*dx - dz; */
  /* ds = s - dz; */
  VSUB(ni,ds,s,dz);
  /* k = (ds > 0); alphas = min([1 ; s(k)./ds(k)]); */
  /* k = (dz > 0); alphaz = min([1 ; s(k)./dz(k)]); */
  /* alpha = min(alphas,alphaz);                        */
  r2 = FLTC(1.0);
  for (k=ni;k--;) {
   if (ds[k] > FLTC(0.0)) { r0 = s[k]/ds[k]; if (r0 < r2) r2 = r0; }
   if (dz[k] > FLTC(0.0)) { r0 = s[k]/dz[k]; if (r0 < r2) r2 = r0; }
  }
  /* dz = d.*dz; */
  VMUL(ni,dz,d,dz);
  /* r0 = (s-alpha*ds)'*(z-alpha*dz) */
  /* r1 = s'*z */
  r0 = r1 = FLTC(0.0);
  for (k=ni;k--;) {
   r0 += (s[k] - r2*ds[k]) * (z[k] - r2*dz[k]);
   r1 +=  s[k] * z[k];
  }
  /* tau = r0^3/(r1^2*ni) */
  r1 = r0/r1;                /* Since r0 and r1 can be both very small,     */
  r0 = (r1*r1)*(r0/(FLT)ni); /* this is more precise and less prone to NaNs */ 
  /* ec = (tau - ds.*dz)./z; */
  VMUL(ni,ec,ds,dz);
  for (k=ni;k--;) ec[k] = r0 - ec[k];
  VDIV(ni,ec,ec,z);
  /* xp = xp + d.*ec; */
  VMUL(ni,dz,d,ec);
  VADD(ni,xp,xp,dz);
  /* dx = A'*xp - rd; */
  COPY(n,dx,rd);
  GTMV(ni,n,1.0,a,xp,-1.0,dx);
  /* dx = U\(U'\dx); */
  PTRS(n,u,dx);
  /* dz = A*dx - rp - ec; */
  VADD(ni,dz,rp,ec);
  GNMV(ni,n,1.0,a,dx,-1.0,dz);
  /* ds = s - dz - ec; */
  VADD(ni,ds,dz,ec);
  VSUB(ni,ds,s,ds);
  /* k = (ds > 0); alphas = min([1 ; s(k)./ds(k)]); */
  /* k = (dz > 0); alphaz = min([1 ; s(k)./dz(k)]); */
  /* alpha = 0.99*min(alphas,alphaz);                        */
  r2 = FLTC(1.0);
  for (k=ni;k--;) {
   if (ds[k] > FLTC(0.0)) { r0 = s[k]/ds[k]; if (r0 < r2) r2 = r0; }
   if (dz[k] > FLTC(0.0)) { r0 = s[k]/dz[k]; if (r0 < r2) r2 = r0; }
  }
  r2 *= FLTC(0.99);
  /* dz = d.*dz; */
  VMUL(ni,dz,d,dz);
  /* x = x + alpha*dx; */
  AXPY(n,x,r2,dx);
  /* s = s - alpha*ds; */
  AXPY(ni,s,-r2,ds);
  /* z = z - alpha*dz; */
  AXPY(ni,z,-r2,dz);
  /* stat = norm([rd;s+rp],Inf) < tol; */
  r0 = FLTC(0.0);
  for (k=n  ; k-- ; ) { r1 = ABS(     rd[k]); if (r1 > r0) r0 = r1; }
  for (k=ni ; k-- ; ) { r1 = ABS(s[k]+rp[k]); if (r1 > r0) r0 = r1; }
  if (*stat = r0 < tol) break;
 }
}
