function varargout = vpdn_MS(varargin)
% Electron diffusivity and pinch velocity
% Use analytical transport model from Kim, Merle, Sauter, Goodman
% PPCF 2016(58) 055002
%
% Dne = coff*chi_e
% Vpe/Dne = *f(rho_inv-rho/wrho_inv)...
%         [mu_ne/(n_e*rho_edge)*f([rho_ped-rho]/wrho_ped)+...
%         lambda_ne/rho_edge*f([rho-rho_ped]/wrho_ped)]^-1
%         f(x) = 1/(1+exp(x))
% Returns electron diffusivity and pinch velocity and derivatives. all outputs are on grid [rhogauss]

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if nargin==0
  varargout{1} = 'vpdn_e'; % Module type
  varargout{2} = []; % Configuration
  return
elseif nargin == 2,
    %% DEFAULT PARAMETERS
    % additional parameters are not in the default structure, but can be
    % defined by the user:
    % -- save_mutrace -- to save the effective mune trace, incl. feedback
    % -- fname_mutrace -- to define the filename where to store the trace
    % -- plot_hh_fgr -- to plot nel divided by reference after every time
    % step
    % -- startup_gp_factor -- multiplier to feedback gains while it < -- startup_gp_it -- 
    module_params=struct(...
        'mune'    , 1.e+19   ,... % typical mune (gradient in the pedestal region)
        'prescribed', 0 ,... % 0 - use controlled muTe, 1 - use prescribed muTe
        'lambdane', [1.5 1.0]   ,... % typical lambdaTe values for L-, H-modes (inverse scalelength in the core region)
        'rhoped'  , [0.8 0.9]   ,... % typical normalised rho pedestal, defines core/edge region (for L-, H-modes)
        'wrhoped' , 0.02  ,... % width of the transition between core and edge regions 
        'rhoinv'  , 0.15   ,... % typical normalised sawthooth inversion radius, defines center/core region
        'wrhoinv' , 0.02  ,... % width of the transition between center and core regions
        'nel'     , 1.e+19*[1. 4.]   ,... % reference line averaged density (on rho_tor) [m^-3]: L/H-mode
        'gp'      , 1e+19 ,... % default value for proportional gain
        'gi'      , 1e+19 ,... % default value for integration gain
        'fb_on'   , true ,...    % feedback: true - on, false - off
        'ff_on'   , false ,...    % feedforward: true - on, false - off
        'muneff_pred', -1, ... % predictive feed-forward; -1 - calculated
        'dnochie' , 0.2, ... % Dne = dnochie*chi_e: Dne linearly proportional to chi_e
        'sawtooth', false, ... % enable internal chiMS sawtooth model
        's1crit'  , 1,... % critical shear for sawtooth mode
        'implicit', true  ,... % implicit method
        'check'   , false, ... % option to check gradients etc
        'jac_rhoinv', false ...
        );
      
      mm.name = mfilename;
      varargout{1} = mm;
      varargout{2} = module_params;
      return %empty call, probably to get default structures
elseif nargin==6; % change this depending on number of inputs to module
    stap = varargin{1};
    geop = varargin{2};
    trap = varargin{3};
    it = varargin{4};    
    model = varargin{5};
    cp = varargin{6};
else
    error('must call with 0 or 6 inputs');
end

% Newton iteration

persistent mune mune_km1 nel_km1 nel_km2 nelref_km1 nelref_km2 u_ff_km1 rhoinv munetrace;

in_statstate_solver = all(abs(diff(stap.upl))<1e-12);

% counter to identify first Newton iteration did not work properly for stat state solver 
% as persistent itp counter stays at 1 (note: only mu prescribed mode supported for now in stat state solver)
firstNewtonIteration = isFirstTime(it) || in_statstate_solver;

%% Module parameters
% Electron density and it derivative over time
ne = stap.ne;
dne_dx = stap.dne_dx;
%dne_dx = stap.dne_dx;
% Safety factor
q = stap.q;
% Gauss radial grid, normalised
rhogauss = model.rgrid.rhogauss;
% chi_e
chie = trap.chie;
dchie_dx = trap.dchie_dx;
dchie_dxdot = trap.dchie_dxdot;
dchie_du = trap.dchie_du;

%% chi_e module
% Get current lambdaTe value
if numel(cp.lambdane)==2
    lambdane = (1.-stap.lamH)*cp.lambdane(1) + stap.lamH*cp.lambdane(2);
else
    lambdane = cp.lambdane(it);
end
% Get current rho_ped
rhoped = (1.-stap.lamH)*cp.rhoped(1) + stap.lamH*cp.rhoped(2);

% Line averaged density: over rho_gauss (!)
% nelavrho = [int ne(rho_gauss) drho_gauss]/[int 1. drho_gauss]
neli = int_CompSimpsonRule(ne,model);
nel = neli(end);

% Find mune for current time iteration, keep it the same during Newton
% iterations
if firstNewtonIteration    
% only for first Newton iteration
    % Get rho_inv
    % Define rho_tor for q=1.0 surface
    indql1 = find(q<1);
    if isempty(indql1)
    	rhoinv=cp.rhoinv;
    else
      rhoinv = find_rhoinv(q,rhogauss,stap.shear,cp);        
    end
    
    if cp.prescribed
        % use prescribed mune
        mune = cp.mune(it);    
    else
        % mu(ii) is calculated from mu(ii-1) and nel(ii-1)
        if it==1
            if cp.muneff_pred==-1
                mune = cp.mune(1);                
            else
                mune = cp.muneff_pred(1);
            end
            mune_km1 = mune;
            u_ff_km1 = mune;
            nel_km1 = nel; 
            nel_km2 = 0.;
            nelref_km1 = cp.nel(1); 
            nelref_km2 = 0.;         
        else
            % Feedforward control based on mune dependence on Ip and input power.
            % + Feedback PI controller
            % mune(it) = mune(it-1) + gi*(nel_ref(it-1) - nel_eff(it-1)) +
            % gp*(nel_ref(it-1) - nel_eff(it-1) - (nel_ref(it-2) - nel_eff(it-2)))
            % + mune_ref(it) - mune_ref(it-1)
            
            % Feedback controller                   
            if cp.fb_on
                err_km1 = (nelref_km1 - nel_km1)./(nel_km1+1.e-06);
                err_km2 = (nelref_km2 - nel_km2)./(nel_km2+1.e-06); 
                if isfield(cp,'startup_gp_factor') && (it<cp.startup_gp_it)
                    du_fb = cp.startup_gp_factor*( cp.gi*err_km1 + cp.gp*(err_km1 - err_km2) );
                else
                    du_fb = cp.gi*err_km1 + cp.gp*(err_km1 - err_km2);
                end
            else
                du_fb = 0.;
            end
            
            % Feedforward controller
            if cp.ff_on
                if cp.muneff_pred==-1
                    error('no done yet');
%                     % calculated feedforward
%                     if strcmp(model.equi.tokamak,'TCV')
%                         coeff = 5.*1.5e+03;
%                     elseif strcmp(model.equi.tokamak,'AUG')
%                         coeff = 3.3*1.e+03;
%                     end
%                     Te0 = coeff*((1.e-6*Ip(end)).^0.93).*((1.e-6*Ptotin).^0.3).*((1.e-19*ne(1)./1.5).^(-0.6));        
%                     u_ff_k = (Te0*exp(-lambdane*(rhoped -rho_inv))- model.te.BCval)./(1.0-rhoped);
                else
                    % prescribed feedforward
                    u_ff_k = cp.muneff_pred(it);
                end
                du_ff = u_ff_k - u_ff_km1; 
                % save for the following iteration
                u_ff_km1 = u_ff_k;
            else
                du_ff = 0.;
            end
            
            % Calculate mune
            mune = mune_km1 + du_ff + du_fb;        
            % Avoid too low muTe values
            if mune<=10. %50
                mune = 10.;%50.;
            end
            
            % Save for the following iteration
            mune_km1 = mune;
            nel_km2 = nel_km1; 
            nel_km1 = nel;  
            nelref_km2 = nelref_km1;
            if numel(cp.nel)==2
                nelref_km1 = (1.-stap.lamH)*cp.nel(1) + stap.lamH*cp.nel(2);
            else
                nelref_km1 = cp.nel(it);
            end
        end
    end
    if isfield(cp,'save_mutrace') && cp.save_mutrace 
      if it == 1
        munetrace = zeros(1,numel(cp.nel));
        munetrace(it) = mune;
      elseif it == numel(cp.nel)
        munetrace(it) = mune;
        save(cp.fname_mutrace,'munetrace')
      else 
        munetrace(it) = mune;
      end
    end
end
if isfield(cp,'plot_hh_fgr') && cp.plot_hh_fgr && firstNewtonIteration 
  figure(101);
  subplot(212)
  plot(it, nel./nelref_km1,'r*');hold on
  xlabel('iteration')
  ylabel('nel/nelref')
end
% f functions
f1 = 1./(1.+exp((rhoped - rhogauss)./cp.wrhoped));
f2 = 1./(1.+exp((rhogauss - rhoped)./cp.wrhoped));
if cp.jac_rhoinv
  f3 = 1-1./(1+exp((stap.q-1)./cp.wrhoinv));
else
  f3 = 1./(1.+exp((rhoinv - rhogauss)./cp.wrhoinv));
end


% rho_edge
rho_b = sqrt(geop.Phib./(pi*geop.B0));
% Expression in the brackets
expr_br = mune*f1./(ne.*rho_b) + lambdane*f2/rho_b;

% Formula Vpe/Dne
vpodn = -expr_br.*f3;
% Dne
Dne = cp.dnochie*chie;
% Vpe
Vpe = vpodn.*Dne;

% % Avoid negative and close-to-zero Vp_e: take a max value between simulated
% % Vp_e and fixed Vp_e.
% Vpe = max(Vp_e,cp.vpe);

varargout{1} = Dne;
varargout{2} = Vpe;

if (nargout>2) % then compute detivatives     
    %%% Derivatives for Dne
    dDne_dx    = cp.dnochie*dchie_dx;
    dDne_dxdot = cp.dnochie*dchie_dxdot;
    dDne_du    = cp.dnochie*dchie_du;
    
    %%% Derivatives of the expression in the brackets
    % dbrexpr_dx
    dexprbr_dx = -bsxfun(@rdivide,mune*bsxfun(@times,f1,dne_dx)/rho_b,ne.^2);   
    
    %%% Derivatives for Vpe
    if cp.jac_rhoinv
      df3_dx = 1./(1+exp((stap.q-1)./cp.wrhoinv)).^2.*exp((stap.q-1)./cp.wrhoinv).*(1./cp.wrhoinv).*stap.dq_dx;
      dVpe_dx    = -bsxfun(@times,dDne_dx,expr_br.*f3) -bsxfun(@times,dexprbr_dx,Dne.*f3) -bsxfun(@times,df3_dx,Dne.*expr_br);
    else
      dVpe_dx    = -bsxfun(@times,dDne_dx,expr_br.*f3) -bsxfun(@times,dexprbr_dx,Dne.*f3);
    end
    dVpe_dxdot = -bsxfun(@times,dDne_dxdot,expr_br.*f3);
    dVpe_du    = -bsxfun(@times,dDne_du,expr_br.*f3);
    
    varargout{3} = dDne_dx;
    varargout{4} = dDne_dxdot;
    varargout{5} = dDne_du;
    varargout{6} = dVpe_dx;
    varargout{7} = dVpe_dxdot;
    varargout{8} = dVpe_du;
    
%     % optional: test gradients
%     if cp.check == 1
%         check_gradients(x,xdot,g,v,vdot,u,it,model,params,chie,dchie_dx,dchie_dxdot,dchie_du)
%     end
end

end

% function check_gradients(x,xdot,g,v,vdot,u,it,model,params,chie,dchie_dx,dchie_dxdot,dchie_du)
% % Perturbation level
% alpha = [1e-04 1e-03 1e-02];
% alpha = [alpha -alpha];
% na = numel(alpha);
% % Initialize
% ratio_dchiedte = zeros(model.rgrid.nrhogauss,na);
% ratio_dchiedpsi = zeros(model.rgrid.nrhogauss,na);
% ratio_dchiedtedot = zeros(model.rgrid.nrhogauss,na);
% ratio_dchiedpsidot = zeros(model.rgrid.nrhogauss,na);
% ratio_dchiedIp = zeros(model.rgrid.nrhogauss,na);
% ratio_dchiedu = zeros(model.rgrid.nrhogauss,na);
% 
% for ii=1:na
%     % dpsi_drho perturbation (in jtor dpsi_drho)
%     dpsip = alpha(ii)*ones(size(model.rgrid.rhogauss));    
%     % Spline coefficients for perturbed psi
%     psihat_pert = model.psi.Lampgauss\dpsip;
%     % Perturbed state vector
%     dx = [psihat_pert; zeros(size(psihat_pert))];
%     % Get perturbed part of chie
%     dchie_num = chi_MS(x + dx,xdot,g,v,vdot,u,it,model,params) - chie;
%     % Get analytical dchie from dchie_dx
%     dchie_anl = dchie_dx*dx;
%     % Comparison of anlytical and numerical derivatives
%     dchiedp_num = dchie_num./dpsip;
%     dchiedp_anl = dchie_anl./dpsip;
%     % Ratio
%     ratio_dchiedpsi(:,ii) = dchiedp_anl./(1e-06*dchiedp_num(end) + dchiedp_num);
%     
%     % Electron temperature perturbation
%     dte = alpha(ii)*ones(size(model.rgrid.rhogauss));    
%     % Spline coefficients for perturbed te
%     tehat_pert = model.te.Lamgauss\dte;
%     % Perturbed state vector
%     dx = [zeros(size(tehat_pert));tehat_pert];
%     % Get perturbed part of chie
%     dchie_num = chi_MS(x + dx,xdot,g,v,vdot,u,it,model,params) - chie;
%     % Get analytical dchie from dchie_dx
%     dchie_anl = dchie_dx*dx;
%     % Comparison of anlytical and numerical derivatives
%     dchiedp_num = dchie_num./dte;
%     dchiedp_anl = dchie_anl./dte;  
%     % Ratio
%     ratio_dchiedte(:,ii) = dchiedp_anl./(1e-06*dchiedp_num(end) + dchiedp_num);
% 
%     % Psi dot perturbation
%     dpsidot = alpha(ii)*ones(size(model.rgrid.rhogauss));  
%     % Spline coefficients for perturbed te
%     psihatdot_pert = model.psi.Lamgauss\dpsidot;
%     % Perturbed state vector
%     dxdot = [psihatdot_pert; zeros(size(psihatdot_pert))];
%     % Get perturbed part of chie
%     dchie_num = chi_MS(x,xdot+dxdot,g,v,vdot,u,it,model,params) - chie;
%     % Get analytical dchie from dchie_dxdot
%     dchie_anl = dchie_dxdot*dxdot;
%     % Comparison of anlytical and numerical derivatives
%     dchiedp_num = dchie_num./dpsidot;
%     dchiedp_anl = dchie_anl./dpsidot;
%     % Ratio
%     ratio_dchiedpsidot(:,ii) = dchiedp_anl./(1e-06*dchiedp_num(end) + dchiedp_num);
%     
%     % Electron temperature time derivative perturbation
%     dtedot = alpha(ii)*ones(size(model.rgrid.rhogauss));    
%     % Spline coefficients for perturbed te
%     tehatdot_pert = model.te.Lamgauss\dtedot;
%     % Perturbed state vector
%     dxdot = [zeros(size(tehatdot_pert));tehatdot_pert];
%     % Get perturbed part of chie
%     dchie_num = chi_MS(x,xdot+dxdot,g,v,vdot,u,it,model,params) - chie;
%     % Get analytical dchie from dchie_dxdot
%     dchie_anl = dchie_dxdot*dxdot;
%     % Comparison of anlytical and numerical derivatives
%     dchiedp_num = dchie_num./dtedot;
%     dchiedp_anl = dchie_anl./dtedot;
%     % Ratio
%     ratio_dchiedtedot(:,ii) = dchiedp_anl./(1e-06 + dchiedp_num);
%     
%     % Actuators perturbation
%     du = alpha(ii)*ones(size(u));  
%     % Get perturbed part of chie
%     dchie_num = chi_MS(x,xdot,g,v,vdot,u+du,it,model,params) - chie;
%     % Get analytical dchie from dchie_dxdot
%     dchie_anl = dchie_du*du;
%     % Comparison of anlytical and numerical derivatives
%     % Plasma current
%     dchiedp_num = dchie_num./du(1);  
%     dchiedp_anl = dchie_anl./du(1);
%     % Ratio
%     ratio_dchiedIp(:,ii) = dchiedp_anl./(1e-06*dchiedp_num(end) + dchiedp_num);
%     % Anothe actuator
%     if du>1
%     dchiedp_num = dchie_num./du(2);  
%     dchiedp_anl = dchie_anl./du(2);
%     ratio_dchiedu(:,ii) = dchiedp_anl./(1e-06*dchiedp_num(end) + dchiedp_num);
%     end
% end
% 
% figure;
% subplot(2,3,1)
%  plot(model.rgrid.rhogauss,ratio_dchiedpsi,'o');
%  legend('alpha=1e-04','alpha=1e-03','alpha=1e-02','alpha=-1e-04','alpha=-1e-03','alpha=-1e-02');
%  str = sprintf('dchie/dpsi anl/num: it=%2.0f',it);title(str);
% subplot(2,3,2)
%  plot(model.rgrid.rhogauss,ratio_dchiedte,'o');
%  str = sprintf('dchie/dte anl/num: it=%2.0f',it);title(str);
% subplot(2,3,3)
%  plot(model.rgrid.rhogauss,ratio_dchiedpsidot,'o');
%  str = sprintf('dchie/dpsidot anl/num: it=%2.0f',it);title(str);
% subplot(2,3,4)
%  plot(model.rgrid.rhogauss,ratio_dchiedtedot,'o');
%  str = sprintf('dchie/dtedot anl/num: it=%2.0f',it);title(str);
% subplot(2,3,5)
%  plot(model.rgrid.rhogauss,ratio_dchiedIp,'o');
%  str = sprintf('dchie/dIp anl/num: it=%2.0f',it);title(str);
% subplot(2,3,6)
%  plot(model.rgrid.rhogauss,ratio_dchiedu,'o');
%  str = sprintf('dchie/du anl/num: it=%2.0f',it);title(str);
% return


function firstTime = isFirstTime(it)
persistent itp
if isempty(itp)
  itp = 0; % init
end

firstTime = (itp ~= it);
itp = it; % update counter
end