function [varargout] = RAPTOR_predictive(x0,Gin,Vin,Uin,model,params,varargin)
% function [X,U,df_dxk1,df_dx,df_du,df_dp,dx_dp] = RAPTOR_predictive(x0,g0,v0,U_ff,model,params)
% function [simres] = RAPTOR_predictive(x0,g0,v0,U_ff,model,params) % returns structure

[verbosity] = parse_inputs(params,varargin{:});
params = set_verbosity(params,verbosity);

input_checks(x0,Gin,Vin,Uin,model,params)

% call mode
if (nargout==1 || nargout>2) && (params.numerics.calc_sens) % flag to decide whether forward sensitivities are to be calculated
  calcfwsens = 1;
else
  calcfwsens = 0;
end

%% Allocate
nu = count_nu(model,Uin);
nt = numel(params.tgrid);

U    = NaN*ones(nu,nt);
Xdot = NaN*ones(size(x0,1),nt);
X    = NaN*ones(size(x0,1),nt);
Dt   = NaN*ones(1,nt);
niter = NaN*ones(nt,1);
newtres = NaN*zeros(1,nt);
exitflags = NaN*zeros(1,nt);

traps = cell(1,nt);
staps = cell(1,nt);
geops = cell(1,nt);

if calcfwsens
  df_dxkm1  = cell(1,nt);
  df_dxk  = cell(1,nt);
  df_dx   = cell(1,nt);
  df_dxdot= cell(1,nt);
  df_du   = cell(1,nt);
  df_dd   = cell(1,nt);
  dxk_dp  = cell(1,nt+1);
end

exit_time_loop = false; % bool used to exit time loop

%% assign G,V
[G,V] = assign_GV(Gin,Vin,model,nt);

% Measure computational time
tic;
for it=1:nt % TIME ITERATIONS
  %% Initialize inputs and state
  if it==1
    % flying start, useful to get ramp-ups to converge
    xdotkm1 = initialize_xdot(model,params);
    dtk = diff(params.tgrid(1:2));
    vdotk = zeros(size(V(:,1)));
    gdotk = zeros(size(G(:,1)));
    
    xkm1 = x0;
    gk = G(:,1);
    vk = V(:,1);
    gkm1 = gk;
    u0 = RAPTOR_controllers(xkm1,gk,vk,Uin,it,model,params);
    dk = zeros(model.dims.nd,1);
    
    % Evaluate state profiles, geometrical profiles, kinetic profiles on gauss grid for FE equation
    stapi = init_stap(model); geopi=init_geop(model); trapi=init_trap(model);
    stapkm1 = state_profiles(x0,xdotkm1,gk,gdotk,vk,vdotk,model,stapi);
    geopkm1 = geometry_profiles(gk,gdotk,model,geopi);
    trapkm1 = transport_profiles(stapkm1,stapkm1,geopkm1,dk,u0,it,model,params,0,trapi);
  else
    xdotkm1 = xdotk;
    vk = V(:,it); % kinetic profiles
    
    vdotk = (V(:,it)-V(:,it-1))/dtk;
    gdotk = (G(:,it)-G(:,it-1))/dtk;
    dtk = params.tgrid(it)-params.tgrid(it-1); % dt used
    
    xkm1 = X(:,it-1);
    gkm1 = G(:,it-1);
    stapkm1 = staps{it-1};
    geopkm1 = geops{it-1};
    trapkm1 = traps{it-1};
  end
  

  %% Input sensitivity
  duk_dp = input_sensitivity(calcfwsens,it,model,params);
  
  if model.equi.eq_call
    %% Perform equilibrium and transport iterations
    [xk, xdotk, gk, uk, stapk, geopk, trapk,...
      exitflag, res, inewt, ...
      dftilde_dx, dftilde_dxdot, dftilde_dxk, dftilde_dxkm1, dftilde_du, dftilde_dd, ~, ...
      subiter_info{it}] = run_eq_tran_iterations(gkm1, xkm1, xdotkm1, vk, vdotk, dk, Uin, it, dtk, stapkm1, geopkm1, trapkm1, model, params);
  else
    %% Call RAPTOR_predictive_step directly
    gk = G(:,it); % geometry profiles
    uk = RAPTOR_controllers(xkm1,gk,vk,Uin,it,model,params);
    assert(size(uk,1) == model.dims.nu,'Size of uk (=%d) does not match model.dims.nu (=%d)',size(Uin,1),model.dims.nu);

    if params.numerics.usemex
      error('Todo: fix mex call')
    else
      [xk,xdotk,dftilde_dx,dftilde_dxdot,dftilde_dxk,dftilde_dxkm1,dftilde_du,dftilde_dd,dftilde_dw,~,...
        stapk,geopk,trapk,res,inewt,exitflag] = ...
        RAPTOR_predictive_step(it,dtk,xkm1, xdotkm1,gk,gdotk,vk,vdotk,uk,dk,stapkm1,geopkm1,trapkm1,model,params);
    end
  end
  
  %% forward sensitivity analysis
  % store jacobians
  df_dxkm1{it} = dftilde_dxkm1;
  df_dx{it} = dftilde_dx;
  df_dxdot{it} = dftilde_dxdot;
  df_dxk{it} = dftilde_dxk;
  df_du{it} = dftilde_du;
  df_dd{it} = dftilde_dd;
  
  % forward sens equation if necessary
  if calcfwsens
    % init
    if it==1
      dxkm1_dp = sparse(numel(xkm1),size(duk_dp,2));
    end
    dxk_dp{it} = calc_fwsens(dxkm1_dp,duk_dp,dftilde_dxk,dftilde_dxkm1,dftilde_du);
    dxkm1_dp = dxk_dp{it}; % next
  end
  
  % sawtooth module: simply substitute resulting state
  do_mhd = ~isempty(params.saw) && isfield(params.saw, 'active') && params.saw.active;
  if do_mhd
    xdotk_preMHD = xdotk; % store old values
    xk = saw(xk,gk,vk,uk,it,model,params.saw);  
    xdotk = (xk-xkm1)/dtk;
  end
  
  
  %% Store results in matrices
  U(:,it) = uk;
  newtres(it) = res;  % newton iteration residual
  exitflags(it) = exitflag;
  niter(it) = inewt;
  X(:,it) = xk;
  Xdot(:,it) = xdotk;
  G(:,it) = gk;
  V(:,it) = vk;
  Dt(:,it) = dtk;
  
  staps{it} = stapk;
  geops{it} = geopk;
  traps{it} = trapk;
  
  
  %% Error handling
  if ~any(exitflag==[0,1])
    % handle errors in RAPTOR_predictive
    exitflags(it) = RAPTOR_errorhandler(exitflag);
    exit_time_loop = true;
  end
  
  %% PLOTS & DISPLAYS
  
  % flag whether to display
  displaynow = ~mod(it-1,params.debug.iterdisp) || exit_time_loop;
  
  if params.debug.iterplot && displaynow
    timeslice = RAPTOR_out(xk,gk,vk,xdotk,uk,it,stapk,geopk,trapk,model,params);
    RAPTOR_plot_step(timeslice)
  end
  
  % optional summary of macroscopic quantities
  if params.debug.iterdisp && displaynow
    iterdisp(xk,gk,vk,xdotk,res,dtk,it,inewt,trapk,model,params)
  end
  
  % optional break at given iteration
  if isfield(params.debug,'keyboard_at') && ~isempty(params.debug.keyboard_at)
    if any(it == params.debug.keyboard_at)
      disp(['keyboard at it= ',num2str(it),' as requested']);
      dbstack;
      keyboard;
    end
  end
  
  % check exit flag and exit if necessary
  if exit_time_loop
    break
  end
  
  if do_mhd
    xdotk = xdotk_preMHD;
    % use the pre-MHD xdot1 as initialization for the next step
    % the one modified after MHD can be too large
    % so that the extrapolation in the first iteration of the
    % next step may be too large...
  end
  
end % END of time loop


% outputs, depending on call varargout
%% If one out argument, return all as structure
if nargout == 1;
  % assign outputs
  
  simres.X = X;
  simres.G = G;
  simres.V = V;
  simres.Xdot = Xdot;
  simres.U = U;
  simres.Dt = Dt;
  simres.staps = staps;
  simres.geops = geops;
  simres.traps = traps;
  
  simres.exitflags = exitflags;
  simres.newtres = newtres;
  simres.niter = niter;
  
  simres.df_dx = df_dx;
  simres.df_dxdot = df_dxdot;
  simres.df_dxkm1 = df_dxkm1;
  simres.df_dxk = df_dxk;
  simres.df_du = df_du;
  simres.df_dd = df_dd;
  
  if calcfwsens % assign sensitivities if calculated
    simres.dxk_dp = dxk_dp;
  end
  
  if model.equi.eq_call
    eqrecon = rap_liu_post_processing(params, subiter_info);
    simres.subiter_info = subiter_info;
    simres.eqrecon = eqrecon;
  end
  
  varargout{1} = simres;
elseif nargout == 2;
  varargout = {X,U};
elseif nargout == 7;
  varargout = {X,U,df_dxk,df_dx,df_du,df_dp,dx_du};
else
  varargout = [];
end

end

function [verbosity] = parse_inputs(params,varargin)
p = inputParser;
addParameter(p,'verbosity',params.debug.verbosity,@(x) isnumeric(x) && isscalar(x) && (x>=0 || x==-1));

parse(p,varargin{:});
verbosity = p.Results.verbosity;

end

function nu = count_nu(model,Uin)
% decide number of inputs taking possible controllers into account
if ~isstruct(model.controllers)
  % standard case, no controllers
  nu = size(Uin,1);
else
  % with controllers
  nu=0; % init
  for ictr = 1:numel(model.controllers.mapping)
    % count actuators as defined in mapping
    nu = nu+numel(model.controllers.mapping{ictr}{1});
  end
end

end

function [G,V] = assign_GV(Gin,Vin,model,nt)

if size(Gin,2)==1
    G = repmat(Gin,[1,nt]);
elseif size(Gin,2)==nt
    G = Gin;
else
    error('Gin has invalid size')
end

if (size(Vin,2)==nt)
  V = Vin;
elseif size(Vin,2) == 1
  V = repmat(Vin,[1,nt]);
else
  error('Vin has invalid size')
end

end

function duk_dp = input_sensitivity(calcfwsens,it,model,params)

if calcfwsens
  % civ - control input vector
  dciv_dp = RAPTOR_cvp(it);
  if ~isempty(params.cvp.transp_matrix)
    % Get dciv_dp for reduced p which doesn't contain fixed values.
    dciv_dpfull = dciv_dp*params.cvp.transp_matrix;
  else
    dciv_dpfull = dciv_dp;
  end
  nu = model.dims.nu;
  ij=1;
  duk_dp = zeros(nu,size(dciv_dpfull,2));
  for ii=1:nu
    % check which actuators are optimized
    if params.cvp.uind(ii)==1
      duk_dp(ii,:) = dciv_dpfull(ij,:);
      ij = ij+1;
    end
  end
else
  duk_dp = 0.0;
end

end

function [dxk_dp] = calc_fwsens(dxkm1_dp,duk_dp,dftilde_dxk,dftilde_dxkm1,dftilde_du)

% make more efficient later by computing LU decomposition of dftilde_dxk1 only
% once (see "doc LU" for details)
df_dxkm1_dxkm1_dp = dftilde_dxkm1*dxkm1_dp; % (df/dpsi_k * dpsi_k/dp)
df_du_du_dp = dftilde_du*duk_dp;
dxk_dp = - dftilde_dxk \ (df_dxkm1_dxkm1_dp + df_du_du_dp);
end


