function varargout = jacfd(F, x0, varargin)
% JACFD Generic finite difference jacobian evaluation
%
%   varargout = jacfd(F, x0[, varargin])
%
% Required inputs
%   F: handle to the function to be differentiated.
%   x0: cell array containing all arguments to F at point where to evalute
%     the jacobian.
%
% Optional inputs in name-value pairs:
%   F0: cell array containing the values of all outputs of F (up to
%     nargout) at the point of evaluation (default: F0=F(x0{:})).
%   szF0: cell array containing the size of all outputs of F (up to
%     nargout). For functions with scalar or column outputs, specifying a
%     scalar size avoids jacfd returning 3D arrays. (default: sizes of F0).
%   iargin: index of the argument to perturb (default: 1).
%   epsval: amplitude of the perturbation used for the finite difference
%     scheme (default: sqrt(eps)).
%   pertgroups: cell array of array of indices (see below)
%   fdtype: centered/forward/backward, type of finite difference scheme
%   (default: centered).
%   workers: if larger than 1 sets the size of a parallel pool for parallel
%     evaluations of F (default: 1).
%
% Outputs:
%   varargout: F will be called with as many output arguments as were
%   requested in the call to JACFD. With the call:
%     [dF1dx,dF2dx] = jacfd(fun,x0)
%   it is expected that fun can produce at least 2 output arguments and
%   dF1dx/dF2dx will contain the approximate jacobians for each.
%
% Perturbation groups:
%   if no groups are specified, then each element of the
%   specified input argument is varied independently yielding a jacobian of
%   size [nF,nx]. However if many elements of this jacobian are zero, we
%   can try to reduce the number of function evaluations by grouping
%   together indices of x (the dependent variable) which affect different
%   indices of F. For example if F=@sin and x0={linspace(0,2*pi,n)}, the
%   jacobian of F will be diagonal as evaluations at different values of x
%   are independent, we can then do
%     dFdx = jacfd(@sin,{linspace(0,2*pi,n)},'pertgroups',{1:n});
%   which will return only the diagonal elements of the jacobian and use
%   only 2 calls to F instead of 2*n.
%
% [+MEQ MatlabEQuilibrium Toolbox+]

%    Copyright 2022-2025 Swiss Plasma Center EPFL
%
%   Licensed under the Apache License, Version 2.0 (the "License");
%   you may not use this file except in compliance with the License.
%   You may obtain a copy of the License at
%
%       http://www.apache.org/licenses/LICENSE-2.0
%
%   Unless required by applicable law or agreed to in writing, software
%   distributed under the License is distributed on an "AS IS" BASIS,
%   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%   See the License for the specific language governing permissions and
%   limitations under the License.

% Setup input parser
p = inputParser;
p.addRequired('F',@(x) isa(x,'function_handle'));
p.addRequired('x0',@iscell);
p.addParameter('F0',{},@iscell);
p.addParameter('szF0',{},@iscell);
p.addParameter('iargin',1,@(x) isscalar(x) && (x == round(x)));
p.addParameter('epsval',sqrt(eps),@(x) isnumeric(x) && all(x(:)>0));
p.addParameter('pertgroups',{},@iscell); 
p.addParameter('fdtype','centered',@(x) ismember(x,{'centered','forward','backward'}));
p.addParameter('workers',1,@(x) isscalar(x) && (x == round(x)));

% Parse inputs
p.parse(F, x0, varargin{:});
P = p.Results;

% Extract some parameters
iargin   = P.iargin;
epsval   = P.epsval;

% Assume F has as many outputs as the number expected from this one
nargoutF = nargout;

% Get initial F values ...
F0 = P.F0;
if isempty(F0)
  F0 = cell(1,nargoutF);
  [F0{:}] = F(x0{:});
end
% ... and sizes
szF0 = P.szF0;
if isempty(szF0)
  szF0 = cell(1,nargoutF);
  for iout = 1:nargoutF
    szF0{iout} = size(F0{iout});
  end
end

% Get size of dependent variable
groups = P.pertgroups;
if isempty(groups)
  groups = num2cell(1:numel(x0{iargin}));
end
ngroups = numel(groups);

% Initialize jacobians
varargout = cell(1,nargoutF);
for iout = 1:nargoutF
  varargout{iout} = zeros([prod(szF0{iout}),ngroups]);
end

% Expand scalar epsval
if isscalar(epsval), epsval = repmat(epsval,1,ngroups);
elseif numel(epsval) ~= ngroups, error('jacfd:epsvalSize','epsval should be a scalar or have %d elements',ngroups);
end
% Difference between xp and xm
idelta = 1./((1+strcmp(P.fdtype,'centered'))*epsval);

% Compute jacobian
if P.workers>1 % Parallel version
  % Create parallel pool if needed
  pool = gcp('nocreate');
  if isempty(pool)
    parpool([2,P.workers]);
  end

  Fp = repmat({F0},1,ngroups); % cell of cell
  Fm = Fp;
  parfor (igroup = 1:ngroups, P.workers)
    % Prepare perturbed arguments
    % Perturb all values in one group
    ix = groups{igroup};
    xp = x0; xp{iargin}(ix) = xp{iargin}(ix) + epsval(igroup);
    xm = x0; xm{iargin}(ix) = xm{iargin}(ix) - epsval(igroup);
    % Evaluate function
    switch P.fdtype %#ok<*PFBNS> 
      case 'centered'
        [Fp{igroup}{:}] = F(xp{:});
        [Fm{igroup}{:}] = F(xm{:});
      case 'forward'
        [Fp{igroup}{:}] = F(xp{:});
      case 'backward'
        [Fm{igroup}{:}] = F(xm{:});
    end
  end
  % Compute derivatives
  for iout = 1:nargoutF
    for igroup = 1:ngroups
      varargout{iout}(:,igroup) = (Fp{igroup}{iout}(:) - Fm{igroup}{iout}(:))*idelta(igroup);
    end
  end
else % Serial version
  Fp = F0;
  Fm = F0;
  for igroup = 1:ngroups
    % Prepare perturbed arguments
    % Perturn all values in one group
    ix = groups{igroup};
    xp = x0; xp{iargin}(ix) = xp{iargin}(ix) + epsval(igroup);
    xm = x0; xm{iargin}(ix) = xm{iargin}(ix) - epsval(igroup);
    % Evaluate function
    switch P.fdtype
      case 'centered'
        [Fp{:}] = F(xp{:});
        [Fm{:}] = F(xm{:});
      case 'forward'
        [Fp{:}] = F(xp{:});
      case 'backward'
        [Fm{:}] = F(xm{:});
    end
    % Compute derivatives
    for iout = 1:nargoutF
      varargout{iout}(:,igroup) = (Fp{iout}(:) - Fm{iout}(:))*idelta(igroup);
    end
  end
end

% Reshape jacobians
for iout = 1:nargoutF
  varargout{iout} = reshape(varargout{iout},[szF0{iout},ngroups]);
end

end
