function passed = QLKNN_test
%pass(1) = one_QLKNN_test('QLKNN4Dkin');
%pass(2) = one_QLKNN_test('QLKNN10D');
pass = one_QLKNN_test('QLKNN10D');

passed = all(pass);

return

function passed = one_QLKNN_test(networkname)
% Test function for QLKNN - checks derivatives
% F. Felici 2018

%% RAPTOR initialization
nn = 1e-8; % size of perturbation
[stapnew,stap,geop,model,dx] = RAPTOR_test_setup_AUG(nn);

%% qlk initialization
[QLKNNnet,QLKNNparams] = QLKNN(networkname);
%QLKNNparams.maxOutput(:) = 1e20;
%QLKNNparams.minOutput(:) = -1e20;
%% Nominal case
%[NNinputsnew,dNNinputs_dxnew,kfacnew,dkfac_dxnew] = RAPTORtoQLKNN10D(stapnew,geop,model);
%[name, QLKNNparams] = QLKNN_set_pp(QLKNNparams, [true NaN false true true true false true true true true false true]); %not sanity clipped
[name, QLKNNparams] = QLKNN_set_pp(QLKNNparams, [true NaN false true true true true true true true true false true]); %sanity clipped
[NNin,dNNin_dx,kfac,dkfac_dx] = RAPTORtoQLKNN10D(stap,geop,model);
[NNout,dNNout_dNNin] = QLKNN10D(NNin,QLKNNnet,QLKNNparams);
[coeffOut,dcoeffOut_dx] = NNOut2Coeff(NNout,dNNout_dNNin,dNNin_dx,kfac,dkfac_dx,QLKNNparams,QLKNNnet.name);

input_names = {'Zeff', 'Ati', 'Ate', 'An', 'q', 'smag', 'x',  'Ti/Te', 'log10(Nustar)', 'gammaE_GB'};
markers =     {'+',    '+',   'x',   'o',  'o', 'x',    'o',  'x',     'o',             'o'};
colors =      {'b',    'b',   'r',   'g',  'b', 'r',    'r',  'g',     'b',             'b'};
lines =       {'--',   '--',  ':',   '-.', '--','--',    '--', '--',   '--',            '--'};
close all
subplot(421)
plot_input(2:3, NNin, QLKNNnet, input_names)
subplot(422)
plot_input(4, NNin, QLKNNnet, input_names)
subplot(423)
plot_input(5, NNin, QLKNNnet, input_names)
subplot(424)
plot_input(6, NNin, QLKNNnet, input_names)
subplot(425)
plot_input([1 8], NNin, QLKNNnet, input_names)
subplot(426)
plot_input(7, NNin, QLKNNnet, input_names)
subplot(427)
plot_input(9, NNin, QLKNNnet, input_names)
subplot(428)
plot_input(10, NNin, QLKNNnet, input_names)

figure
% 1 : efeETG_GB
% 2 : efeITG_GB
% 3 : efeTEM_GB
% 4 : efiITG_GB
% 5 : efiTEM_GB
% 6 : pfeITG_GB
% 7 : pfeTEM_GB
% 8 : dfeITG_GB
% 9 : dfeTEM_GB
% 10: vteITG_GB
% 11: vteTEM_GB
% 12: vceITG_GB
% 13: vceTEM_GB
% 14: dfiITG_GB
% 15: dfiTEM_GB
% 16: vtiITG_GB
% 17: vtiTEM_GB
% 18: vciITG_GB
% 19: vciTEM_GB
%output_names = {'qe_{ETG}', 'qe_{ITG}', 'qe_{TEM}', 'qi_{ITG}', 'qi_{TEM}', '\Gamma e_{ITG}', '\Gamma e_{TEM}'};
output_names = {'q_e', 'q_i', '\Gamma_e'};
subplot(311)
plot_output(1, NNout, QLKNNnet, output_names)
subplot(312)
plot_output(2, NNout, QLKNNnet, output_names)
subplot(313)
plot_output(3, NNout, QLKNNnet, output_names)

% figure()
% subplot(421)
% plot_input(2:3, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(422)
% plot_input(4, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(423)
% plot_input(5, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(424)
% plot_input(6, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(425)
% plot_input([1 8], NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(426)
% plot_input(7, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(427)
% plot_input(9, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
% subplot(428)
% plot_input(10, NNinputsnew, QLKNNnet, input_names, markers, colors, lines)
passed = true;

% %% Test 
% [tcoeffnew] = QLKNN(stapnew,geop,model,QLKNNnet,QLKNNparams);
% errs = coeff_diff(tcoeffnew,tcoeff,dtcoeff_dx,dx,false);
% 
% %% Check if errors sufficiently small
% tol = 1e-4; % error tolerance
% passed = all(errs(~isnan(errs))<tol);
% 
% if ~passed
%   % plot problems
%   coeff_diff(tcoeffnew,tcoeff,dtcoeff_dx,dx,true, tol);
% end
return

function errs = coeff_diff(coeff2,coeff1,dcoeff_dx,dx,doplot, tol)
% compute errors for all structure fields
myfields = fieldnames(coeff1);

errs = zeros(1,numel(myfields)); % prealloc
for ifield = 1:numel(myfields)
  myfield = myfields{ifield};
  diffnum = coeff2.(myfield)-coeff1.(myfield); % out(x + dx) - out(x)
  dxstr = sprintf('d%s_dx',myfield);
  diffana = dcoeff_dx.(dxstr)*dx; % dout/dx * dx
  
  err = norm(diffnum-diffana)/norm(diffnum);
  errs(ifield) = err;
  fprintf('error = %4.4e \n',errs(ifield));
  if doplot
    subplot(211); plot([diffnum,diffana]); xlabel('rho'); legend('out(x + dx) - out(x)', 'dout/dx * dx', 'Location', 'best'); title(myfield)
    subplot(212); plot(diffnum-diffana); xlabel('rho'); legend('[out(x + dx) - out(x)] - dout/dx * dx', 'Location', 'best');
    %bound = norm(diffnum) * tol;
    %hold on; plot([1 size(diffnum, 1)], [bound,bound], '--g'); plot([1 size(diffnum, 1)], [-bound,-bound], '--g'); hold off;
    fprintf('error = %4.4e \n',errs(ifield));
    disp('paused, hit key to continue')
    pause
  end
 
end
return

function plot_output(output_idx, NNoutput, QLKNNnet, output_names, model)
for ii = output_idx
      scatter(1:size(NNoutput, 1), NNoutput(:, ii), 'DisplayName', output_names{ii})
      hold on
end
legend('Location', 'best')

return

function plot_input(input_idx, NNinputs, QLKNNnet, input_names, model)
markers =     {'o', '+', 'x'};
colors =      {'r', 'g', 'b'};
lines =       {'--', ':', '-.'};
ii = 1;
for inp = input_idx
    name = input_names{inp};
    scatter(1:size(NNinputs, 1), NNinputs(:, inp), 'DisplayName', name, 'Marker', markers{ii}, 'MarkerEdgeColor', colors{ii})
    hold on
    ii = ii + 1;
end
legend('Location', 'best')
ii = 1;
for inp = input_idx
    if inp < 10
        plot(1:size(NNinputs, 1), repmat(QLKNNnet.feature_min(inp), [1 size(NNinputs, 1)]), '--', 'Color', colors{ii})
        plot(1:size(NNinputs, 1), repmat(QLKNNnet.feature_max(inp), [1 size(NNinputs, 1)]), '--', 'Color', colors{ii})
    end
    ii = ii + 1;
end
