function [err] = calc_jacobians(x,xdot,g,gdot,v,vdot,dk,u,inewt,it,model,params,pertgrid,doplot)
% calculate all Jacobians w.r.t. x,xdot,u
% INPUTS:
% pertgrid: vector of perturbation sizes to try (e.g. logspace(-6,-2,10);
% doplot: option to plot results or not
% params.debug.profilestocheck contains list of profiles to check, see below
%
% OUTPUTS:
%  err: structure of errors per x,xdot,u and per transport channel

xk = x; xdotk = xdot;
stapk = state_profiles(xk,xdotk,g,gdot,v,vdot,model,init_stap(model));
stap = state_profiles(x,xdot,g,gdot,v,vdot,model,init_stap(model));
geop = geometry_profiles(g,gdot,model,init_geop(model));
trap = transport_profiles(stapk,stap,geop,dk,u,it,model,params,inewt,init_trap(model));

if strcmp(params.debug.trapProfilesToCheck,'all')
    % check them all
    trapProfilesToCheck = {'chie','chii','vne','dne','sne','jbsB','signeo','jphioR','petot','pitot','jauxB'}; % more to add
elseif strcmp(params.debug.trapProfilesToCheck,'none')
    trapProfilesToCheck = {};
else
    trapProfilesToCheck = params.debug.trapProfilesToCheck;
    if ~iscell(trapProfilesToCheck) == 1
        trapProfilesToCheck = {trapProfilesToCheck}; % put in cell if necessary
    end
end

if strcmp(params.debug.stapProfilesToCheck,'all')
    % check them all
    stapProfilesToCheck = eliminateDerivatives(fieldnames(stap));
elseif strcmp(params.debug.stapProfilesToCheck,'none')
    stapProfilesToCheck = {};
else
    stapProfilesToCheck = params.debug.stapProfilesToCheck;
    if ~iscell(stapProfilesToCheck) == 1
        stapProfilesToCheck = {stapProfilesToCheck}; % put in cell if necessary
    end
end

channels = {'psi','te','ti','ne','ni','n1','n2','n3'};
% count which channels are active
iusechan = false(numel(channels),1);
for ichan=1:numel(channels)
    if strcmp(model.(channels{ichan}).method,'state')
        iusechan(ichan) = true;
    end
end
nchan = sum(iusechan); % number of active state channels
usedchan = channels(iusechan);

err = struct; % init
% loop over those that are in the state
for ichan = 1:nchan
    chan = usedchan{ichan};
    % for varying sizes of the perturbation signal
    for in = 1:numel(pertgrid) % for each perturbation size
        nn = pertgrid(in);
        
        % petrubed profile structures
        [dx_stap,dx_trap,dxdot_stapdot,dxdot_trap,du_trap,dx,dxdot,du] = ...
            calc_perturbed_profiles_for_this_channel(nn,x,xdot,g,gdot,v,vdot,u,...
            stapk,stap,geop,dk,inewt,it,model,params,chan);
        
        % for all values in state structure, evaluate the error
        for istate= 1:numel(stapProfilesToCheck)
            prof = stapProfilesToCheck{istate};
            % for this profile, and this channel, calculate error for x, xdot and u
            err.(prof).x.(chan)(in)    = calc_error_and_plot(stap, dx_stap,       prof, sprintf('d%s_dx',prof),    dx,    chan, doplot);
            err.(prof).xdot.(chan)(in) = calc_error_and_plot(stap, dxdot_stapdot, prof, sprintf('d%s_dxdot',prof), dxdot, chan, doplot);
        end
        
        % for all profile names in transport structure, evaluate the error
        for iprof = 1:numel(trapProfilesToCheck)
            prof = trapProfilesToCheck{iprof};
            % for this profile, and this channel, calculate error for x, xdot and u
            err.(prof).x.(chan)(in)     =  calc_error_and_plot(trap,   dx_trap, prof, sprintf('d%s_dx',prof),    dx,    chan, doplot);
            err.(prof).xdot.(chan)(in)  =  calc_error_and_plot(trap,dxdot_trap, prof, sprintf('d%s_dxdot',prof), dxdot, chan, doplot);
        end
        
        if ichan==1
          % do this only once, since does not depend on channel
          for iprof = 1:numel(trapProfilesToCheck)
            prof = trapProfilesToCheck{iprof};
            err.(prof).u(in)     =  calc_error_and_plot(trap,   du_trap, prof, sprintf('d%s_du',prof),    du,    '', doplot);
          end
        end
    end
end

return

function [dx_stap,dx_trap,dxdot_stapdot,dxdot_trap,du_trap,dx,dxdot,du] = ...
    calc_perturbed_profiles_for_this_channel(nn,x,xdot,g,gdot,v,vdot,u,...
    stapk,stap,geop,dk,inewt,it,model,params,chan)

% calculate perturbed vectors for indices of this channel
dx = perturb(x,nn,model.(chan).xind);
dxdot = dx./get_dt(params.tgrid,it);
du = perturb(u,nn,1:numel(u));

% perturb dx
dx_stap = state_profiles(x+dx,xdot,g,gdot,v,vdot,model,init_stap(model));
dx_trap = transport_profiles(stapk,dx_stap,geop,dk,u,it,model,params,inewt,init_trap(model));

% perturb dxdot
dxdot_stapdot = state_profiles(x,xdot+dxdot,g,gdot,v,vdot,model,init_stap(model));
dxdot_trap = transport_profiles(stapk,dxdot_stapdot,geop,dk,u,it,model,params,inewt,init_trap(model));

% perturb du
du_trap = transport_profiles(stapk,stap,geop,dk,u+du,it,model,params,inewt,init_trap(model));

return

function err = calc_error_and_plot(profStruct,dprofStruct,profName,dprofName,dvec,chanName,verbosity)

[dprofn,dprofa,prof0] = calc_perturbations(profStruct,dprofStruct,profName,dprofName,dvec);
err = calc_error(dprofa,dprofn,prof0);

if abs(verbosity)
    plot_error(dprofn,dprofa,dprofName,chanName)
    if verbosity == -1
      keyboard; 
    end
end

return

function [dprofn,dprofa,prof0] = calc_perturbations(profStruct,dprofStruct,profName,dprofName,dvec)

prof0     = profStruct.(profName);
dprofn    = dprofStruct.(profName) - prof0; % numerical difference

if isempty(prof0) % if empty, ignore
  dprofa = zeros(size(prof0));
elseif isfield(profStruct,dprofName) % if Jacobian exists
    dprofa = profStruct.(dprofName)*dvec; % analytical difference
else
    dprofa = NaN;
end
return

function d = perturb(y,nn,ind)
% perturb indices 'ind' of y, with noise of amplitude nn
d      = zeros(size(y)); % init
normyind = norm(y(ind));
assert(normyind>0,'zero norm');
d(ind) = nn .* rand(size(y(ind))) * normyind;
return

function withoutDerivatives = eliminateDerivatives(fields)
% eliminate field names that start with _d like _dx, _du etc
iwithoutDerivatives = false(numel(fields),1);
for ifield = 1:numel(fields)
    myfield=fields{ifield};
    if isempty(strfind(myfield,'_d'))
        iwithoutDerivatives(ifield) = true;
    end
end
withoutDerivatives = fields(iwithoutDerivatives);

return

function plot_error(dprofn,dprofa,dprofName,chanName)

    ax = subplot(211);
    plot(ax,1:numel(dprofn),dprofn,'-',...
            1:numel(dprofn),dprofa,'--');
    ylabel(ax,'numerical(-),analytical (--)')
    ht=title(ax,sprintf('%s(%s)',dprofName,chanName)); ht.Interpreter='none';

    ax = subplot(212);
    plot(ax,dprofn-dprofa)
    ylabel(ax,'error')
%     disp('press any key to continue');
%     pause;
    drawnow;
return


function plot_residuals(err,pertgrid)

persistent hax

if isempty(hax)
    hf = gcf; clf(hf);
    hax = multiaxes(gcf,nchan,3);
end

if numel(pertgrid) == 1
    warning('can not plot for single grid point')
    return
end
%% plots
ax = hax(jchan,1);
%loglog(ngrid,anal_err); hold on;
ProfilesToPlot = fields(err);
for iprof = 1:numel(ProfilesToPlot)
    loglog(ax,pertgrid,err.(ProfilesToPlot{iprof}).x.(chan),'o'); hold(ax,'on');
end
title(ax,sprintf('||\\delta %s||',chan))
hold(ax,'off');
if jchan==1
    ylabel(ax,'||\deltay - dy/dx*\deltax||')
end

ax = hax(jchan,2);
for iprof = 1:numel(ProfilesToPlot)
    loglog(ax,pertgrid,err.(ProfilesToPlot{iprof}).xdot.(chan)); hold(ax,'on');
end
hold off;
if jchan==1
    ylabel(ax,'||\deltay - dy/dxdot*\deltaxdot||')
end
ax = hax(jchan,3);
%loglog(ngrid,anal_err); hold on;
for iprof = 1:numel(ProfilesToPlot)
    myprof = ProfilesToPlot{iprof};
    if isfield(err.(myprof),'u') % state profiles do not depend on u
        loglog(ax,pertgrid,err.(myprof).u.(chan)); hold(ax,'on');
    end
end
hold(ax,'off');
if jchan==1
    ylabel(ax,'||\deltay - dy/du*\deltau||')
end

if jchan == nchan
    drawnow
end
return

function dt = get_dt(tgrid,it)
if it~=numel(tgrid)
    dt = tgrid(it+1) - tgrid(it);
else
    dt = tgrid(it) - tgrid(it-1);
end   
return