/* [+MEQ MatlabEQuilibrium Toolbox+]
 *
 *    Copyright 2022-2025 Swiss Plasma Center EPFL
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License. */

# include "meq.h"

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{

 /* [Ie,aj,st] = ipmjmex(uAjj,Aje,Aej,Ie0,aj0,Ie,aj,Qcj,Xc,s,z,tol,iters) */

# define MEXNAME ipmjmex

# define IE    pout[ 0]
# define AJ    pout[ 1]
# define ST    pout[ 2]

# define UAJJ  prhs[ 0]  /* double or single, size=nj*(nj+1)/2 */
# define AJE   prhs[ 1]  /* type match UAJJ, def size=[nj,ne] */
# define AEJ   prhs[ 2]  /* type match UAJJ, size=ne*nj */
# define IE0   prhs[ 3]  /* type match UAJJ, size=ne */
# define AJ0   prhs[ 4]  /* type match UAJJ, size=nj */
# define IE1   prhs[ 5]  /* type match UAJJ, size=ne */
# define AJ1   prhs[ 6]  /* type match UAJJ, size=nj */
# define QCJ   prhs[ 7]  /* type match UAJJ, def size=[nc,*], size=[nc,nj] */
# define XC    prhs[ 8]  /* type match UAJJ, size=nc */
# define S     prhs[ 9]  /* type match UAJJ, size=nc */
# define Z     prhs[10]  /* type match UAJJ, size=nc */
# define TOL   prhs[11]  /* numeric, scalar */
# define ITERS prhs[12]  /* numeric, scalar */

 CHECK_NARGIN_EQ(13);

 int nj = mxGetM(AJE),
     ne = mxGetN(AJE),
     nc = mxGetM(QCJ);

 CHECK_REAL(UAJJ);
 CHECK_NUMEL(UAJJ,nj*(nj+1)/2);
 CHECK_TYPE_MATCH(AJE,UAJJ);
 CHECK_TYPE_MATCH(AEJ,UAJJ);
 CHECK_NUMEL(AEJ,ne*nj);
 CHECK_TYPE_MATCH(IE0,UAJJ);
 CHECK_NUMEL(IE0,ne);
 CHECK_TYPE_MATCH(AJ0,UAJJ);
 CHECK_NUMEL(AJ0,nj);
 CHECK_TYPE_MATCH(IE1,UAJJ);
 CHECK_NUMEL(IE1,ne);
 CHECK_TYPE_MATCH(AJ1,UAJJ);
 CHECK_NUMEL(AJ1,nj);
 CHECK_TYPE_MATCH(QCJ,UAJJ);
 CHECK_NCOLS(QCJ,nj);
 CHECK_TYPE_MATCH(XC,UAJJ);
 CHECK_NUMEL(XC,nc);
 CHECK_TYPE_MATCH(S,UAJJ);
 CHECK_NUMEL(S,nc);
 CHECK_TYPE_MATCH(Z,UAJJ);
 CHECK_NUMEL(Z,nc);
 CHECK_NUMERIC(TOL);
 CHECK_SCALAR(TOL);
 CHECK_NUMERIC(ITERS);
 CHECK_SCALAR(ITERS);

 CHECK_NARGOUT_LE(3);

 mxArray *pout[3] = {NULL};

 int iters = mxGetScalar(ITERS);

 meq_bool stat = meq_false;
 
 switch (mxGetClassID(UAJJ)) {
  case mxDOUBLE_CLASS: {
   double *w  = (double *)mxCalloc(nj*(nj+1)/2+4*nj+6*nc,sizeof(double));
   IE = mxCreateDoubleMatrix(ne, 1, mxREAL);
   AJ = mxCreateDoubleMatrix(nj, 1, mxREAL);
   dipmj(mxGetPr(IE),mxGetPr(AJ),&stat,mxGetPr(UAJJ),mxGetPr(AJE),mxGetPr(AEJ),mxGetPr(IE0),mxGetPr(AJ0),mxGetPr(IE1),
         mxGetPr(AJ1),mxGetPr(QCJ),mxGetPr(XC),mxGetPr(S),mxGetPr(Z),mxGetScalar(TOL),iters,w,nj,ne,nc);
   mxFree(w);
   break;
  }
  case mxSINGLE_CLASS: {
   float *w  = (float *)mxCalloc(nj*(nj+1)/2+4*nj+6*nc,sizeof(float)),
         *tol = (float *)mxGetData(TOL);
   IE = mxCreateNumericMatrix(ne, 1, mxSINGLE_CLASS, mxREAL);
   AJ = mxCreateNumericMatrix(nj, 1, mxSINGLE_CLASS, mxREAL);
   sipmj((float*)mxGetData(IE),(float*)mxGetData(AJ),&stat,(float*)mxGetData(UAJJ),(float*)mxGetData(AJE),(float*)mxGetData(AEJ),(float*)mxGetData(IE0),(float*)mxGetData(AJ0),(float*)mxGetData(IE1),
         (float*)mxGetData(AJ1),(float*)mxGetData(QCJ),(float*)mxGetData(XC),(float*)mxGetData(S),(float*)mxGetData(Z),*tol,iters,w,nj,ne,nc);
   mxFree(w);
   break;
  }
  default:
   mexErrMsgIdAndTxt("ipmjmex:error","Class of uajj is neither double nor single but has passed the real test");
 }
 ST = mxCreateLogicalScalar(stat);

 ASSIGN_PLHS;
}
