function varargout = NNOut2Coeff(NNout,dNNout_dNNin,dNNin_dx,kfac,dkfac_dx,pp,networkname)
% [QLKNNout,dQLKNNout_dx] = NNOut2Coeff(NNout,dNNout_dNNin,dinputs_dx,kfac,dkfac_dx);
% Generic function to compute QLKNN outputs in SI units from
% (GB-normalized) fluxes or diffusivities coming from neural network.
%
% Dimensions of inputs:
%   NNout: [nEval x nOut]
%   dNNout_dNNin: [nEval x nIn x nOut] Karel note: I think it should be [nEval x nOut x nIn]
%   dNNin_dx: [nEval x nStates x nIn]
%   kfac: [nEval x nOut];
%   dkfac_dx: [nEval x nStates x nOut]
%   iCombine: [1 x nOut2RAPTOR] [maps outputs of NN to outputs to RAPTOR]

% F. Felici 2018;

% flag to calculate derivatives or not
if nargout==2
  calcder=true;
else
  calcder=false;
end

% get mappings
switch networkname
  case 'QLKNN4Dkin'
    [iNNoutlist,ikfaclist] = map_NNout_kfac_4D;
  case {'QLKNN10D', 'QLKNN8D'}
    useDeff = pp.use_effective_diffusivity;
    [iNNoutlist,ikfaclist] = map_NNout_kfac_10D(useDeff);
  otherwise
    error('unknown QLKNN network type')
end
nOut = numel(iNNoutlist);

% sizes
nIn = size(dNNin_dx,3);
nx = size(dNNin_dx,2);
nEval = size(NNout,1);

% init
out = zeros(nEval,nOut);
dout_dx = zeros(nEval,nx,nOut);

for iOut=1:nOut
  % get indices of specific NN output and kfac involved in this output
  iNNout = iNNoutlist(iOut);
  ikfac =  ikfaclist(iOut);
  
  % evaluate NN output and derivatives for this output index
  valueNNout = NNout(:,iNNout);
  dvalueNNout_dx = zeros(nEval,nx);
  for iIn=1:nIn
    % loop over inputs for derivative components (chain rule), for this NN output
    dvalueNNout_dx = dvalueNNout_dx + bsxfun(@times,dNNout_dNNin(:,iNNout,iIn),dNNin_dx(:,:,iIn));
  end
  
  % get k factor for this output index
  kfactor = kfac(:,ikfac);
  dkfactor_dx = dkfac_dx(:,:,ikfac);
  
  % calculate transport model output
  [out(:,iOut),dout_dx(:,:,iOut)] = ...
    calcTransportCoeff(kfactor,dkfactor_dx,valueNNout,dvalueNNout_dx);
end

varargout{1} = out;
if calcder
  varargout{2} = dout_dx;
end

return

function [value,dvalue_dx] = calcTransportCoeff(kfactor,dkfactor_dx,valueNNout,dvalueNNout_dx)

value = valueNNout.*kfactor;

if nargout>1
  dvalue_dx = bsxfun(@times,valueNNout,dkfactor_dx) + ...
    bsxfun(@times,kfactor,dvalueNNout_dx);
  % follows convention [nEval x nx x nOut]
end

return

function [iNNout,ikfac] = map_NNout_kfac_10D(useDeff)
% specific relation between index of kfac and index of NN output needed to
% get a given transport model output.

% index: iOut, table: [iNNout, ikfac]

if useDeff
  iNNout = [1,2,3,3,4,6];
  ikfac  = [1,2,3,5,4,5];
else
  iNNout = [1,2,4,5,6,7];
  ikfac  = [1,2,4,5,4,5];
end
return

function [iNNout,ikfac] = map_NNout_kfac_4D
% this one was easy
iNNout = 1:4;
ikfac  = 1:4;
return