classdef linesearch_tests < meq_test
  % Tests for linesearch (part of Newton solver)
  %
  % [+MEQ MatlabEQuilibrium Toolbox+]

  %    Copyright 2022-2025 Swiss Plasma Center EPFL
  %
  %   Licensed under the Apache License, Version 2.0 (the "License");
  %   you may not use this file except in compliance with the License.
  %   You may obtain a copy of the License at
  %
  %       http://www.apache.org/licenses/LICENSE-2.0
  %
  %   Unless required by applicable law or agreed to in writing, software
  %   distributed under the License is distributed on an "AS IS" BASIS,
  %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  %   See the License for the specific language governing permissions and
  %   limitations under the License.

  properties(TestParameter)
    testfunction = {'linear','quadratic','cubic','nans'};
    aguess={0.99,1};
  end

  methods(Test,TestTags={'Unit'})
    function test_linesearch_behaviour(testCase,testfunction,aguess)
      % Ignoring linear and quadratic function, because the first 2 proposals are deterministic and are given
      % by aguess, aguess/2, which do not correspond to the correct answer
      % 1 and 1/2 if aguess=0.99
      testCase.assumeFalse(~strcmp(testfunction, 'cubic') && (aguess<1), ...
        sprintf('Only tests with cubic function if aguess<1: Found testfunction=%s aguess=%f<1',testfunction,aguess))
      [fun,dfun_dx,xmin,vmin,kexpected] = testCase.function_factory(testfunction);

      nfeval = 0; % number of function evaluations

      % function value and derivative
      [v0,nfeval] = fun(0,nfeval);
      dv0_da = dfun_dx(0);

      dodisp = 1; doplot = 1; % disp and plotting
      norm2x=[]; % empty -> not used when dv0_da is given
      rho=1; 
      tau=0; 

      [a,v,accept,kiter,nfeval] = linesearch(fun,v0,dv0_da,norm2x,rho,tau,aguess,nfeval,dodisp,doplot);

      testCase.verifyTrue(accept,'returned accept=false')
      testCase.verifyGreaterThanOrEqual(nfeval,kiter,'unexpected number of function evaluations')
      testCase.verifyEqual(kiter,kexpected,'unexpected number of iterations')
      testCase.verifyEqual(a,xmin,'AbsTol',10*eps);
      testCase.verifyEqual(v,vmin,'AbsTol',10*eps);
    end

    function test_linesearch_behaviour_no_deriv(testCase,testfunction,aguess)
      % Ignoring linear and quadratic function, because the first 2 proposals are deterministic and are given
      % by aguess, aguess/2, which do not correspond to the correct answer
      % 1 and 1/2 if aguess=0.99
      testCase.assumeFalse(~strcmp(testfunction, 'cubic') && (aguess<1), ...
        sprintf('Only tests with cubic function if aguess<1: Found testfunction=%s aguess=%f<1',testfunction,aguess))
      [fun,dfun_dx,xmin,vmin,~] = testCase.function_factory(testfunction);

      nfeval = 0; % number of function evaluations

      % function value and derivative
      [v0,nfeval] = fun(0,nfeval);
      dv0_da = dfun_dx(0);
      dx=dv0_da;

      dodisp = 1; doplot = 1; % disp and plotting
      norm2x=dx'*dx;
      rho=0.95; % convergence of the method such that |F(x+a*dx)| < |F(0)|*(rho+tau) - c*a^2*dx'*dx)
      tau=0; % we have descending direction (no need of positive tau)
      dv0_da=[];

      [a,v,accept,kiter,nfeval] = linesearch(fun,v0,dv0_da,norm2x,rho,tau,aguess,nfeval,dodisp,doplot);

      testCase.verifyTrue(accept,'returned accept=false')
      testCase.verifyGreaterThanOrEqual(nfeval,kiter,'unexpected number of function evaluations')
      testCase.verifyEqual(a,xmin,'AbsTol',10*eps);
      testCase.verifyEqual(v,vmin,'AbsTol',10*eps);
    end
  end

  methods(Static)
    function [fun,dfun_dx,xmin,vmin,kexpected] = function_factory(testfunction)
      % returns function, function derivative (handles) and xmin,vmin(location of minimum)
      switch testfunction
        case 'linear'
          xmin = 1; vmin = 0;
          fun = @(x) (xmin-x);
          dfun_dx = @(x) -1;
          kexpected = 1; % expect to find solution at 1st attempt
        case 'quadratic'
          xmin = 0.5; vmin = 0; % solution
          fun = @(x) 0.5*(x-xmin).^2;
          dfun_dx = @(x) (x-xmin);
          kexpected = 2; % expect to find solution after 2 attempts
        case 'cubic'
          c3 = -5.6;
          c2 = 7.5;
          c1 = -1.775;
          c0 = 1;

          fun = @(x) c3*x.^3 + c2*x.^2 + c1.*x + c0;
          dfun_dx = @(x) 3*c3*x.^2 + 2*c2*x + c1;

          D = (2*c2).^2-4*(3*c3)*c1;
          assert(D>0,'cubic function has no minimum');
          xsols(1) = (-(2*c2) - sqrt(D)) / (2*(3*c3));
          xsols(2) = (-(2*c2) + sqrt(D)) / (2*(3*c3));
          [vmin,imin] = min(fun(xsols));
          xmin = xsols(imin);
          kexpected = 3; % expect to find solution after 3 attempts
        case 'nans'
          % like quadratic, but return NaN if x>=0.3
          fun = @(x,nfeval) nanquad(x,0);
          dfun_dx = @(x) nanquad(x,1);
          xmin = 0.25; % expected solution following 2 backtracks
          vmin = fun(xmin); % value at solution
          kexpected = 3;
        otherwise, error('unknown testfunction %s',testfunction)
      end
      
      % add function evaluation counter and extra outputs
      fun = @(x,nfeval) function_with_counter(fun,x,nfeval);
    end
  end
end

function [y,nfeval,funout] = function_with_counter(fun,x,nfeval)
  y = fun(x);
  nfeval = nfeval + 1;
  funout = {}; % no other outputs
end

function [y] = nanquad(x,der)
% test function: quadratic for x<=0.3, NaN for x>0.3
xmin = 0.5;
if der==0 % 0th derivative
  y = 0.5*(x-xmin).^2;
elseif der==1 % 1st derivative
  y = x-xmin;
end
y(x>0.3) = NaN; % overwrite NaNs
end