classdef fge_convergence_tests < meq_test
  % Tests for Time Evolving solver
  % Check convergence of dt of final zA position
  % Check the convergence increasing number of vessel modes
  % Check done for Euler Explicit with preconditioner
  %
  % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties(TestParameter)
    algoNL = {'Picard','all-nl','all-nl-Fx','Newton-GS'};
    icsint = {false,true};
  end
  
  properties
    tok = 'ana';
    dosave = 0;
    shot = 2;
    tstart = 0;
    tinterval= 5e-4; % time interval length
    dt = [ 5e-5, 1e-5, 5e-6 2e-6];
    prec = 1;
    solvetol = 1e-8;
    nu = 50;
    selu = 'e';
    verbose = 0;
    mkryl = 20;
    iterq = 20;
  end
  
  % --------------------------------------------------------------------
  methods(Test,TestTags={'fge-convergence'})
    
    function test_algoNL_methods(testCase)
      myshot = testCase.shot;
      t = testCase.dt(1)*[0:100-1];
      % run for both methods and compare result
      for ii=1:numel(testCase.algoNL)
        mymethod = testCase.algoNL{ii};
        isPicard = strcmp(mymethod,'Picard');
        % FGE parameters
        PP = {'algoNL',mymethod,...
              'usepreconditioner',~isPicard,...
      	      'anajac',~isPicard,...
              'iterq',testCase.iterq};
        if strcmpi(mymethod,'Newton-GS')
          PP = [PP,{'algoF','newton'}];%#ok<AGROW>
        end
      	% Call FGE
        [L,~,LY(ii)] = fge('ana',myshot,t,PP{:});%#ok<AGROW>
        if ii>1
          Ip0 = L.Ip0;
          testCase.verifyEqual(LY(1).Fx,LY(ii).Fx,'AbsTol',100*L.P.tolF)
          testCase.verifyEqual(LY(1).Iy,LY(ii).Iy,'AbsTol',L.P.tolF*Ip0)
          testCase.verifyEqual(LY(1).bp,LY(ii).bp,'AbsTol',100*L.P.tolF)
          testCase.verifyEqual(LY(1).li,LY(ii).li,'AbsTol',100*L.P.tolF)
          testCase.verifyEqual(LY(1).Ip,LY(ii).Ip,'AbsTol',10*L.P.tolF*Ip0)
        end
      end
    end
    
    function dt_convergence(testCase,algoNL)

      isPicard = strcmp(algoNL,'Picard');
      
      % Run simulation at different dt
      dt_list = testCase.dt;
      zAlist = zeros(size(dt_list));

      PP = {'insrc','liu',...
            'tolF',testCase.solvetol,...
            'selu',testCase.selu,...
            'nu',testCase.nu,...
            'debug',testCase.verbose,...
            'algoNL',algoNL,...
            'usepreconditioner',~isPicard,...
            'anajac',~isPicard};
      if strcmpi(algoNL,'Newton-GS')
        PP = [PP,{'algoF','newton'}];
      end
      
      L = fge(testCase.tok, testCase.shot,0,PP{:});
      
      LX0 = fgex(testCase.tok,testCase.tstart,L);
      for ii = 1:numel(dt_list)
        t = testCase.tstart + (0:dt_list(ii):testCase.tinterval);
        LX = fgex(testCase.tok,t,L,LX0);
        LX.Va(2,:) = LX.Va(2,:) + 20; % some step input
        LY = fget(L,LX);
        if testCase.verbose>1
          if ii==1; clf; end
          plot(LY.t,LY.zA,'displayname',sprintf('dt=%3.2e',dt_list(ii))); hold on;
          legend('show');
          drawnow;
        end
        zAlist(ii) = LY.zA(end);
      end
      
      % Compute residual with respect to smaller dt solution
      zAresidual = abs(zAlist(1:end-1)-zAlist(end));
      
      % Check that the total residual is decaying 
      testCase.verifyTrue(issorted(zAresidual, 'descend'));
     
      % Check convergence is linear in loglog plot
      logtt=log(dt_list(end-1:-1:1));
      logzA = log(zAresidual(end:-1:1));

      coeff_fit = polyfit(logtt,logzA,1);
      if testCase.verbose>1
        clf;
        plot(logtt,logzA,'x');  hold on;
        plot(logtt, coeff_fit(1)*logtt+ coeff_fit(2),'r--');
        legend('data','fit','location','best');
        ylabel('log error'); xlabel('log tt');
        title(sprintf('slope=%3.2f',coeff_fit(1)),'Interpreter','None')
        drawnow
      end
      testCase.verifyTrue( abs(coeff_fit(1)-1)<0.4,...
        'convergence is not linear in dt');
    end
    
    function eigenmode_convergence(testCase)
      t = testCase.tstart + (0:2e-5:5e-4);
      % Test convergence for eigenmode vessel description
      RelTol = sqrt(testCase.solvetol)*10;
      %% basic case with full vessel
      
      [L] = fge(testCase.tok,testCase.shot,t,...
        'insrc','liu',...
        'debug',testCase.verbose);
      
      LX = fgex(testCase.tok,t,L);
      LY = fget(L,LX);
      
      %% eigenmode case with all eigenmodes
      nu_ = L.G.nv; %#ok<*PROP>
      [~,~,LYe] = fge(testCase.tok,testCase.shot,t,...
        'insrc','liu',...
        'tolF',testCase.solvetol,...
        'selu','e','nu',nu_,....
        'debug',testCase.verbose);
      
      % verify equal
      testCase.verifyEqual(LY.Fx,LYe.Fx,...
        'RelTol',RelTol)
      
      if testCase.verbose
        clf;
        plot(LYe.t,LYe.zA,'displayname',sprintf('nu = %d',nu_));
        hold on; drawnow;
      end
      
      for ieigs = {50,30,10}
        nu_ = ieigs{1}; 
        fprintf('Run case with selu = ''e'', nu = %d\n',nu_)
        % try fewer eigenvalues
        [~,~,LYe] = fge(testCase.tok,testCase.shot,t,...
          'insrc','liu',...
          'tolF',testCase.solvetol,...
          'selu','e','nu',nu_,...
          'debug',testCase.verbose); 
        if testCase.verbose
          plot(LYe.t,LYe.zA,'displayname',sprintf('nu = %d',nu_)); drawnow;
        end
      end
      
      if testCase.verbose
        title('zA for various #eigenmodes');
        xlabel('time'); ylabel('zA [m]');
        legend show; legend(gca,'location','best')
      end
    end

    function gmres_accuracy(testCase)
      dt_list = testCase.dt;

      % setup FGE: use no preconditioner and large mkryl so that GMRES
      % performance matters most
      [L,LX0] = fge(testCase.tok, testCase.shot,0,...
        'insrc','liu',...
        'mkryl',300,...
        'usepreconditioner',0,...
        'tolF',testCase.solvetol,...
        'debug',0);

      ii = numel(dt_list);  % last dt_list should be enough
      fprintf('[ Test %g ]\n\n', dt_list(ii));
      t = testCase.tstart + (0:dt_list(ii):testCase.tinterval);

      LX = fgex(testCase.tok,t,L,LX0);
      % Some step input (otherwise simulation is static)
      LX.Va(2,:) = LX.Va(2,:) + 50;

      algolist = {'sim','giv','aya'};
      telaps = zeros(numel(algolist),1); % init
      for ialgo = 1:numel(algolist) % loop over algorithms to try
        algoGMRES = algolist{ialgo};

        L.P.algoGMRES = algoGMRES;
        start_time=tic;
        LY = fget(L,LX);
        telaps(ialgo) = toc(start_time);

        if ialgo==1
          LY0 = LY;
        else
          tol = sqrt(eps);
          testCase.verifyEqual(LY.rB, LY0.rB, 'AbsTol', tol);
          testCase.verifyEqual(LY.zB, LY0.zB, 'AbsTol', tol);
          testCase.verifyEqual(LY.FB, LY0.FB, 'AbsTol', tol);
          testCase.verifyEqual(LY.zA, LY0.zA, 'AbsTol', tol);
          testCase.verifyEqual(LY.FA, LY0.FA, 'AbsTol', tol);
          testCase.verifyEqual(LY.Fx, LY0.Fx, 'AbsTol', tol);
        end
      end

      fprintf('-- Comparitive times: SIM: %f[s], GIV: %f[s] AYA: %f[s] [%.2fx, %.2fx]\n',...
        telaps(1), telaps(2), telaps(3), telaps(1)/telaps(2), telaps(1)/telaps(3));
    end

    function compare_newton_methods(testCase,icsint)
      % Compare execution times with different methods

      if icsint
        PP = {'icsint',true,'ilim',3};
      else
        PP = {'icsint',false,'ilim',1};
      end

      % setup FGE
      [L,LX0] = fge(testCase.tok, testCase.shot,0,...
        'tolF',testCase.solvetol,PP{:});

      dt_ = testCase.dt(end); % Use smallest dt
      t = testCase.tstart + (0:dt_:testCase.tinterval);

      LX = fgex(testCase.tok,t,L,LX0);
      % Some step input (otherwise simulation is static)
      LX.Va(2,:) = LX.Va(2,:) + 50;

      % First test with jfnk
      t=tic;
      LY1 = fget(L,LX,'algoF','jfnk');
      t1 = toc(t);

      % Then test with newton and full analytical jacobian matrix
      t = tic;
      LY2 = fget(L,LX,'algoF','newton','jacobian_handle',false);
      t2 = toc(t);

      % Finally test with newton and analytical jacobian handle
      t = tic;
      LY3 = fget(L,LX,'algoF','newton','jacobian_handle',true);
      t3 = toc(t);

      T = [t1,sum(LY1.niter),sum(LY1.mkryl),sum(LY1.nfeval);...
           t2,sum(LY2.niter),sum(LY2.mkryl),sum(LY2.nfeval);...
           t3,sum(LY3.niter),sum(LY3.mkryl),sum(LY3.nfeval);];

      fprintf('Comparing FGE[tok=%s,shot=%d,icsint=%d]\n',testCase.tok,testCase.shot,icsint);
      disp(array2table(T,'RowNames',{'jfnk','newton/anajac-full','newton/anajac-handle'},'VariableNames',{'Wall_time','sum_niter','sum_mkryl','sum_nfeval'}.'));

    end
  end
end
