function [f,df_dxdot,df_dx,df_du,df_dd,df_dw,df_dp,stap,geop,trap] = state_equation(x,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt,stapk,geopk,trapk)%#codegen
% [f,df_dxdot,df_dx,df_du,df_dp,stap,geop,trap] = state_equation(x,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt)
% returns function and jacobians (continuous time)
% also returns stap,geop,trap profiles corresponding to inputs
% x is the x of the nonlinear equation f(xdot,x,u).
% xk is the current state value that is treated as explicit variable.

% # codegen

coder.extrinsic('state_equation_check'); % declare as exstrinsic, so no code generated for it
coder.extrinsic('check_jacobians'); % declare as exstrinsic, so no code generated for it
coder.extrinsic('all_causes'); % declare as exstrinsic, so no code generated for it

%% Physics - general profile quantities on gauss grid that go into the various equations
geop = geometry_profiles(g,gdot,model,geopk);
stap = state_profiles(x,xdot,g,gdot,v,vdot,model,stapk);
trap = transport_profiles(stapk,stap,geop,dk,u,it,model,params,inewt,trapk);

BCs  = boundary_conditions(x,g,v,u,model);

%% PSI part
if nargout == 1
  Psires = psi_equation(x,xdot,u,geop,stap,trap,BCs,model);
else
  [Psires,dPsires_dxdot,dPsires_dx,dPsires_du,dPsires_dd,dPsires_dw] ...
    = psi_equation(x,xdot,u,geop,stap,trap,BCs,model);
end

%% Te part
if nargout == 1
  Teres = te_equation(x,xdot,u,geop,stap,trap,BCs,model,'e');
else
  [Teres,dTeres_dxdot,dTeres_dx,dTeres_du,dTeres_dd,dTeres_dw] ...
    = te_equation(x,xdot,u,geop,stap,trap,BCs,model,'e');
end

%% Ti part
if nargout == 1
  Tires = te_equation(x,xdot,u,geop,stap,trap,BCs,model,'i');
else
  [Tires,dTires_dxdot,dTires_dx,dTires_du,dTires_dd,dTires_dw] ...
    = te_equation(x,xdot,u,geop,stap,trap,BCs,model,'i');
end

%% ne part
if nargout == 1
  neres = n_equation(x,xdot,u,geop,stap,trap,BCs,model,'e');
else
  [neres,dneres_dxdot,dneres_dx,dneres_du,dneres_dd,dneres_dw] ...
    = n_equation(x,xdot,u,geop,stap,trap,BCs,model,'e');
end

%% NTM part
if nargout == 1
  ntmres = ntm_equation(stap,trap,geop,u,model,params);
else
  [ntmres,dntmres_dxdot,dntmres_dx,dntmres_du,dntmres_dd,dntmres_dw] = ...
    ntm_equation(stap,trap,geop,u,model,params);
end

%%
f = [model.psi.eqscal * Psires;
  model.te.eqscal * Teres;
  model.ti.eqscal * Tires;
  model.ne.eqscal * neres;
  model.ntm.eqscal * ntmres];

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Jacobian
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if nargout > 2
  %% put all jacobians together
  df_dxdot = ([ model.psi.eqscal * dPsires_dxdot;
    model.te.eqscal  * dTeres_dxdot;
    model.ti.eqscal  * dTires_dxdot;
    model.ne.eqscal  * dneres_dxdot;
    model.ntm.eqscal * dntmres_dxdot]);
  
  df_dx    = ([ model.psi.eqscal * dPsires_dx;
    model.te.eqscal  * dTeres_dx;
    model.ti.eqscal  * dTires_dx;
    model.ne.eqscal  * dneres_dx;
    model.ntm.eqscal  * dntmres_dx]);
  
  df_du    = ([ model.psi.eqscal * dPsires_du;
    model.te.eqscal  * dTeres_du;
    model.ti.eqscal  * dTires_du;
    model.ne.eqscal  * dneres_du;
    model.ntm.eqscal * dntmres_du]);
  
  df_dd    = ([ model.psi.eqscal * dPsires_dd;
    model.te.eqscal  * dTeres_dd;
    model.ti.eqscal  * dTires_dd;
    model.ne.eqscal  * dneres_dd;
    model.ntm.eqscal * dntmres_dd]);
  
  df_dw    = ([ model.psi.eqscal * dPsires_dw;
    model.te.eqscal  * dTeres_dw;
    model.ti.eqscal  * dTires_dw;
    model.ne.eqscal  * dneres_dw;
    model.ntm.eqscal * dntmres_dw]);
  
  
  % legacy df_dp, remove at some point
  df_dp = zeros(numel(f),model.np);
  
  if params.debug.checkstateequation
    state_equation_check(x,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt,stapk,geopk,trapk,f,df_dx)
  end
  
  if params.debug.checkprofilejacobians
    check_jacobians(x,xdot,g,gdot,v,vdot,dk,u,inewt,it,model,params)
  end
end

return

function check_jacobians(x,xdot,g,gdot,v,vdot,dk,u,inewt,it,model,params)
pertsize = 1e-8; tol = 1e-4; doplot = params.debug.plotprofilejacobians;
err = calc_jacobians(x,xdot,g,gdot,v,vdot,dk,u,inewt,it,model,params,pertsize,doplot);
try %#ok<EMTC>
  check_jacobian_error_results(err,tol,1)
catch ME
  warning(all_causes(ME))
end
return

function state_equation_check(x,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt,stapk,geopk,trapk,f,df_dx)
% % self-check, per transport channel and per equation
%
dx0 = zeros(size(x));

channels = {'psi','te','ti','ne'}; chanused = false(1,numel(channels));
% list active transport channels
for jchan = 1:numel(channels)
  if strcmp(model.(channels{jchan}).method,'state')
    chanused(jchan) = true;
  end
end
ichansused = find(chanused); % indices of used channels
nchansused = sum(chanused); % number of used channels

for jchan=1:nchansused
  chan = channels{ichansused(jchan)};
  xind = model.(chan).xind;
  %
  %% perturb x
  % Psi perturb
  dx=dx0; dx(xind) = 1e-9*rand(size(xind));
  % recursive call with perturbed x (nargin==1)
  df = state_equation(x+dx,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt,stapk,geopk,trapk) - f;
  for ichan=1:nchansused
    % store in matrices
    chanj = channels{ichansused(ichan)};
    f_ind = model.(chanj).xind;
    xpert.ana{ichan,jchan} = df_dx(f_ind,:)*dx;
    xpert.num{ichan,jchan} = df(f_ind);
  end
end

%%
figname = 'state equation check: dx';
persistent hax
hax = get_persistent_hax(hax, figname, nchansused, nchansused, true);
for ii = 1:nchansused
  % command line display header
  if ii == 1
    fprintf('*** Jacobian check residuals, it=%d,inewt=%d: ***\n',it,inewt);
    fprintf('%10s','');
    for jchan=1:numel(channels)
      if chanused(jchan); fprintf('d%-10s ',channels{jchan});  end
    end
  end
  for jj = 1:nchansused
    ax = hax(ii,jj);
    plot(ax,1:numel(xpert.ana{ii,jj}),xpert.ana{ii,jj},'b',...
      1:numel(xpert.num{ii,jj}),xpert.num{ii,jj},'r--')
    axis(ax,'tight')
    
    % norm
    normerr = norm(xpert.num{ii,jj}-xpert.ana{ii,jj})/(norm(xpert.num{ii,jj})+eps);
    
    % labels
    if jj==1
      fx = sprintf('df_{%s}',channels{ichansused(ii)});
      fprintf('\n%8s',fx)
      ylabel(ax,fx)
    end
    if ii==1
      % plot header
      title(ax,sprintf('dx {%s}',channels{ichansused(jj)}))
    end
    % command line output
    fprintf('%10.3e  ',normerr);
  end
end
fprintf('\n\n')
drawnow
return