function passed = QLKNN_test
%pass(1) = one_QLKNN_test('QLKNN4Dkin');
%pass(2) = one_QLKNN_test('QLKNN10D');
pass = one_QLKNN_test('QLKNN10D');

passed = all(pass);

return

function passed = one_QLKNN_test(networkname)
% Test function for QLKNN - checks derivatives
% F. Felici 2018

%% RAPTOR initialization
nn = 1e-8; % size of perturbation
rtol = 1e-4;
[stapnew,stap,geop,model,dx,g,v,x] = RAPTOR_test_setup_AUG(nn);
%% qlk initialization
[QLKNNnet,QLKNNparams] = QLKNN(networkname);
% pp.use_effective_diffusivity    = flags(1);
% pp.use_ion_diffusivity_networks = flags(2);
% pp.apply_victor_rule            = flags(3);
% pp.calc_heat_transport          = flags(4);
% pp.calc_part_transport          = flags(5);
% pp.merge_modes                  = flags(6);
% pp.apply_sanity_clipping        = flags(7);
% pp.apply_stability_clipping     = flags(8);
% pp.useETG                       = flags(9);
% pp.useITG                       = flags(10);
% pp.useTEM                       = flags(11);
% pp.constrainInputs(:)           = flags(12);

[name, QLKNNparams] = QLKNN_set_pp(QLKNNparams, [true NaN false true true true false true true true true false true]); %not sanity clipped
%[name, QLKNNparams] = QLKNN_set_pp(QLKNNparams, [true NaN false true true true true true true true true false true]); %sanity clipped

%QLKNNparams.maxOutput(:) = 1e20;
%QLKNNparams.minOutput(:) = -1e20;
%% Nominal case
[coeff,dcoeff_dx] = QLKNN(stap,geop,model,QLKNNnet,QLKNNparams);

%% Apply perturbation
dx = nn*rand(size(x)); %dx = zeros(size(x)); dx(end-2) = nn;

myfields = fieldnames(coeff);
dcoeff_dx_num = struct();
filename = 'dRAPTOR_din_num.mat';
if ~(exist(filename, 'file') == 2)
    for ii = numel(x):-1:1
        disp(ii)
        pert = nn;
        xleft = x;
        xright = x;
        xleft(ii) = x(ii) - pert;
        xright(ii) = x(ii) + pert;
        stap_left  = state_profiles(xleft,0*g,g,0*g,v,0*v,model);
        stap_right = state_profiles(xright,0*g,g,0*g,v,0*v,model);
        coeff_left = QLKNN(stap_left,geop,model,QLKNNnet,QLKNNparams);
        coeff_right = QLKNN(stap_right,geop,model,QLKNNnet,QLKNNparams);
        for ifield = 1:numel(myfields)
            myfield = myfields{ifield};
            dxstr = sprintf('d%s_dx',myfield);
            dcoeff_dx_num.(dxstr)(:, ii) = (coeff_right.(myfield) - coeff_left.(myfield)) / (2 .* pert);
        end
    end  
    save(filename, 'dcoeff_dx_num')
else
    load(filename)
end

failing = [];
for ifield = 1:numel(myfields)
    myfield = myfields{ifield};
    dxstr = sprintf('d%s_dx',myfield);
    for ii = numel(x):-1:1
        %err.(dxstr)(ii) = norm(dcoeff_dx.(dxstr)(:, ii) - dcoeff_dx_num.(dxstr)(:, ii))/ norm(dcoeff_dx.(dxstr)(:, ii));
        err.(dxstr)(ii) = calc_error(dcoeff_dx.(dxstr)(:, ii), ...
            dcoeff_dx_num.(dxstr)(:, ii), coeff.(myfield));
        if err.(dxstr)(ii) > rtol
            failing = [failing ii]; %#ok<AGROW>
        end
    end
end
unique(failing)
passed = ~any(unique(failing));
% %% Test 
% [tcoeffplus] = QLKNN(stapnewplus,geop,model,QLKNNnet,QLKNNparams);
% [tcoeffmin] = QLKNN(stapnewmin,geop,model,QLKNNnet,QLKNNparams);
% errs = coeff_diff(tcoeffmin, tcoeffplus,tcoeff,dtcoeff_dx,dx,false);
% 
% %% Check if errors sufficiently small
% tol = 1e-4; % error tolerance
% passed = all(errs(~isnan(errs))<tol);
%  
% if ~passed
%   % plot problems
%   coeff_diff(tcoeffmin, tcoeffplus,tcoeff,dtcoeff_dx,dx,true, tol);
% end
return

function errs = coeff_diff(coeffmin, coeffplus,coeffzero,dcoeff_dx,dx,doplot, tol)
% compute errors for all structure fields
myfields = fieldnames(coeffzero);

errs = zeros(1,numel(myfields)); % prealloc
for ifield = 1:numel(myfields)
  myfield = myfields{ifield};
  dxstr = sprintf('d%s_dx',myfield);
  if strcmp(myfield, 've')
      break %ve not predicted for now
  end
  diffnumplus = coeffplus.(myfield) - coeffzero.(myfield); % out(x + dx) - out(x)
  diffnummin = coeffzero.(myfield) - coeffmin.(myfield); % out(x) - out(x + dx)

  diffana = dcoeff_dx.(dxstr)*dx; % dout/dx * dx
  
  %err = norm(diffnum-diffana)/norm(diffnum)
  rel_err_plus = (diffnumplus - diffana).^2 ./ sum(diffnumplus.^2);
  err_plus = sqrt(sum(rel_err_plus));
  
  rel_err_min= (diffnummin - diffana).^2 ./ sum(diffnummin.^2);
  err_min = sqrt(sum(rel_err_min));
  
  rel_err_minplus = (diffnumplus - diffnummin).^2 ./ sum(diffnumplus.^2);
  err_minplus = sqrt(sum(rel_err_minplus));
  errs(ifield) = err_plus;
  %fprintf('error = %4.4e \n',errs(ifield));
  if doplot
    subplot(221); plot([diffnumplus,diffana]); xlabel('rho'); legend('out(x + dx) - out(x)', 'dout/dx * dx', 'Location', 'best'); title(myfield)
    subplot(222); plot(diffnumplus-diffana); xlabel('rho'); legend('[out(x + dx) - out(x)] - dout/dx * dx', 'Location', 'best');
    subplot(223); plot(rel_err_plus); xlabel('rho'); legend('{[out(x + dx) - out(x)] - dout/dx * dx}/[out(x + dx) - out(x)]', 'Location', 'best');
    %bound = norm(diffnum) * tol;
    %hold on; plot([1 size(diffnum, 1)], [bound,bound], '--g'); plot([1 size(diffnum, 1)], [-bound,-bound], '--g'); hold off;
    fprintf('error = %4.4e \n',errs(ifield));
    disp('paused, hit f5 to continue')
    keyboard
  end
end
return