function [x, xdot, u, out_data, res_vec] = initial_state_solver(model, params, init, g, v, u, ...
                                                                xdot_profiles, verbosity)
% Interface for using the state solver for calculation of initial profiles.
% Wrapper function for calling the general state equation solver for
% the specific case of an initial state.
% xdotprofiles.Upl: profiles versus scalar (flat profile assumed)   
% xdotprofiles.dTe_dt: idem 

x_z    = @(z) x_z_initdTe_dt(z, model);
xdot_z = @(z) xdot_z_initdTe_dt(z, model, xdot_profiles);
u_z    = @(z) u_z_initdTe_dt(z, model, u);

xxdotu_z = @(z) compact_args(z, x_z, xdot_z, u_z);
[x, xdot, u, out_data, res_vec] = solve_parametrized(model, params, init, g, v, xxdotu_z, verbosity);                                                          
end

function [x, dx_dz, xdot, dxdot_dz, u, du_dz] = compact_args(z, x_z, xdot_z, u_z)
[x, dx_dz] = x_z(z); 
[xdot, dxdot_dz] = xdot_z(z);
[u, du_dz]= u_z(z);
end

%% initial state parametrization

function [x, dx_dz] = x_z_initdTe_dt(z, model)
psisize = numel(model.psi.xind);
zsize = numel(z);
xsize = model.dims.nx;
kinetic_xind = [model.te.xind, model.ti.xind, model.ne.xind, model.ni.xind];

x = zeros(xsize, 1);
x(model.psi.xind(1:end-1)) = z(1:psisize-1);
x(kinetic_xind) = z(psisize:zsize-1);

dx_dz = zeros(xsize, zsize);
dx_dz(model.psi.xind(1:end-1), 1:psisize-1) = eye(psisize-1);
dx_dz(kinetic_xind, psisize:zsize-1) = eye(zsize-psisize);
end

function [xdot, dxdot_dz] = xdot_z_initdTe_dt(z, model, xdot_profiles)
rho = linspace(0, 1, numel(model.psi.xind))';
Upl_rho = xdot_profiles.Upl(rho);
if numel(xdot_profiles.dTe_dt) == 1
    if isfield(xdot_profiles, 'dTe_dt')
        scaling = xdot_profiles.dTe_dt;
    else
        scaling = 1e2;
    end
    dTe_dt_rho = scaling*ones(numel(model.te.xind), 1);
else
    dTe_dt_rho = xdot_profiles.dTe_dt(rho);
end
zsize = numel(z);
xsize = model.dims.nx;

xdot = zeros(zsize, 1);
xdot(model.psi.xind) = Upl_rho;
xdot(model.te.xind) = z(zsize)*dTe_dt_rho;

dxdot_dz = zeros(xsize, zsize);
dxdot_dz(model.te.xind, zsize) = dTe_dt_rho;
end

function [u, du_dz] = u_z_initdTe_dt(z, model, u)
du_dz = zeros(model.dims.nu, numel(z));
end
