classdef fbt_timedependent_tests < meq_test
 % Tests for FBT including passive structures
 %
 % [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.
  
  properties
    verbosity = 0;
  end
  
  properties(TestParameter)
    solvecircuit = {false,true};
    shot = {2,101};
    algoNL = {'all-nl','all-nl-Fx','Newton-GS'};
    useSQP = {true,false};
    timeder = {'centered','backward'};
  end
  
  methods(Test,TestTags = {'fbt'})

    function test_fbt_passive_equivalence(testCase,shot)
      % Check that solving for unconstrained cases with time-dependent
      % or static algorithm gives same answer
      [L,LX,LY1] = fbt('ana',shot,[]);
      [~,~,LY2] = fbt('ana',shot,[],'circuit',true);
      testCase.verifyTrue(meq_test.check_convergence(L,LX.t,LY1),'FBT without circuit equations did not converge');
      testCase.verifyTrue(meq_test.check_convergence(L,LX.t,LY2),'FBT with circuit equations did not converge')
      testCase.verifyEqual(LY1.Fx  , LY2.Fx  , 'AbsTol', L.P.tol*10,'default vs GSpulse run gave different Fx');
      testCase.verifyEqual(LY1.Ia  , LY2.Ia  , 'AbsTol', L.Ia0*L.P.tol*10,'default vs GSpulse run gave different Ia');
      testCase.verifyEqual(LY1.werr, LY2.werr, 'AbsTol', 1e2*L.P.tol*max(abs(LY1.werr)),'default vs GSpulse run gave different werr');
    end

    function test_circuit_equation(testCase)
      debugval = max(0,testCase.verbosity-1);
      [L,LX] = fbt('ana',101,[],...
        'selu','e','nu',30,'izgrid',true,...
        'circuit',true,... % activate passive structure
        'itert',0,'tol',1e-10,...
        'debug',debugval);
      
      nt = numel(LX.t);

      % solve FBT
      LY = fbtt(L,LX);
      
      % Check convergence
      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LY),'FBT did not converge');

      if testCase.verbosity>1
        clf;
        meqmovie(L,LY);
      end
      
      %% Check satisfaction of circuit equation vs direct numerical integration
      Mee = [L.G.Maa,L.G.Mau;L.G.Mau',L.G.Muu];
      Ree = diag([L.G.Ra;L.G.Ru]);

      % Mee*Iedot + Re*Ie + Mey*Iydot = Va

      Ie = zeros(L.ne,numel(LX.t));
      Ie(:,1) = [LY.Ia(:,1);LY.Iu(:,1)];
      dIy = diff(LY.Iy,[],3);
      Va = LY.Va;

      for it=2:nt
        dt = LX.t(it)-LX.t(it-1);
        % Implicit Euler method discretization
        % Mee*(Iek-Iep) + dt*Re*Iek + dt*Mey*Iydot = dt*Va
        % (Mee+dt*Re)*Iek = Mee*Iep -    Mey*dIy + dt*Va
        % Iek =  (Mee+dt*Re) \ (Mee*Iep - Mey*dIy+ dt*Va)
        MM = (Mee + dt*Ree); Ba = eye(L.ne,L.G.na);
        Ie(:,it) = MM\(Mee*Ie(:,it-1) + dt*Ba*Va(:,it) - L.Mey*reshape(dIy(:,:,it-1),L.ny,1));
      end
      Ia = Ie(1:L.G.na,:); Iu = Ie(L.G.na + (1:L.G.nu),:);
      tol = 1e-4;
      testCase.verifyEqual([LY.Ia;LY.Iu],Ie,'AbsTol',tol,'Circuit equation direct integration gave different Ie'); % essentially the same calculation so error only due to FBT tolerances

      %% Check equation residual
      res = zeros(L.ne,nt-1); % init
      Ie = [LY.Ia;LY.Iu];
      for it = 2:nt
        dt = LX.t(it)-LX.t(it-1);
        res(:,it) = Mee*(Ie(:,it)-Ie(:,it-1))/dt + Ree*Ie(:,it) + L.Mey*reshape(dIy(:,:,it-1),L.ny,1)/dt - eye(L.ne,L.G.na)*Va(:,it);
      end
      testCase.verifyLessThan(abs(res),tol,sprintf('Circuit equation residual is too high at time step %d',it))

      if testCase.verbosity>0
        clf;
        subplot(121)
        plot(LY.t,Ia,'b-x'); hold on;
        plot(LY.t,LY.Ia,'k--o'); hold on;
        title('Ia')

        subplot(122)
        plot(LY.t,Iu,'b-x'); hold on;
        plot(LY.t,LY.Iu,'k--o'); hold on;
        title('Iu')
      end
    end
    
    function test_vessel_penalization(testCase)
      % Scan passive current dissipation penalty (dispas) and check that
      % the expected cost function terms increase/decrease
      
      debugval = max(0,testCase.verbosity-1);
      dt = 1e-3; t = 0:dt:0.01; % time step and vector
      [L,LX] = fbt('ana',2,t,...
        'selu','e','nu',30,'izgrid',true,'debug',debugval,...
        'itert',10,'circuit',true); % activate circuit equation
      LX.t = t;

      gpuds = 1.3e4*sqrt([0.1,1,10]); % scan passive current penalty
      for ii=1:numel(gpuds)
        LX.gpud = gpuds(ii)*LX.Ip/L.Ip0; % Assign passive current penalty
        LX = fbtx(L,LX);
        
        LY(ii) = fbtt(L,LX);
        
        % check that vessel current penalty increases
        if ii>1
          zerrIu = errIu; zerrIa = errIa; % store previous
        end
        % get unweighted error norm per equation type
        IDpassives = LY(ii).ID==L.dimID.Passives;
        IDcoils    = LY(ii).ID==L.dimID.Coils   ;
        IDflux     = LY(ii).ID==L.dimID.Flux    ;

        errIu = sqrt(sum((LY(ii).uerr .*IDpassives).^2,1,'omitnan')./sum(IDpassives,1));
        errIa = sqrt(sum((LY(ii).uerr .*IDcoils   ).^2,1,'omitnan')./sum(IDcoils   ,1));
        errF  = sqrt(sum((LY(ii).uerr .*IDflux    ).^2,1,'omitnan')./sum(IDflux   ,1));

        if ii>1
          derrIu = errIu-zerrIu; derrIu = sum(derrIu,'omitnan');
          testCase.verifyGreaterThanOrEqual(derrIu,0,'Passive error should be increasing with gpud')
          testCase.verifyLessThanOrEqual(sum(errIa),sum(zerrIa),'Coil    error should be decreasing with gpud')
        end
        
        if testCase.verbosity>0
          figure(1); if ii==1, clf; end
          subplot(411)
          plot(LY(ii).t,sum(LY(ii).Ia)); hold on;
          title('sum of active curents');
          subplot(412)
          plot(LY(ii).t,sum(LY(ii).Iv),'displayname',sprintf('gpud=%g',gpuds(ii))); hold on;
          title('Total vessel current')
          legend('show');
          subplot(413)
          plot(LY(ii).t,errIu); title('Passive current cost term'); hold on;
          subplot(414)
          plot(LY(ii).t,errF); title('Flux error cost term'); hold on;
        end
        
        if testCase.verbosity>1
          figure(2); 
          for kt=1:numel(LY(ii).t)
            clf;
            fbtplot(L,meqxk(LX,kt),meqxk(LY(ii),kt));
            drawnow;
          end
        end
      end
      
    end
    
    function test_voltage_cost(testCase)
      
      % Run time-dependent anamak case
      [L,LX] = fbt('ana',101,[],'selu','e','nu',50,'circuit',true);

      gpad = 150*logspace(2,0,3); % scan circuit voltage penalty
      
      for ii=1:numel(gpad)
        if ii==1, sty='r'; elseif ii==2, sty='k--'; else, sty = 'b:'; end
        
        LX.gpad = gpad(ii)*LX.Ip/L.Ip0; % Assign voltage penalty
        LX = fbtx(L,LX);

        LY = fbtt(L,LX);
        conv = meq_test.check_convergence(L,LX.t,LY);
        testCase.verifyTrue(conv,sprintf('some time slices did not converge for gpad=%f',gpad(ii)));
        if ~conv, continue; end
        if ii>1, zserrF = sum(errF); zserrVa=sum(errVa); end
        errF  = sum((LY.uerr.*(LY.ID==L.dimID.Flux)).^2,'omitnan');
        errVa = sum((LY.Va(:,2:end) - LX.gpaa(:,2:end)).^2);
       
        if ii>1
          testCase.verifyLessThan   (sum(errVa),zserrVa,'expected increasing Voltage error with gpad');
          testCase.verifyGreaterThan(sum(errF) ,zserrF ,'expected decreasing Flux error with gpad');
        end
        
        if testCase.verbosity>1
          figure(1);
          subplot(411);
          plot(LY.t,LY.Va,sty); hold on;
          subplot(412);
          plot(LY.t,LY.Ia,sty); hold on;
          subplot(413);
          plot(LY.t(2:end),errVa,sty); hold on;
          subplot(414);
          plot(LY.t,errF,sty); hold on;
          
          if testCase.verbosity>2
            figure(2); clf;
            meqmovie(L,LY);
          end
        end
      end
    end
    
    function test_voltage_constraint(testCase)
      % Check that we can use fbt parameters to set an exact constraint on a given circuit voltage
      [L,LX] = fbt('ana',101,[],'selu','e','nu',30,'circuit',true);
      LX.gpad = 1.5e5*LX.Ip/L.Ip0;
      iconstrain = 1; % coil index to constrain
      tol = 1e-5; % constraint satisfaction tolerance
      for constraint_val = [-50,-40]
        LX.gpae(iconstrain,:) = 0; % exact constraint on this circuit;
        LX.gpaa(iconstrain,:) = constraint_val; % value of constraint
        LX.gpae(:,1) = Inf; % Disable constraint on first slice
        LX = fbtx(L,LX); % Necessary after manual LX update
        LY = fbtt(L,LX);
        testCase.verifyNotEmpty(LY,sprintf('FBT did not converge for Va=%f',constraint_val));
        if isempty(LY), continue; end
        Va1 = L.G.Ra.*LY.Ia; % first voltage
        expected_val = [Va1(iconstrain,1), constraint_val * ones(1,numel(LY.t)-1)];
        testCase.verifyEqual(LY.Va(iconstrain,:),expected_val,'AbsTol',tol,'Va equality constraint not satisfied');
      end
    end
    
    function test_voltage_limits(testCase)
      % Check voltage calculation and constraint imposition, with and without passives
      PP = {'selu','e','nu',30,'circuit',true};
      [L,LX,LY] = fbt('ana',101,[],PP{:},'voltlim',false);
    
      % add voltage limits that would limit the present solution
      L.G.Vamax(:) =  0; L.G.Vamin(:) = -65;
      testCase.verifyFalse(all(all( (LY.Va <=L.G.Vamax) & (LY.Va >=L.G.Vamin) )),...
        'expected some unconstrained eq. to violate Va limits')
      
      % Solving with again with voltlim = false should give same result
      [LYn] = fbtt(L,LX);
      testCase.verifyEqual(LYn.Va,LY.Va,'solving with voltlim=true should have given unconstrained result');

      % constrain voltage limits
      L.P.voltlim = true;

      % re-solve with these voltage limits and check limits are satisfied
      LY2 = fbtt(L,LX);
      etol = 1e-5; % tolerance on constraint satisfaction
      testCase.verifyLessThanOrEqual   (LY2.Va(:,2:end),L.G.Vamax+etol,'expected no constrained eq. to violate Va maximum limits')
      testCase.verifyGreaterThanOrEqual(LY2.Va(:,2:end),L.G.Vamin-etol,'expected no constrained eq. to violate Va minimum limits')
      
      % check that sum of the flux errors is higher now
      err1 = sum((LY .uerr.*(LY .ID==L.dimID.Flux)).^2,'omitnan');
      err2 = sum((LY2.uerr.*(LY2.ID==L.dimID.Flux)).^2,'omitnan');
      testCase.verifyGreaterThanOrEqual(sum(err2),sum(err1),'expected larger flux error for constrained case');
      
      if testCase.verbosity > 1
        %%
        figure(1); clf;
        subplot(311); plot(LY.t,LY.Va,'-',LY2.t,LY2.Va,'k--'); title('Va');
        subplot(312); plot(LY.t,LY.Ia,'-',LY2.t,LY2.Ia,'k--'); title('Ia');
        
        subplot(313); plot(LY.t,err1,'-',LY2.t,err2,'k--'); title('flux error');
        figure(2); clf; it=5;
        meqgplot(L.G,gca,'vl');
        contour(L.rx,L.zx, LY.Fx(:,:,it)- LY.FB(:,it),[0,0],'r'); hold on;
        contour(L.rx,L.zx,LY2.Fx(:,:,it)-LY2.FB(:,it),[0 0],'k');
        axis equal;
      end
    end

    function test_slices_without_controlpoints(testCase)
      % Test with an intermediate time slice that has no equilibrium constraints.
      [L,LX] = fbt('ana',101,[],'selu','e','nu',20,'circuit',true);
      
      nt = 5; % total number of time steps
      tt = linspace(LX.t(1),LX.t(2),nt);
      
      interpmethod = 'linear'; cpts = false; % linear interpolation, no control points
      LX = fbtxinterp(LX,tt,interpmethod,cpts);
      LX = fbtx(L,LX); % consolidate and check LX

      % run FBT
      LY = fbtt(L,LX);

      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LY),'some slices did not converge')
      testCase.assertTrue(all(isnan(LY.Fb(2:end-1))),'expected Fb=nan at slices without control points')
    end

    function test_timeder_constraints(testCase,timeder)
      % Test shot 201 with exact time derivative constraint for Fx at (r0,0)
      [L,LX] = fbt('ana',201,[],'icsint',true,'ilim',3,...
        'circuit',true,'selu','e','nu',20,'timeder',timeder,...
        'tol',5e-4,'niter',20); 

      % Run FBT
      LY = fbtt(L,LX,'useSQP',false); % This particular case does not converge with SQP method

      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LY),'some slices did not converge')

      % Get time derivative stencils
      d1dt = fbttimeder(LX.t,timeder);
      % Get flux values at constraint location
      rn = LX.g1r(:,2);
      zn = LX.g1z(:,2);
      if L.P.MHyinterp
        inM = L.inMfbt;
      else
        inM = qintc(L.P.inp,L.drx,L.dzx);
      end
      Fn = qintmex(L.G.rx,L.G.zx,LY.Fx,rn,zn,inM);
      dFndt = d1dt*reshape(Fn,numel(LX.t),1);

      % Main difference will come from quality of GS solution
      testCase.verifyEqual(dFndt(2:end-1).',LX.g1fa(1,2:end-1),'AbsTol',100*L.P.tol);
    end

    function test_algoNL(testCase,algoNL,useSQP)
      % Test time-dependent FBT with different algoNL for SQP and Picard methods
      [L,LX,LY] = fbt('ana',101,[],...
        'selu','e','nu',20,'circuit',true,...
        'algoNL',algoNL,'useSQP',useSQP);

      testCase.assertTrue(meq_test.check_convergence(L,LX.t,LY),...
        sprintf('some slices did not converge for algoNL=%s and useSQP=%d',algoNL,useSQP))
    end

    function test_verbosity(testCase)
      % Exercise different text/plot debugging levels
      for level = 1:3
        [~,~,~] = fbt('ana',101,[],...
          'selu','e','nu',20,'circuit',true,...
          'debug',level,'debugplot',level);
      end
    end

    function test_fbtxdisp(testCase)
      % test fbtxdisp with time derivative cases
      [L,LX] = fbt('ana',101,[],...
          'selu','e','nu',5,'circuit',true);
      % activate all equations
      LX.gpud(:) = 1; LX.gpad(:) = 1;
      L.G.Vamax(:) =  1e3;
      L.G.Vamin(:) = -1e3;
      L.P.limu(:) =  10e4;
      L.P.liml(:) = -10e4;
      L.P.voltlim = true;

      % Some nonsense constraints to stress-test the display
      r = LX.gpr(1,:); z = LX.gpz(2,:);
      %    fbtgp( X,r,z,b,fa, fb,fe,br,bz,ba,be,cr,cz,ca,ce,vrr,vrz,vzz,ve,timeder)
      LX = fbtgp(LX,r,z,0, 1,0, 1, 0, 0,[], 1, 0, 0,[], 1,  0,  0, [], 1,1);
      LX = fbtgp(LX,r,z,0, 1,0, 1, 0, 0,[], 1, 0, 0,[], 1,  0,  0, [], 1,2);      
      LX = fbtx(L,LX);

      LX.g1ie(:) = 1; LX.g1id(:) = 1; LX.g1ia(:) = 0;
      LX.g1ue(:) = 1; LX.g1ud(:) = 1; LX.g1ua(:) = 0;
      LX.g1ae(:) = 1; LX.g1ad(:) = 1; LX.g1aa(:) = 0;

      fbtxdisp(L,meqxk(LX,1));
    end
  end
end
