/* [+GenLib General Purpose Library+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved. */
# include "genlib.h"
# include <string.h>

void mexFunction(int nlhs,       mxArray *plhs[],
                 int nrhs, const mxArray *prhs[])
{

# define MEXNAME bspsum

 /* v = bspsummex(t,c,x,d,p,npar); */

# define MXV  pout[0]

# define MXT  prhs[0] /* double */
# define MXC  prhs[1] /* double */
# define MXX  prhs[2] /* double */
# define MXD  prhs[3] /* numeric, scalar */
# define MXP  prhs[4] /* numeric, scalar */
# define NPAR prhs[5] /* numeric, scalar */

 CHECK_NARGIN_GE(3);
 CHECK_NARGIN_LE(6);

 CHECK_DOUBLE(MXT);
 CHECK_DOUBLE(MXC);
 CHECK_DOUBLE(MXX);
 if (nrhs>3) { CHECK_NUMERIC(MXD);
               CHECK_SCALAR(MXD); }
 if (nrhs>4) { CHECK_NUMERIC(MXP);
               CHECK_SCALAR(MXP); }
 if (nrhs>5) { CHECK_NUMERIC(NPAR);
               CHECK_SCALAR(NPAR);}

 CHECK_NARGOUT_LE(1);

 mxArray *pout[1] = {NULL};

 int nT, mC, nC, nX, k, p, d, i, Vinc, Xinc, npar;
 int ndC, ndX, ndV;
 const mwSize  *dimC, *dimX;
 mwSize *dX, *dimV;
 int check;
 void (*f)(double *, double *, int, double *, int, double *, int, int);
 double *T, *C, *X, *V;

 /* default values */
 d    = (nrhs>3) ? mxGetScalar(MXD ) : 0;
 p    = (nrhs>4) ? mxGetScalar(MXP ) : 0;
 npar = (nrhs>5) ? mxGetScalar(NPAR) : 1;
 if (npar<=0)
   mexErrMsgIdAndTxt("bspsummex:npar","NPAR argument to bspsummex must be a strictly positive integer");

 /* size management */
 nT   = mxGetNumberOfElements(MXT);
 mC   = mxGetM(MXC);
 nC   = mxGetN(MXC);
 ndC  = mxGetNumberOfDimensions(MXC);
 dimC = mxGetDimensions(MXC);
 nX   = mxGetNumberOfElements(MXX);
 ndX  = mxGetNumberOfDimensions(MXX);
 dimX = mxGetDimensions(MXX);

 /* Case where X is a vector */
 if (ndX == 2 && (nX == dimX[0] || nX == dimX[1]))
 {
   ndX = 1;
   dX = (mwSize *) mxMalloc(sizeof *dX);
   dX[0] = nX;
 }
 else
 {
   dX = (mwSize *) mxMalloc(sizeof *dX * ndX);
   memcpy(dX,dimX,sizeof *dX * ndX);
 }

 /* Determine order of B-spline */
 k = nT - mC;
 if (k < 1) mexErrMsgIdAndTxt("bspsummex:loworder","Too many coefficients for specified knot sequence");

 /* Check for size compatibility, determine output size */
 if (p)
 {
   check = ndC > ndX;
   if (check)
     for (i = 0; i < ndX; i++)
       check = check && (dimC[i+1] == dX[i]);
   if (!check) mexErrMsgIdAndTxt("bspsummex:inputSize","Incompatible C and X size");
   ndV = ndC - 1;
   dimV = (mwSize *) dimC + 1;
 }
 else
 {
   ndV = ndX + ndC - 1;
   dimV = (mwSize *) mxMalloc(sizeof *dimV * ndV);
   memcpy( dimV     , dX     ,sizeof *dimV *  ndX   );
   memcpy(&dimV[ndX],&dimC[1],sizeof *dimV * (ndC-1));
 }

 /* output memory allocation */
 MXV = mxCreateNumericArray(ndV, dimV, mxDOUBLE_CLASS, mxREAL);

 /* do the work */
 if (d >= k || nX == 0) {ASSIGN_PLHS;return;}

 Vinc = p ? 1 : nX;
 Xinc = p ? 1 : 0;
 f = &bspsum;
 
# ifdef BSPSUM4
 if (k == 4 && d <= 2) f = d ? (d == 2 ? &bspsum42 : &bspsum41) : &bspsum40;
# endif
 /* Get pointer addresses of MATLAB arrays */
 T = mxGetPr(MXT);
 C = mxGetPr(MXC);
 X = mxGetPr(MXX);
 V = mxGetPr(MXV);
# ifdef BSPPAR
# pragma omp parallel for num_threads(npar)
# endif
 for (i = 0; i < nC; i++)
  f(V + i*Vinc, T, nT, C + i*mC, mC,
    X + (i*Xinc)%nX, Vinc, d);

 ASSIGN_PLHS;
}
