/* [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved. */
# include "meq.h"
# ifdef SINGLE
# define FCT   siata
# define FCT1  siata1
# define FCT2  siata2
# define FCT3  siata3
# define FCT4  siata4
# define FCTN  siataN
# define INV(x) ( x ? 1.0f / x : 0.0f )
# define DOT  cblas_sdot
# define SYRK cblas_ssyrk
# define SYSV LAPACKE_ssysv
# else
# define FCT   diata
# define FCT1  diata1
# define FCT2  diata2
# define FCT3  diata3
# define FCT4  diata4
# define FCTN  diataN
# define INV(x) ( x ? 1.0 / x : 0.0 )
# define DOT  cblas_ddot
# define SYRK cblas_dsyrk
# define SYSV LAPACKE_dsysv
# endif

void FCT(FLT *b, FLT *a1, int m1, FLT *a2, int m2, int n) {
   switch (n) {
    case 0:  break;
    case 1:  FCT1(b, a1, m1, a2, m2); break;
    case 2:  FCT2(b, a1, m1, a2, m2); break;
    case 3:  FCT3(b, a1, m1, a2, m2); break;
    case 4:  FCT4(b, a1, m1, a2, m2); break;
    default: FCTN(b, a1, m1, a2, m2, n);break;
   }
}

/* function b = iata1(a1,a2) */
void FCT1(FLT *b, FLT *a1, int n1, FLT *a2, int n2) {
 FLT a00;
 
 /* a1'*a1 + a2'*a2 */
 a00 = DOT(n1, a1     , 1, a1     , 1) + DOT(n2, a2     , 1, a2     , 1);
 
 b[0] = INV(a00);
}

/* function b = iata2(a1,a2) */
void FCT2(FLT *b, FLT *a1, int n1, FLT *a2, int n2) {
 FLT a00, a10, a11, d1, id0, id1, l10;
 
 /* a1'*a1 + a2'*a2 */
 a00 = DOT(n1, a1     , 1, a1     , 1) + DOT(n2, a2     , 1, a2     , 1);
 a10 = DOT(n1, a1+  n1, 1, a1     , 1) + DOT(n2, a2+  n2, 1, a2     , 1);
 a11 = DOT(n1, a1+  n1, 1, a1+  n1, 1) + DOT(n2, a2+  n2, 1, a2+  n2, 1);

 /* LDL' decomposition */
 id0 = INV(a00);
 l10 =   a10             * id0;
 d1  =   a11 - l10*a10;
 id1 = INV(d1);
 
 /* inversion by back substitution */
        b[3] =                     id1                      ;
 b[1] = b[2] =           - l10   * id1                      ;
        b[0] =                     id0 - l10*b[1]           ;
}

/* function b = iata3(a1,a2) */
void FCT3(FLT *b, FLT *a1, int n1, FLT *a2, int n2) {
 FLT a00, a10, a20, a11, a21, a22, d1, d2, id0, id1, id2, l10, l20, l21;
 
 /* a1'*a1 + a2'*a2 */
 a00 = DOT(n1, a1     , 1, a1     , 1) + DOT(n2, a2     , 1, a2     , 1);
 a10 = DOT(n1, a1+  n1, 1, a1     , 1) + DOT(n2, a2+  n2, 1, a2     , 1);
 a20 = DOT(n1, a1+2*n1, 1, a1     , 1) + DOT(n2, a2+2*n2, 1, a2     , 1);
 a11 = DOT(n1, a1+  n1, 1, a1+  n1, 1) + DOT(n2, a2+  n2, 1, a2+  n2, 1);
 a21 = DOT(n1, a1+2*n1, 1, a1+  n1, 1) + DOT(n2, a2+2*n2, 1, a2+  n2, 1);
 a22 = DOT(n1, a1+2*n1, 1, a1+2*n1, 1) + DOT(n2, a2+2*n2, 1, a2+2*n2, 1);

 /* LDL' decomposition */
 id0 = INV(a00);
 l10 =   a10             * id0;
 d1  =   a11 - l10*a10;
 id1 = INV(d1);
 l20 = a20 * id0;
 l21 = ( a21 - l20*a10 ) * id1;
 d2  =   a22 - l20*a20 - l21*l21*d1;
 id2 = INV(d2);
 
 /* inversion by back substitution */
        b[8] =                     id2                      ;
 b[5] = b[7] =           - l21   * id2                      ;
 b[2] = b[6] = ( l21*l10 - l20 ) * id2                      ;
        b[4] =                     id1 - l21*b[5]           ;
 b[1] = b[3] =           - l10   * id1 - l21*b[2]           ;
        b[0] =                     id0 - l10*b[1] - l20*b[2];
}

/* function b = iata4(a1,a2) */
void FCT4(FLT *b, FLT *a1, int n1, FLT *a2, int n2) {
 FLT a00, a10, a20, a30, a11, a21, a31, a22, a32, a33,
     d1, d2, d3, id0, id1, id2, id3,
     l10, l20, l30, l21, l31, l32;
 
 /* a1'*a1 + a2'*a2 */
 a00 = DOT(n1, a1     , 1, a1     , 1) + DOT(n2, a2     , 1, a2     , 1);
 a10 = DOT(n1, a1+  n1, 1, a1     , 1) + DOT(n2, a2+  n2, 1, a2     , 1);
 a20 = DOT(n1, a1+2*n1, 1, a1     , 1) + DOT(n2, a2+2*n2, 1, a2     , 1);
 a30 = DOT(n1, a1+3*n1, 1, a1     , 1) + DOT(n2, a2+3*n2, 1, a2     , 1);
 a11 = DOT(n1, a1+  n1, 1, a1+  n1, 1) + DOT(n2, a2+  n2, 1, a2+  n2, 1);
 a21 = DOT(n1, a1+2*n1, 1, a1+  n1, 1) + DOT(n2, a2+2*n2, 1, a2+  n2, 1);
 a31 = DOT(n1, a1+3*n1, 1, a1+  n1, 1) + DOT(n2, a2+3*n2, 1, a2+  n2, 1);
 a22 = DOT(n1, a1+2*n1, 1, a1+2*n1, 1) + DOT(n2, a2+2*n2, 1, a2+2*n2, 1);
 a32 = DOT(n1, a1+3*n1, 1, a1+2*n1, 1) + DOT(n2, a2+3*n2, 1, a2+2*n2, 1);
 a33 = DOT(n1, a1+3*n1, 1, a1+3*n1, 1) + DOT(n2, a2+3*n2, 1, a2+3*n2, 1);

 /* LDL' decomposition */
 id0 = INV(a00);
 l10 = a10 * id0;
 d1  = a11 - l10*a10;
 id1 = INV(d1);
 l20 = a20 * id0;
 l21 = ( a21 - l20*a10 ) * id1;
 d2  =   a22 - l20*a20 - l21*l21*d1;
 id2 = INV(d2);
 l30 = a30 * id0;
 l31 = ( a31 -l30*a10 ) * id1;
 l32 = ( a32 -l30*a20 -l31*l21*d1 ) * id2;
 d3  =   a33 -l30*a30 -l31*l31*d1 - l32*l32*d2;
 id3 = INV(d3);

 /* inversion by back substitution */
         b[15] =                 id3                                    ;
 b[14] = b[11] =          - l32 *id3                                    ;
 b[13] = b[ 7] = (l21*l32 - l31)*id3                                    ;
 b[12] = b[ 3] = (l20*l32 - l30)*id3 - l10*b[ 7]                        ;
         b[10] =                 id2 -             l32*b[11]            ;
 b[ 9] = b[ 6] =          - l21 *id2 -             l32*b[ 7]            ;
 b[ 8] = b[ 2] = (l10*l21 - l20)*id2 -             l32*b[ 3]            ;
         b[ 5] =                 id1 -             l21*b[ 6] - l31*b[ 7];
 b[ 4] = b[ 1] =          - l10 *id1 -             l21*b[ 2] - l31*b[ 3];
         b[ 0] =                 id0 - l10*b[ 1] - l20*b[ 2] - l30*b[ 3];
}

/* For general case use LAPACK for matrix inversion */

/* function b = iatan(a1,a2) */
int FCTN(FLT *b, FLT *a1, int n1, FLT *a2, int n2, int n) {
 int *ipiv;
 int i,s;
 FLT *a;
 
 /* Compute AT*A */
 a = malloc(n*n*sizeof(FLT));
 SYRK(CblasColMajor, CblasUpper, CblasTrans, n, n1, 1, a1, n1, 0, a, n);
 if (n2>0) SYRK(CblasColMajor, CblasUpper, CblasTrans, n, n2, 1, a2, n2, 1, a, n);
 
 /* Invert AT*A */
 ipiv = malloc(n*sizeof(int));
 /* Construct eye matrix for rhs */
 for(i=0; i<n*n; i++)
   b[i] = (i%(n+1)) ? 0 : 1;
 s = SYSV(LAPACK_COL_MAJOR, 'U', n, n, a, n, ipiv, b, n);
 
 /* Cleanup */
 free(ipiv);
 free(a);
 
 return s;
}

