function varargout = QLKNN(varargin)
% [out,dout_dx] = QLKNN(stap,geop,trap,model,QLKNN_net,QLKNN_params);
%
%   networktype: string indicating type of network
%   stap, geop: RAPTOR physics containers
%   model: RAPTOR model
%   QLKNN_params: parameter structure for this QLKNN model
%
%  To get default params, use
%      QLKNN_params = QLKNN(networktype);

% #codegen

if nargin==1
  QLKNNtype = varargin{1};
  assert(ischar(QLKNNtype),'first argument must be string with network type');
  
  % Call to get default parameters
  % return default parameter structure for this QLKNN type
  switch QLKNNtype
    case 'QLKNN4Dkin'
      [QLKNNnet,QLKNNparams] = QLKNN4Dkin; % return default parameters
    case 'QLKNN8D'
      [QLKNNnet,QLKNNparams] = QLKNN8D;
    case 'QLKNN10D'
      [QLKNNnet,QLKNNparams] = QLKNN10D;
    otherwise
      error('unknown QLKNN model type')
  end
  
  varargout{1} = QLKNNnet;
  varargout{2} = QLKNNparams;
  return
elseif (nargin==7) || (nargin==6)
  stap = varargin{1};
  geop = varargin{2};
  trap = varargin{3};
  model = varargin{4};
  QLKNNnet = varargin{5};
  QLKNNparams = varargin{6};
  if nargin == 7
    debug = varargin{7};
  else
    debug = {};
  end
else
  error('invalid number of input arguments')
end

%% Actual call to get values
switch QLKNNnet.name
  case 'QLKNN4Dkin'
    [NNin,dNNin_dx,kfac,dkfac_dx] = RAPTORtoQLKNN4Dkin(stap,geop,model);
    [NNout,dNNout_dNNin] = QLKNN4Dkin(NNin,QLKNNnet,QLKNNparams);
  case {'QLKNN8D','QLKNN10D'}
    [NNin,dNNin_dx,kfac,dkfac_dx,QLKNNparams] = RAPTORtoQLKNN10D(stap,geop,trap,model,QLKNNparams);
    
    if strcmp(model.hmode.modeltype, 'imposed')
      % Strip pedestal from input
      iped = model.rgrid.rhogauss > model.hmode.rhoped;
      NNin = NNin(~iped, :);
      dNNin_dx = dNNin_dx(~iped, :, :);
      kfac = kfac(~iped, :);
      dkfac_dx = dkfac_dx(~iped, :, :);
    end
    
    % Evaluate NN
    switch QLKNNnet.name
      case 'QLKNN8D'
        [NNout,dNNout_dNNin] = QLKNN8D(NNin,QLKNNnet,QLKNNparams);
      case 'QLKNN10D'
        [NNout,dNNout_dNNin] = QLKNN10D(NNin,QLKNNnet,QLKNNparams);
      otherwise
        error();
    
    end
  otherwise
    error('unknown QLKNN network type')
end

%% post-processing specific to implementation in RAPTOR transport equation
[coeffOut,dcoeffOut_dx] = NNOut2Coeff(NNout,dNNout_dNNin,dNNin_dx,kfac,dkfac_dx,QLKNNparams,QLKNNnet.name);

noVmask = isfield(QLKNNparams, 'noVmask') && QLKNNparams.noVmask;
if isfield(QLKNNparams, 'use_effective_diffusivity') && QLKNNparams.use_effective_diffusivity && ~noVmask
  if strcmp(model.hmode.modeltype, 'imposed')
    g0 = geop.g0(~iped); g1 = geop.g1(~iped); Vp = geop.Vp(~iped);
  else
    g0 = geop.g0; g1 = geop.g1; Vp = geop.Vp;
  end
  [coeffOut,dcoeffOut_dx] = DVeffMask10D(coeffOut,dcoeffOut_dx,NNout(:,3),NNin(:,4),g0,g1,Vp);
end

[coeffOut,dcoeffOut_dx] = post_processing(coeffOut,dcoeffOut_dx,model,QLKNNnet,QLKNNparams);
if strcmp(model.hmode.modeltype, 'imposed')
  % Fill not-evaluated pedestal with zeros
  coeffOut = [coeffOut; zeros(sum(iped), size(coeffOut,2))];
  dcoeffOut_dx = [dcoeffOut_dx; zeros([sum(iped), size(dcoeffOut_dx,2), size(dcoeffOut_dx,3)])];
end
tcoeff.chie = coeffOut(:,1);
tcoeff.chii = coeffOut(:,2);
tcoeff.de   = coeffOut(:,3);
tcoeff.ve   = coeffOut(:,4);

dtcoeff_dx.dchie_dx = squeeze(dcoeffOut_dx(:,:,1));
dtcoeff_dx.dchii_dx = squeeze(dcoeffOut_dx(:,:,2));
dtcoeff_dx.dde_dx   = squeeze(dcoeffOut_dx(:,:,3));
dtcoeff_dx.dve_dx   = squeeze(dcoeffOut_dx(:,:,4));

varargout{1} = tcoeff;
varargout{2} = dtcoeff_dx;

if QLKNNparams.doplot
    QLKNN_plot_IO(model.rgrid.rhogauss, NNin, NNout, QLKNNnet,QLKNNparams)
end


return


function [coeffOut,dcoeffOut_dx] = DVeffMask10D(coeffOut,dcoeffOut_dx,Gam_e,rlne,g0,g1,Vp)
% patch output coefficients to include condition of when to use Deff or Veff based on
% condition. (for networks that have D/V output only!)

iUseDeff = (rlne.*Gam_e >= 0) & abs(rlne)>0.1;
coeffOut(~iUseDeff,3) = 0;
% coeffOut(~iUseDeff,4) = coeffOut(~iUseDeff,4).* (Vp(~iUseDeff)./g0(~iUseDeff));

% coeffOut(iUseDeff,3) = coeffOut(iUseDeff,3).* (Vp(iUseDeff).^2./g1(iUseDeff));
coeffOut(iUseDeff,4) = 0;


dcoeffOut_dx(~iUseDeff,:,3) = 0; % keff factor for Deff
% dcoeffOut_dx(~iUseDeff,:,4) = bsxfun(@times, dcoeffOut_dx(~iUseDeff,:,4), (Vp(~iUseDeff)./g0(~iUseDeff)));

% dcoeffOut_dx(iUseDeff,:,3) = bsxfun(@times, dcoeffOut_dx(iUseDeff,:,3), (Vp(iUseDeff).^2./g1(iUseDeff)));
dcoeffOut_dx(iUseDeff,:,4) = 0; % keff factor for Veff
return
