function [xk,xdotk,dftilde_dx,dftilde_dxdot,dftilde_dxk,dftilde_dxkm1,dftilde_du,dftilde_dd,dftilde_dw,dftilde_dp,...
  stapk,geopk,trapk,res,inewt,exitflag] = ...
  RAPTOR_predictive_step(it,dtk,xkm1,xdotkm1,gk,gdotk,vk,vdotk,uk,dk,stapkm1,geopkm1,trapkm1,model,params)

coder.extrinsic('plotEachNewtonStep'); % declare as exstrinsic, so no code generated for it
coder.extrinsic('warning');
coder.extrinsic('fprintf');

% temporal discretization
thp = params.numerics.theta; % for stability. Psh = thp*x_{k+1} + (1-thp)*x_k.
% 1 = fully implicit, 0 = fully explicit, 0.5 = crank-nicholson
% xdot = (xk+1-xk) / dt
% x = (theta)*xk+1 + (1-theta)*xk
dxdot_dxk = 1/dtk; dxdot_dxkm1 = -1/dtk;
dx_dxk = thp; dx_dxkm1 = (1-thp);

if strcmp(model.hmode.modeltype,'imposed') 
    if vdotk(model.hmode.vind.activation) > 0
        xkm1 = map_state_L_to_H(xkm1, model);
    elseif vdotk(model.hmode.vind.activation) < 0
        xkm1 = map_state_H_to_L(xkm1, model, vk, vdotk, dtk);
    end
end

xkm1 = add_seed_island(xkm1,vk,model,params);

extrapfactor = params.numerics.extrap*(min(1,(it-1)/(params.numerics.nextrap-1)));
[xk,xdotk] = first_guess(xkm1,xdotkm1,dtk,model,extrapfactor);
d = xk-xkm1; % first descent direction guess

exitflag = -1;
inewt = 0;

% initialize others
dftilde_dx     = zeros(numel(xk));
dftilde_dxdot  = zeros(numel(xk));
dftilde_du     = zeros(numel(xk),numel(uk));
dftilde_dd     = zeros(numel(xk),numel(dk));
dftilde_dw     = zeros(numel(xk),model.dims.nw);
dftilde_dp     = zeros(numel(xk),model.np);

%% Setup & init Newton solver
c_armijo = 0.1; c_backtrack = 0.5; % reduction factor
res_prev = 1e9; res=res_prev; % init
tau = params.numerics.tau;
xk_prev = xkm1;
fprev = inf(model.dims.nx,1); ftilde = fprev;
JAC_prev = zeros(model.dims.nx); JAC=JAC_prev; % init

if ~model.realtime
  residues = zeros(params.numerics.nmax+1,1);
else
  residues = 0;
end

reject_checks = false; reject_residual = false; res_accept = 9e99; % init

stapx = stapkm1;
trapx = trapkm1;
geopx = geopkm1;

while inewt<=params.numerics.nmax
  plotThisNewtonStep = false; % init

  exitflag = check_state_validity(xk,gk,vk,model);
  if exitflag~=-1
    % reject this iterate immediately without even calling state equation
    reject_validity = true;
  else
    reject_validity = false;

    %% Evaluate state function value
    % temporal discretization
    x = dx_dxk*(xk) + dx_dxkm1*xkm1; %
    % call state equation
    [ftilde,dftilde_dxdot,dftilde_dx,dftilde_du,dftilde_dd,dftilde_dw,dftilde_dp,stapx,geopx,trapx] = ...
      state_equation(x,xkm1,gk,vk,uk,dk,xdotk,gdotk,vdotk,it,model,params,inewt,stapkm1,geopkm1,trapkm1);
    % calculate Jacobian w.r.t. xk1
    JAC = dftilde_dxdot*dxdot_dxk + dftilde_dx*dx_dxk;

    res = norm(ftilde); % residual

    if ~model.realtime
      residues(inewt+1) = res; % store
    else
      residues = res; % store only most recent one
    end
    
    if (res<params.numerics.restol)
      exitflag = 0; break;  % tolerance reached, exit newton loop now
    elseif inewt >= params.numerics.nmax
      if ~model.realtime && params.debug.verbosity>1 % warn only if not RT, else ok
        warning('RAPTOR:ExceedMaxNewtonIterations','Exceeded maximum newton iterations');
      end
      exitflag = 1; break
    end

    % check this iterate output
    exitflag = numerics_check(ftilde,JAC); % check numerics are ok
    reject_checks = (exitflag~=-1);

    % Check residual using Armijo rule (Nocedal & Wright Ch 3)
    if inewt>0
      res_accept = norm(fprev + c_armijo*JAC_prev*(tau*d)); % maximum accepted residual on evaluation
      reject_residual = (res>res_accept);
    end
  end

  if reject_validity || reject_checks || reject_residual % reject this iterate
    if params.debug.plotFailedNewtonStep
      plotThisNewtonStep = true;
    end
    if params.debug.dispFailedNewtonStep
      if reject_validity
        fprintf('   Rejected step due to unphysical state, exitflag=%d! it=%d, inewt=%d \n',exitflag,it,inewt);
      elseif reject_checks
        fprintf('   Rejected step due to exitflag=%d! it=%d, inewt=%d \n',exitflag,it,inewt);
      elseif reject_residual
        fprintf('   Rejected step due to high residual! it=%d, inewt=%d \n',it,inewt);
        fprintf('   Residual norm prev:%3.3e new: %3.3e, accepted: %3.3e\n',res_prev,res,res_accept);
      end
    end

    % attempt a smaller step in the present Newton direction
    % (backtracking method)

    tau = c_backtrack*tau; % reduce step size
    if params.debug.dispFailedNewtonStep
      fprintf('    decrease step to tau=%4.4f\n',tau)
    end

    xk = xk_prev + tau*d;
    xdotk = dxdot_dxk*xk + dxdot_dxkm1*xkm1;

    % Keep also the previous stap if case of failure
    stapx = stapkm1;
    trapx = trapkm1;
    geopx = geopkm1;

  else % accept this iterate
    % New Newton descent direction
    % solve 0 = K + JAC*d
    [L,U,P] = lu(JAC);
    d = - U\(L\(P*ftilde)); % newton descent direction

    % Step in Newton direction
    tau = params.numerics.tau; % step size (1=full step)
    xk_prev = xk; res_prev = res;% save previous
    fprev = ftilde; JAC_prev = JAC;

    xk = xk_prev + tau*d;
    xdotk = dxdot_dxk*xk + dxdot_dxkm1*xkm1;
  end
  inewt = inewt+1;

  %% Display and plot of salient information of each Newton step
  if ~model.realtime
    if params.debug.plotEachNewtonDirection
      plot_newton_step_direction(xk,xk_prev,gk,vk,model);
    end
    if params.debug.pauseEachNewtonStep
      pauseEachIteration = true;
      status = plotEachNewtonStep(inewt,res,d,xk,gk,vk,xdotk,uk,it,tau,stapx,geopx,trapx,model,params,pauseEachIteration);
      if status == -1
        exitflag = 9; break
      end% user quit
    elseif params.debug.plotEachNewtonStep || plotThisNewtonStep
      pauseEachIteration = false;
      plotEachNewtonStep(inewt,res,d,xk,gk,vk,xdotk,uk,it,tau,stapx,geopx,trapx,model,params,pauseEachIteration);
    end
  end
  
end % of Newton iteration

% jacobians
dftilde_dxk = JAC;
dftilde_dxkm1  = dftilde_dxdot*dxdot_dxkm1  + dftilde_dx*dx_dxkm1;

% outputs
stapk = stapx;
trapk = trapx;
geopk = geopx;

end

function xkm1 = map_state_L_to_H(xkm1, model)
% LH transition
xkm1(model.te.xind) = model.te.T_LH * xkm1(model.te.xind);
if strcmp(model.ne.method,'state')
	xkm1(model.ne.xind) = model.ne.T_LH * xkm1(model.ne.xind);
end
if strcmp(model.ti.method,'state')
  xkm1(model.ti.xind) = model.ti.T_LH * xkm1(model.ti.xind);
end
end

function xkm1 = map_state_H_to_L(xkm1, model, vk, vdotk, dtk)
% HL transition
te_rhoedge = vk(model.hmode.vind.te_rhoedge) - dtk*vdotk(model.hmode.vind.te_rhoedge);
xxkm1 = [xkm1(model.te.xind); te_rhoedge./model.te.scal];
xkm1(model.te.xind) = model.te.T_HL * xxkm1; 
if strcmp(model.ne.method,'state')
    ne_rhoedge = vk(model.hmode.vind.ne_rhoedge) - dtk*vdotk(model.hmode.vind.ne_rhoedge);        
    xxkm1 = [xkm1(model.ne.xind); ne_rhoedge./model.ne.scal];
    xkm1(model.ne.xind) = model.te.T_HL * xxkm1;
end
if strcmp(model.ti.method,'state')
    ti_rhoedge = vk(model.hmode.vind.ti_rhoedge) - dtk*vdotk(model.hmode.vind.ti_rhoedge);        
    xxkm1 = [xkm1(model.ti.xind); ti_rhoedge./model.ti.scal];
    xkm1(model.ti.xind) = model.ti.T_HL * xxkm1; 
end
end

function status = plotEachNewtonStep(inewt,res,d,xk,gk,vk,xdot,uk,it,tau,stap,geop,trap,model,params,pauseEachIteration)

status = 0;
persistent rr
[rr,rr_return] = rr_check_return(rr,it,inewt);
if rr_return
  return
end

% otherwise

if inewt==1
  fprintf('\nNewton iterations for it=%d\n',it)
end
fprintf('%4s%7s%5s%9s%9s%9s%9s\n',...
  'it', 'niter','nmax','res','resmax','|step|','tau') ;
fprintf('%4d%7d%5d%9.1e%9.1e%9.1e%9.1e \n',...
  it,inewt,params.numerics.nmax,res,params.numerics.restol,norm(d),tau)

% plot
timeslice = RAPTOR_out(xk,gk,vk,xdot,uk,it,stap,geop,trap,model,params);
RAPTOR_plot_step(timeslice)
drawnow;

if pauseEachIteration && isempty(rr) % as long as user does not exit
  % choose next step
  fprintf('\n  choose one. [enter]:next newton step, ''n'':next time, ##:go to it=##, ''c'':continue,''k'':keyboard ''q'':quit\r')
  rr=input('','s');
  if strcmp(rr,'q'); status=-1; return
  elseif strcmp(rr,'k'); keyboard;
  elseif strcmp(rr,'c'); disp('continuing simulation');
  elseif strcmp(rr,'n'); disp('continuing to next time step')
  end
end

end


function [rr,rr_return] = rr_check_return(rr,it,inewt)

rr_return=false;
if isempty(rr) || (it == 1 && inewt == 1); % first
  rr = '';
end

if strcmp(rr,'c') % continue
  rr_return=true;
  return
end

if strcmp(rr,'n')
  if inewt>1 % skip until next newton iteration
    rr_return=true;
    return
  else
    rr='';
  end
end

if isnumeric(str2double(rr))
  if it<str2double(rr)
    % skip till iteration number it
    rr_return=true;
    return
  else
    rr=''; % reset
  end
end
end

function [xk1,xdotk1] = first_guess(xk,xdotk,dt,model,extrapfactor)
% initial estimate for next step

% extrapolate only psi, not the rest

xk1 = xk + extrapfactor * dt*xdotk;
xk1(model.psi.xind) = xk(model.psi.xind) + dt*xdotk(model.psi.xind);

xdotk1 = extrapfactor*xdotk;
xdotk1(model.psi.xind) = xdotk(model.psi.xind); % same upl

end

function exitflag = numerics_check(ftilde,JAC)
coder.extrinsic('warning');
if norm(ftilde)>1e22;
  warning('Divergence in newton iterations')
  exitflag = 2;
elseif any(any(~isfinite(JAC))) || any(isnan(ftilde));
  warning('non-finite values in ftilde or Jacobian');
  exitflag = 4;
else
  exitflag = -1;
end
end

function xk1 = add_seed_island(xk1,vk,model,params)

%% Add a possible seed island and check that island width is nonzero
if params.ntm.active && strcmp(model.ntm.method,'state')
  % Check for and assign seeds
  w = xk1(model.ntm.xind);
  wseed = vk(model.ntm.vind);
  xk1(model.ntm.xind(w<wseed)) = wseed(w<wseed); % assign if smaller
  % Check for nonzero island width and correct
  xk1(model.ntm.xind(xk1(model.ntm.xind)<0)) = 0;
end
end


function plot_newton_step_direction(xk,xkprev,gk,vk,model) %#ok<INUSL>

persistent hax
figname = 'RAPTOR Newton Step Direction';
nrow = 2; ncol=3;
hax = get_persistent_hax(hax, figname, nrow, ncol, true);
channels = {'psi','te','ti';
           'iota','ne','ni'};

gauss = false; %#ok<NASGU>
for ichan = 1:numel(channels)
  mychan = channels{ichan};

  switch mychan
    case 'iota'
      isstate = strcmp(model.psi.method,'state');
      scal = 1./(model.psi.scal);
    otherwise
      isstate = strcmp(model.(mychan).method,'state');
      scal = 1./(model.(mychan).scal);
  end

  valkprev = scal*eval(sprintf('eval_%s(xkprev,gk,vk,model,gauss)',mychan));
  valk = scal*eval(sprintf('eval_%s(xk,gk,vk,model,gauss)',mychan));

  if isstate
    plot(hax(ichan),model.rgrid.rho,valkprev,'b--',...
      model.rgrid.rho,valk,'r');
    title(hax(ichan),sprintf('%s, state',mychan));
  else
    plot(hax(ichan),model.rgrid.rho,valk,'k');
    title(hax(ichan),sprintf('%s, fixed',mychan));
  end
end
drawnow
end
