function [a,v,accept,kiter,nfeval,varargout] = linesearch(fun,v0,dv0_da,norm2x,rho,tau,astart,nfeval,dodisp,doplot)
% LINESEARCH - Interpolation-based line search for newton solver
%
% [a,v,accept,kiter,nfeval,varargout] = linesearch(fun,argsout,v0,dv0_da,norm2x,rho,tau,astart,nfeval,dodisp,doplot)
%
% Finds a step length a and value v that satisfy a Armijo-like inequality.
% If dv0_da is empty, a derivative-free method is employed:
%     the method uses the constraint:
%           v(a)<0.5((rho+tau_k)*sqrt(2*v(0)​−c*a^2*dx'*dx)^2
%     for a sequence of tau_k such that sum_k tau_k<inf and
%     v(a)=0.5*|F(x+a*dx)|^2.
%     (This is equivalent to |F(x+a*dx)| < |F(0)|*(rho+tau_k) - c*a^2*dx'*dx)
% Otherwise
%     the method uses the modified Armijo condition:
%           v(a) < (rho+tau_k)*v(0) + c*dv_da*a
%     for a descending direction dv_da<0 and for a sequence of tau_k such that sum_k tau_k<inf
%     As k->inf, this condition converges to the usual Armijo condition (if rho=1)
%
% Constant c is presently hard-coded as c=1e-4
%
% New function evaluation v(a) are selected using using the interpolation method described in
% Nocedal & Wright (2nd Edition, 2006) Section 3.5. This approximates
% the function by a quadratic/cubic function and chooses step size candidate as
% minimizer of the surrogate function. The fit differs between 2 methods
% 1)  For the derivative-free method, unconstrained quadratic/cubic interpolation
% 2)  Otherwise, the cubic interpolation has dv/da at a=0
%
% Ref:
%   Dong-Hui Li & Masao Fukushima (2000) A derivative-free line search and
%   global convergence of Broyden-like method for nonlinear equations,
%   Optimization Methods and Software, 13:3, 181-201, DOI: 10.1080/10556780008805782
%
% Inputs:
%    fun is function handle such that [v,nfeval,funout] = fun(a,nfeval) (see below)
%    v0: value of fun at a=0
%    dv0_da: value of dv/da at a=0 (if empty, the derivative-free method is used)
%    norm2x: squared norm of the full Newton step (dx'*dx) (only used with the derivative-free method)
%    rho: convergence reate
%    tau: controlled increase (sum tau_k <inf)
%    astart: value of a to start from if a=1 attempt is rejected
%    nfeval: function evaluation counter
%    dodisp: if true, display debugging information.
%    doplot: if true, plot line search space on a grid with its approximant
%
% Outputs:
%    a: accepted step size
%    v: function value at accepted step size
%    accept: true if acceptable step size found
%    kiter: number of iterations taken
%    nfeval: function evaluation counter update
%    varargout: other outputs of 'fun' (see below)
%
% details of 'fun'
%   function handle with signature [v,nfeval,funout] = fun(a,nfeval)
%   with inputs: a (step size), nfeval (function evaluation counter)
%   and outputs: v (value), nfeval (function evaluation counter update)
%                funout (cell array of other function outputs, passed as
%                varargout to linesearch function
%
% [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.

% parameters
c       = 1e-4;    % minimal descent fraction w.r.t. expected descent
niter   = 10;      % maximum iterations
mina    = 0.01;    % minimum absolute size of a
tolrela = 0.05;    % minimum relative size of a w.r.t. previous trial. 0 means this check is disabled
tolda   = 0.05;    % minimum relative change in a w.r.t. previous trial.  0 means this check is disabled

% init
accept = false; % init
a0 = 0;
a = astart; % starting guess for step size a
aval = zeros(niter+1,1);
fval = zeros(niter+1,1);
aval(1) = a0;
fval(1) = v0;
varargout = cell(max(nargout-5,0),1);

agrid = [];
vgrid = [];
derivativefree = isempty(dv0_da);

if derivativefree % Not necessarily a descending direction
  % Modified Armijo rule base on
  % |F(x_k+1)|<|F(x_k)|-c*|a*dx|^2
  % In addition, we allow small increase of the function values
  % |F(x_k+1)|<(rho+tau_k)|F(x_k)|-c*|a*dx|^2
  % for a sequence of tau_k such that sum |tau_k|<inf
  % In practice, we use the condition:
  % 0.5|F(x_k+1)|^2<0.5((rho+tau_k)|F(x_k)|-c*|a*dx|^2).^2
  normf = sqrt(2*v0);
  upperbd = @(a) 0.5*((rho+tau) * normf-c * norm2x * a.^2).^2;
else % Armijo condition, Eq(3.56), but allowing small function increase
  % Small increase might be useful at the beginning of the optimization
  % process, where some "walls", that need to be jumped over, might appear.

  % Check descending direction
  assert(dv0_da<0,'dv0_da must be a descent direction if `derivativefree=false`')
  upperbd = @(a) (rho+tau)*v0 + c * a * dv0_da;
end

if doplot
  % tabulate function values to later plot it
  agrid = unique([linspace(0,astart,21),1]);
  vgrid = zeros(size(agrid));
  for ia = 1:numel(agrid)
    [vgrid(ia),nfeval] = fun(agrid(ia),nfeval);
  end
end

if dodisp
  linesearchdisp(); % display header
end

%% first try full step if astart is not 1
if astart < 1
  a = 1; kiter = 1;
  [v,nfeval,varargout{:}] = fun(a,nfeval);
  v_accept = upperbd(a);
  if ~isnan(v) && v<v_accept
    accept = true;
    if dodisp, linesearchdisp(kiter,niter,a,v,v_accept,'accept'); end
    return % accept this
  else
    if dodisp, linesearchdisp(kiter,niter,a,v,v_accept,'reject'); end
  end
  fval(2) = v;
  aval(2) = a;
  kstart = 2;
else
  % astart is already 1 so this will be the first attempt of the for loop
  kstart = 1;
end

%% do line search starting from astart

a = astart; % starting guess for step size a if full step fails
kvalid=kstart;  % number of valid function values (nan are ignored)
for kiter=kstart:niter

  % Evaluate function value including additional outputs
  [v,nfeval,varargout{:}] = fun(a,nfeval);
  fval(kvalid+1) = v;
  aval(kvalid+1) = a;

  % acceptance value
  v_accept = upperbd(a);

  if isnan(v)
    % check for NaNs, reset to new a and continue if this happens
    if dodisp, linesearchdisp(kiter,niter,a,v,v_accept,'nans in v, reset'); end
    a = a/2;
    continue
  elseif v < v_accept
    % acceptable descent direction, return
    accept = true;
    if doplot
      clf;
      plot(agrid,vgrid,'.-'); hold on; % plot function tabulation
      plot(aval(1:kvalid+1),fval(1:kvalid+1),'o');
      plot(agrid,upperbd(agrid),'--'); % plot function upperbound
      plot(a,v,'s'); % solution
      plot([a0,a],[v0,v],'x') % plot constraining points
      if kvalid>1, plot(aval(kvalid),fval(kvalid),'x'); end % previous point
      title(sprintf('converged ntrial=%d,a=%3.5f',kvalid,a));
      drawnow;
    end

    if dodisp
      linesearchdisp(kiter,niter,a,v,v_accept,'accept')
    end
    return
  else
    if dodisp
      linesearchdisp(kiter,niter,a,v,v_accept,'reject')
    end
    % attempt to find new candidate
    a_next = predict_next(aval(1:kvalid+1),fval(1:kvalid+1),dv0_da,doplot,agrid,vgrid,upperbd,astart,mina);
    a = a_next; % new candidate
    ap = aval(kvalid+1);

    % checks on new candidate - as proposed in [Nocedal & Wright bottom of Page 58]

    if (abs(a-ap))/ap < tolda % a too close to previous value
      if dodisp, fprintf('%14s ** new stepsize=%f too close to previous value, reset\n','',a); end
      a = ap/2;
    elseif a < mina % a too small
      if dodisp, fprintf('%14s ** new stepsize=%f absolute size too small, reset\n','',a); end
      a = ap/2;
    elseif a/ap < tolrela % a too small
      if dodisp, fprintf('%14s ** new stepsize=%f relative size too small, reset\n','',a); end
      a = ap/2;
    end
  end
  kvalid=kvalid+1; % increase iteration number
end

% reached kiter>niter without finding solution
if dodisp
  fprintf('line stepping did not find a reasonable step size after %d iterations\n',kiter);
  fprintf('Newton direction may not be a descent direction\n')
end

end

function a_next=predict_next(aval,fval,dv0_da,doplot,agrid,vgrid,upperbd,astart,amin)
% Find the next stepsize for the linesearch. Use past values aval, fval to
% fit either a quadratic/cubic polynomial (depending on the number of datapoints)
% doplot,agrid,vgrid,upperbd inputs are used for plotting.
neval = size(aval,1);
derivativefree = isempty(dv0_da);
if ~derivativefree
  % Using Armijo condition, Eq(3.56). The function derivative at a=0 must
  % be provided to evaluate the condition. In addition, the surrogate function
  % derivative at a=0 is constrained to equal to dv0_da
  a = aval(end);
  if neval == 2 % this is the second attempt
    a_next = a/2; % always choose a/2 as second attempt
  else
    % further attempt (kiter>1)
    % predict next stepsize (decreasing) via cubic interpolation with fixed
    % derivative at a=0
    % vcubic(a) = c3*a^3 + c2*a^2 +  dv0_da*a + v0
    % with coefficients constrained by
    % dfun_da(a0) = dv0_da, fun(a0)=v0, fun(ap)=vp: fun(a)=v:
    % [Nocedal & Wright bottom of Page 58]
    atmp=[aval(1); aval(end-1:end)];
    ftmp=[fval(1); fval(end-1:end)];
    [approx_f,coef]=fit_surrogate(atmp,ftmp,'cubic_deriv',dv0_da);
    a_next=min_func(approx_f,coef,'cubic',0,astart);

    if doplot
      % plot function approximation, all points, and current iterate
      vcube = approx_f(agrid');
      vmin  = approx_f(a_next); % min value

      clf;
      plot(agrid,vgrid,'.-'); hold on; % plot function tabulation
      plot(agrid,upperbd(agrid),'--'); % plot function upperbound
      plot(agrid,vcube,'o'); % plot approximation
      plot(atmp,ftmp,'x') % plot constraining points
      plot(a_next,vmin,'s'); % local minimum
      title(sprintf('candidate ntrial=%d,a=%3.5f',neval,amin));
      drawnow;
    end
  end
else
  % Derivative-free condition. Does not necessarily expect a descending
  % direction, but only controlled function increase are accepted (monitor by upperbound)
  if neval==2
    % need at least 3 datapoints for the regression
    a = aval(end);
    a_next = a/2;
  elseif neval==3
    % predict next stepsize (decreasing) via quadratic interpolation
    [approx_f,coef]=fit_surrogate(aval,fval,'quadratic');

    a_next=min_func(approx_f,coef,'quadratic',0,astart);
  else
    % predict next stepsize (decreasing) via cubic interpolation
    % y~a+x*b+x^2*c+x^3*d
    % use last 3 data points (+origin)
    atmp=[aval(1); aval(end-2:end)];
    ftmp=[fval(1); fval(end-2:end)];
    [approx_f,coef]=fit_surrogate(atmp,ftmp,'cubic');

    a_next=min_func(approx_f,coef,'cubic',0,astart);
    if doplot
      % plot function approximation, all points, and current iterate
      vcube = approx_f(agrid');
      vmin  = approx_f(a_next); % min value

      clf;
      plot(agrid,vgrid,'.-'); hold on; % plot function tabulation
      plot(agrid,upperbd(agrid),'--'); % plot function upperbound
      plot(agrid,vcube,'o'); % plot approximation
      plot(a_next,vmin,'s'); % local minimum
      plot(aval,fval,'x') % plot constraining points
      drawnow;
    end
  end
end
end

function [f,coef]=fit_surrogate(aval,fval,type,dvda)
% Fit a function to the dataset [aval,fval]
% type is either quadratic or cubic or cubic with derivative at a=0
% The function returns the function and the coefficient of the polynomial
neval=numel(aval);
if strcmp(type,'quadratic')
  % use quadratic function
  X = [ones(neval,1), aval, aval.^2];
  y = fval;
  if min(diff(sort(aval)))<1e-3
    % Regularize problem if the difference between 2 evaluation points is too small
    coef = (X'*X+eye(3)*1e-8)\(X'*y);
  else
    coef = X\y;
  end
  f=@(x) [ones(numel(x),1), x, x.^2] * coef;
elseif strcmp(type,'cubic_deriv')
  % use cubic function
  % vcubic(a) = c3*a^3 + c2*a^2 +  dv0_da*a + v0
  % with coefficients constrained by
  % dfun_da(a0) = dv0_da, fun(a0)=v0, fun(ap)=vp: fun(a)=v:
  % [Nocedal & Wright bottom of Page 58]
  v0 = fval(1);
  ap = aval(end-1);
  vp = fval(end-1);
  a = aval(end);
  v = fval(end);
  coeff = 1/(ap^2*a^2*(a-ap))*...
    [ap^2 -a.^2;-ap.^3 a^3]*...
    [ v - v0 - dvda*a;
    vp - v0 - dvda*ap];
  c3 = coeff(1);
  c2 = coeff(2);
  coef=[v0;dvda,;c2;c3];
  f=@(x) [ones(numel(x),1), x, x.^2, x.^3] * coef;
elseif strcmp(type,'cubic')
  % use unconstrained cubic function
  X = [ones(neval,1), aval, aval.^2, aval.^3];
  y = fval;
  if min(diff(sort(aval)))<1e-3
    % Regularize problem if the difference between 2 evaluation points is too small
    coef = (X'*X+eye(4)*1e-8)\(X'*y);
  else
    coef = X\y;
  end
  f=@(x) [ones(numel(x),1), x, x.^2, x.^3] * coef;
else
  error('Unknown type of function')
end
end


function minimizer=min_func(f,coef,type,lower,upper)
% Find minimizer of function f (type=quadratic or cubic)
% on the interval [lower,upper]
if strcmp(type,'quadratic')
  % Minimum or maximum (depend on sign of coef(3))
  minimizer = -coef(2)/(2*coef(3));
  if coef(3) > 0
    % minimizer outside range
    if minimizer < lower
      minimizer = lower;
    elseif minimizer > upper
      minimizer = upper;
    end
  else
    % minimizer is a maximum! -> eval f on boundary
    if f(upper) > f(lower)
      minimizer = lower;
    else
      minimizer = upper;
    end
  end
elseif strcmp(type,'cubic')
  % Minimum or maximum (depend on sign of coef(4))
  delta=sqrt(coef(3).^2-3*coef(4)*coef(2));
  % Root of the gradient should exist, otherwise minimizer should be on the
  % border of the interval
  if isreal(delta)
    minimizer1 = (-coef(3)+delta)/(3*coef(4));
    minimizer=minimizer1;
    % minimum outside range
    if (minimizer < lower)||(minimizer > upper)
      if f(upper) > f(lower)
        minimizer = lower;
      else
        minimizer = upper;
      end
    end
  else
    % check boundary conditions
    if f(upper) > f(lower)
      minimizer = lower;
    else
      minimizer = upper;
    end
  end
else
  error('Unknown type of function')
end
end

function linesearchdisp(kiter,niter,a,v,vref,msg)
if nargin==0
  % header
  fprintf(' Linesearch: | it/max |   stepsize |    value    |  reference  | result \n')
else
  % auxiliary function for debug display
  fprintf('%12s |  %2d/%2d | %10.5f | %10.5e | %10.5e | %s\n','',kiter,niter,a,v,vref,msg)
end
end

