function [x,u,xdot,df_dxdot,df_dx,df_du,f,stap,geop,trap] = RAPTOR_solve_state(x0,u0,xdot0,g,v,model,params,varargin)
%  [x,u,xdot,df_dxdot,df_dx,df_du,f] = RAPTOR_solve_state(x0,u0,xdot0,model,params)
%  [x,u,xdot,df_dxdot,df_dx,df_du,f] = RAPTOR_solve_state(x0,u0,xdot0,model,params,fixmask)
%
% Find solutions of RAPTOR state equation in continuous time:
%     0 = f(xdot,x,u)       (FF thesis eq. 7.41)
%
% By default, u and xdot are assumed to be fixed (as given by the user)
% and a corresponding state x is returned. This is useful for finding
% steady-state or stationary solutions.
%
% Optionally, one can pass a 'fixmask' parameter to decide manually what is
% fixed and what is not.
%
% INPUTS
%   x0,u0,xdot0: initial guesses, some of which are kept fixed when searching for a solution
%                depending on fixmask
%   model,params: RAPTOR model and params structure
%   fixmask: Optional parameter to determine which quantities are kept fixed and which ones we solve for.
%             fixmask = {xfix,ufix,xdotfix}
%             is a cell array with logical vectors of same sizes as x0,u0,xdot0
%             each vector containing 'true' if that particular parameter is fixed.
%             Default: fixmask = {false(size(x0)),true(size(u0)),true(size(xdot0))}
%             This solves for x only, with given u0 and x0.
%
% OUTPUTS:
%   x,u,xdot: solution vectors (some are same as input)
%   df_dxdot,df_dx,df_du: jacobians
%   f: residual (should be small)

% F. Felici, B. v.d. Schepop TU/e 2012-2017

if ~isempty(varargin)
    fixmask = varargin{1};
else
    x0mask = false(size(x0)); x0mask(end) = true; % fix last element of x
    fixmask = {x0mask,true(size(u0)),true(size(xdot0))};
end

% checks
if ~all(size(fixmask{1})==size(x0))
    error('wrong size for fixmask{1} or x0')
end
if ~all(size(fixmask{2})==size(u0))
    error('wrong size for fixmask{2} or u0')
end
if ~all(size(fixmask{3})==size(xdot0))
    error('wrong size for fixmask{3} or xdot0')
end

it = 1;

% some numerics parameters
% disp=0;
restol = 1e-12;
maxiter = 300;
tau = 0.2; % start slow, then accelerate (adaptive step)
doplot = 0;

% scalings
scal.x = ones(size(x0));
scal.u = 1e6*ones(size(u0));
scal.xdot = ones(size(xdot0));

% z is scaled such that values are closer to unity
x0scal = x0./scal.x; u0scal = u0./scal.u; xdot0scal = xdot0./scal.xdot;

z     = [x0scal(~fixmask{1}); u0scal(~fixmask{2}); xdot0scal(~fixmask{3})];
zfix  = [x0scal( fixmask{1}); u0scal( fixmask{2}); xdot0scal( fixmask{3})];
resfold = inf; % to always accept first step

fprintf('\n\niter: residual, tau\n');

for inewt=1:maxiter;
    if inewt == maxiter; 
        warning('Newton iterations exceeded'); 
        return
    end
    params.debug.checkstateequation = 0;
    % evaluate equation
    [f,df_dz,x,u,xdot,df_dxdot,df_dx,df_du,stap,geop,trap] = generalized_equation(z,zfix,g,v,it,model,params,fixmask,scal);
    assert(diff(size(df_dz))==0,'df_dz must be square to have a unique solution')
    resf = norm(f);
    % display
    fprintf('%3d: %3.3e, %3.3f\n',[inewt resf,tau]);
    
    % checks
    if any(isnan(f)); error('f is NaN'); end
    
    if resf<restol
        fprintf('Newton iterations converged\n\n')
        break
    end
    
    if resf>resfold
        % reject this step
        tau = tau-(tau*tau); % deceleration for tau
        % retry previous direction, with smaller tau
        upd = -tau*dold;
    else
        % new step
        %
        d = df_dz\f;
        upd = -tau*d;
        
        % store old
        zold = z; dold = d; resfold = resf;
        tau = tau+((1-tau)*tau); % acceleration for tau
    end
    z = zold + upd;
    
    %%
    if doplot
        timeslice = RAPTOR_out(x,g,v,xdot,u,it,stap,geop,trap,model,params);
        RAPTOR_plot_step(out);
    end
end

return

function [f,df_dzfree,x,u,xdot,df_dxdot,df_dx,df_du,stap,geop,trap] = generalized_equation(zfree,zfix,g,v,it,model,params,fixmask,scal)
% split x,u,xdot into vector of free parameters (zfree) and fixed parameters(zfix).
% get state equation residual and return df/dzfree

% assign what is fixed
nfix = [sum(fixmask{1}),sum(fixmask{2}),sum(fixmask{3})];
xscal(fixmask{1},:)    = zfix(1:nfix(1));
uscal(fixmask{2},:)    = zfix(nfix(1)+[1:nfix(2)]);
xdotscal(fixmask{3},:) = zfix(nfix(1)+nfix(2)+[1:nfix(3)]);

% assign what is free
nfree = [sum(~fixmask{1}),sum(~fixmask{2}),sum(~fixmask{3})];
xscal(~fixmask{1},:)    = zfree(1:nfree(1));
uscal(~fixmask{2},:)    = zfree(nfree(1)+[1:nfree(2)]);
xdotscal(~fixmask{3},:) = zfree(nfree(1)+nfree(2)+[1:nfree(3)]);

x    = scal.x .* xscal;
u    = scal.u .* uscal;
xdot = scal.x .* xdotscal;

% call state function
stapk = init_stap(model);
geopk = init_geop(model);
trapk = init_trap(model); 
gdot = zeros(size(g));
vdot = zeros(size(v));
xk = x; 
dk = zeros(model.dims.nd,1);
inewt = 1;

[f,df_dxdot,df_dx,df_du,~,~,~,stap,geop,trap] = ...
state_equation(x,xk,g,v,u,dk,xdot,gdot,vdot,it,model,params,inewt,stapk,geopk,trapk);

df_dzfree = [bsxfun(@times,df_dx(:,~fixmask{1}),scal.x(~fixmask{1})'),...
    bsxfun(@times,df_du(:,~fixmask{2}),scal.u(~fixmask{2})'),...
    bsxfun(@times,df_dxdot(:,~fixmask{3}),scal.xdot(~fixmask{3})')];

return
