classdef bf_tests < meq_test
  % Tests for basis functions
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    tol = 1e-2;
    verbosity = 0;
    n = 41; % number of FN points

    % parameters for dummy test equilibrium
    r0 = 1;
    rBt = 1.4;
    
    FN; L; LY;
    GN; IGN;
    bfh; bfp;
    alpg; alg;
  end
  
  properties (ClassSetupParameter)
    % Various combinations of boundary and axis flux
    FAs    = {-1   0   2  0.5 };
    FBs    = {0    1   0  0   };
  end
  
  properties (MethodSetupParameter)
    bf     = {'bffbt' 'bfabmex' 'bf3pmex' 'bf3imex' 'bfef' 'bfefmex' 'bfsp' 'bfgenD'};
  end
  
  methods(TestClassSetup,ParameterCombination='sequential')
    
    function gen_Fx(testCase,FAs,FBs)
      r0 = testCase.r0; z0=0; rBt = testCase.rBt; %#ok<*PROPLC>
      L = liu('ana',1);
      
      [testCase.L,testCase.LY] = testCase.getCircularEquilibrium(L,r0,z0,FAs,FBs,rBt);
      testCase.FN = linspace(0,1,testCase.n)';
     end
    
  end
  
  methods(TestMethodSetup,ParameterCombination='sequential')

    function bfh_setup(testCase,bf)
      % Cannot set these values as properties if calling meq functions (not
      % necessarily available in the path when initialising the class)
      switch bf
        case 'bffbt',   bfpar = [1.1 .25 1.5 2.2 .75 2.5];
        case 'bfabmex', bfpar = [3 1];
        case 'bf3pmex', bfpar = false;
        case 'bf3imex', bfpar = bfp3imex();
        case 'bfef',    bfpar = [4 4];
        case 'bfefmex', bfpar = [2 3];
        case 'bfsp',    bfpar = bfpsp_();
        case 'bfgenD',  bfpar = {@bfabmex,[3 1],1};
        otherwise,      error('Basis function %s not handled yet',bf);
      end
      testCase.bfp = bfpar;
      testCase.bfh = str2func(bf); % function handle
      LY = testCase.LY;

      % Call basic basis function evaluator
      [testCase.GN ,testCase.IGN ] = testCase.bfh(2,bfpar,testCase.FN,LY.FA,LY.FB);
      
      % Conversion factors (See chapter 2 of the MEQ Redbook for definitions of alpg (alpha'_g) and alg (alpha_g))
      FBA = LY.FB - LY.FA;
      switch bf
        case 'bfabmex'
          testCase.alpg = [FBA.^(1:bfpar(1)),FBA.^(1:bfpar(2))];
          testCase.alg  = testCase.alpg*FBA;
        case 'bf3pmex'
          testCase.alpg = [FBA,FBA,FBA.^2];
          testCase.alg  = testCase.alpg*FBA;
        case {'bfef','bfefmex'}
          testCase.alpg = [FBA.^(0:bfpar(1)-1),FBA.^(0:bfpar(2)-1)];
          testCase.alg  = testCase.alpg*FBA;
        case {'bffbt','bf3imex','bfsp'}
          testCase.alpg = ones(1,size(testCase.GN,2));
          testCase.alg  = testCase.alpg*FBA;
        case 'bfgenD'
          testCase.alpg = [FBA.^(1:bfpar{2}(1)),FBA.^(1:bfpar{2}(2))];
          testCase.alg  = testCase.alpg*FBA;
        otherwise
          error('Basis function %s not handled yet',bf);
      end
          
          
    end
  end 

  methods(Test, TestTags = {'Unit','bf'})
   
    function test_factors(testCase)
      % simple test of factors
      [FP,FT] = testCase.bfh(0,testCase.bfp);
      testCase.assertTrue(all(FP==1 | FP==0),'FP must be 1 or 0')
      testCase.assertTrue(all(FT==1 | FT==0),'FT must be 1 or 0')
      testCase.assertEqual(size(FP),size(FT),'FP,FT unequal size')
    end
    
    function test_on_rz_grid(testCase)
      L  = testCase.L ;
      LY = testCase.LY;
      % Test evaluating basis function and its integrals on rz grid
      [TYG,TPDG,ITPDG    ] = testCase.bfh(1,testCase.bfp, ...
        LY.Fx,LY.FA,LY.FB,LY.Opy,L.ry,L.iry);
      
      ny = numel(L.ry)*numel(L.zy); % number of grid points
      nbf = size(testCase.GN,2); % number of basis functions
      
      % test sizes
      testCase.verifyEqual(size(TYG)  ,[ny,nbf], 'invalid size for TYG')
      testCase.verifyEqual(size(TPDG) ,[1,nbf] , 'invalid size for TPDG')
      testCase.verifyEqual(size(ITPDG),[1,nbf] , 'invalid size for ITPDG')
      
    end
    
    function test_rz_integral_value(testCase)
      L  = testCase.L ;
      LY = testCase.LY;
      % test values of integrals of basis functions and primitives
      % for cases where we can reasonably compute this analytically.
      rr = testCase.r0*ones(size(L.ry));
      [TYG,TPDG,ITPDG    ] = testCase.bfh(1,testCase.bfp, ...
        LY.Fx,LY.FA,LY.FB,LY.Opy,rr,rr);
      
      qa0 = (LY.rB-LY.rA).^2 + (LY.zB-LY.zA).^2; % Plasma minor radius
      
      switch func2str(testCase.bfh)
        case 'bf3imex'
          dS = mean(diff(L.ry)).*mean(diff(L.zy));
          if testCase.bfp.fPg(1) && ~testCase.bfp.fTg(1)
            % first basis function: linear -> quadratic in R: paraboloid
            testCase.verifyEqual(TPDG(1),pi*qa0/2/dS,'RelTol',1e-2); % paraboloid volume
          end
          if testCase.bfp.fPg(2) && ~testCase.bfp.fTg(2)
            % second basis function: sqrt -> linear in R: cone
            testCase.verifyEqual(TPDG(2),pi*qa0/3/dS,'RelTol',1e-2); % cone volume
          end
          if testCase.bfp.fPg(3) && ~testCase.bfp.fTg(3)
            % third basis function: constant -> cylinder
            testCase.verifyEqual(TPDG(3),pi*qa0/dS  ,'RelTol',1e-2); % cylinder volume
          end
        otherwise
          return % no analytical value for this case
      end
    end
    
    function test_profile_fitting(testCase)
      %% Test that fitting ag for given profiles
      % Test that when fitting p, T profiles using ag, 
      % we get back the same ag as originally used
      L  = testCase.L ;
      LY = testCase.LY;
      
      % Call bf(3,.. to get scaling coefficients for p', TT' profiles
      L.pQ = sqrt(linspace(0,1,41)'); %psiN grid
      L.idsx = 0.01.^2; % ficticious 1/|dS|
      L.bfct = testCase.bfh;
      L.bfp  = testCase.bfp;
      L.nD = 1;
      
      [L.fPg,L.fTg,L.TDg] = L.bfct(0,L.bfp);
      ag = 2*(L.fPg | L.fTg); % coefficients representing basis functions
      L.ng = numel(ag);
      
      % fitting
      rBt = 1; % dummy value used for T
      smalldia = 0; % no small diamagnetism approximation
      
      [PpQ,TTpQ] = meqprof(L.fPg,L.fTg,ag,testCase.FN,...
        LY.FA,LY.FB,rBt,L.bfct,L.bfp,L.idsx,smalldia);
      
      if testCase.verbosity
        subplot(211)
        plot(L.pQ.^2,PpQ)
        subplot(212)
        plot(L.pQ.^2,TTpQ)
      end
      
      % fit using meqfitprof
      agfit = meqfitprof(L,L.pQ.^2,LY.FA,LY.FB,PpQ,TTpQ);
      
      % check for same ag
      testCase.verifyEqual(ag,agfit,'AbsTol',testCase.tol)
      
    end
    %function test_chord_integral(testCase)
      % Test evaluating integral along Z of basis function
      % %% not working yet, do later
      % L  = testCase.L ;
      % LY = testCase.LY;
      %
      % rr = mean(L.ry); % r value for z integral evaluation
      % ir = find(L.ry<rr,1,'last'); KD = 1+[ir,ir+1]; % indices into rx
      % FD = (rr-L.ry(ir:ir+1))/diff(L.ry(ir:ir+1)); % linear interpolation coefficients
      %
      %[TDG,TGY         ] = testCase.bfh(4,testCase.bfp,...
      %  LY.Fx,LY.FA,LY.FB,LY.Opy,KD', FD'         );
      
      %testCase.verifyEqual(size(TDG), [1,nbf], 'invalid size for TYG')
    %end
    
    function test_axis_value(testCase)
      % Test axis value of basis function and its primitive
      LY = testCase.LY;
      
      [G1,IG1] = testCase.bfh(5,testCase.bfp,[],LY.FA,LY.FB);
      
      % Get conversion factors for normalised quantities
      GN1 = G1./testCase.alpg;
      IGN1 = IG1./testCase.alg;
      
      % Dummy test with bfh(2,..) which should give same result
      [GN ,IGN ] = testCase.bfh(2,testCase.bfp,0,LY.FA,LY.FB);
      %
      testCase.verifyEqual( GN, GN1, 'Abstol',   testCase.tol, ...
        sprintf('%s: value at the axis mismatch for [FA,FB]=%3.3f, %3.3f',...
        func2str(testCase.bfh),LY.FA,LY.FB));
      testCase.verifyEqual(IGN,IGN1, 'Abstol',   testCase.tol, ...
        sprintf('%s: integral value at the axis mismatch for [FA,FB]=%3.3f, %3.3f',...
        func2str(testCase.bfh),LY.FA,LY.FB));
    end
   
    function test_regularisation(testCase)
      L  = testCase.L ;
      LY = testCase.LY;
      % Test computation of elements for the regularisation equations
      rA = testCase.r0;
      args = {[],LY.FA,LY.FB,rA,1./rA,L.dsx}; % arguments common to all bfs
      [Qqg,Xq] = testCase.bfh(6,testCase.bfp,args{:});
      switch func2str(testCase.bfh)
        case 'bf3pmex'
          [Qqg0,Xq0] = bfabmex(6,[1,2] - testCase.bfp,args{:});
          testCase.verifyEqual(Qqg,Qqg0, 'Abstol', testCase.tol, ...
            'regularisation matrix Qqg mismatch between bf3pmex and bfabmex with bfp=[1,2]');
          testCase.verifyEqual(Xq,Xq0, 'Abstol', testCase.tol, ...
            'regularisation value Xq mismatch between bf3pmex and bfabmex with bfp=[1,2]');
        case {'bfef','bfefmex'}
          [Qqg0,Xq0] = bfabmex(6,max(testCase.bfp - 1,0),args{:});
          nP = testCase.bfp(1); nT = testCase.bfp(2); ng = nP + nT;
          mask = true(ng,1);
          if nP > 0, mask(1)    = false;end
          if nT > 0, mask(1+nP) = false;end
          testCase.verifyEqual(Qqg(1:end-2,mask),Qqg0, 'Abstol', testCase.tol, ...
            'regularisation matrix Qqg mismatch between bfef and bfabmex');
          testCase.verifyEqual(Xq(1:end-2),Xq0, 'Abstol', testCase.tol, ...
            'regularisation value Xq mismatch between bfef and bfabmex');
        case 'bf3imex'
          testCase.verifyEqual(Qqg,[0,0,0],sprintf('regularisation matrix Qqg should be [0,0,0] for %s',func2str(testCase.bfh)));
          testCase.verifyEqual( Xq,     0 ,sprintf('regularisation value Xq should be 0 for %s',func2str(testCase.bfh)));
        case 'bffbt'
          testCase.verifyEmpty(Qqg,sprintf('regularisation matrix Qqg should be empty for %s',func2str(testCase.bfh)));
          testCase.verifyEmpty( Xq,sprintf('regularisation value Xq should be empty for %s',func2str(testCase.bfh)));
        otherwise
          testCase.assumeFail(sprintf('No tests defined yet for mode=6 for %s',func2str(testCase.bfh)))
      end
    end
    
    function test_toroidal_field(testCase)
      % Test computation of toroidal field on y grid as computed by bf(8,...)
      % [BTY             ] = BF(8,PAR, F,FA,FB,O ,  A,RBT,IDS,IRY)
      L  = testCase.L ;
      LY = testCase.LY;
      % arguments common to all bfs
      idS = L.idsx;
      Fx  = LY.Fx;
      FA  = LY.FA;
      FB  = LY.FB;
      Opy = LY.Opy;
      rBt = 0.88*1.43;
      iry = L.iry;
      switch func2str(testCase.bfh)
        case 'bf3pmex'
          ag = 1e3*[2;4;-10];
        otherwise
          testCase.assumeFail(sprintf('No tests defined yet for mode=8 for %s',func2str(testCase.bfh)))
      end
      Bty = testCase.bfh(8,testCase.bfp,Fx,FA,FB,Opy,ag,rBt,idS,iry);
      
      % Compare with computation based on mode 2 and 3
      FN = (Fx - FA)/(FB - FA);
      [fPg,fTg] = testCase.bfh(0,testCase.bfp);
      [~,IGN] = testCase.bfh(2,testCase.bfp,FN(2:end-1,2:end-1),FA,FB);
      [~,~,~,AHQT] = testCase.bfh(3,testCase.bfp,ag,FA,FB,fPg,fTg,idS);
      hqTx = reshape(IGN*AHQT,size(Opy)).*logical(Opy);
      Bty_ = (rBt+hqTx/rBt).*iry.'; % Using small diamagnetism approx.
      
      testCase.verifyEqual(Bty,Bty_, 'Reltol', testCase.tol, ...
        sprintf('Toroidal field mismatch between mode=8 and mode=2/3 for %s',func2str(testCase.bfh)));
      
    end
      
    function test_bfpr_integral(testCase)
      % test integral as returned by bf(2,..) and bfprmex
      n=3; % bfpr takes maximum 3 basis functions
      GN  = testCase.GN(:,1:n);   IGN = testCase.IGN(:,1:n); %#ok<*PROP>
      
      FN = linspace(0,1,testCase.n);
      
      IGN2 = bfprmex(GN(:,1:n)); % take the integral using bfprmex
      
      % Check vs numerical trapeze integral
      IGN3 = cumtrapz(FN,GN); IGN3 = IGN3 - IGN3(end,:);

      testCase.verifyEqual(IGN3,IGN,   'AbsTol',  testCase.tol, ...
        sprintf('%s(2,...) does not match cumtrapz()',func2str(testCase.bfh)));
      testCase.verifyEqual(IGN3,IGN2,  'AbsTol',   testCase.tol,  ...
        'bfprmex() does not match cumtrapz()'   );
      
      if testCase.verbosity
        figure
        plot(testCase.FN,GN,'-k',testCase.FN,IGN,'-xk',testCase.FN,IGN2,'-+r',testCase.FN,IGN3,'-ob')
        title(char(testCase.bfh))
        shg
      end
    end
  end
end

%% Auxiliary functions
function bfp = bfp3imex()
% parameters for bf3imex test (direct specification of basis functions)
n=41;
FN = linspace(0,1,n);
GN = [linspace(1,0,n)',1-sqrt(linspace(0,1,n)'),linspace(1,1,n)'];
IGN = flipud(cumtrapz(flipud(FN'),flipud(GN)));

FP = [1;0;0];
FT = [0;1;1];
bfp = struct('gNg',GN,'IgNg',IGN,'fPg',FP,'fTg',FT);
end

function bfp = bfpsp_()
% parameters for bfsp test (B-spline basis functions)
t = linspace(0,1,11); % Use 11 knots
ec = 'n';             % Not-a-knot BC for p'/TT' at the edge
bfp = bfpsp(t,ec);
end
