function [State,Logging,Stop,NoSolution] = fgetk_environment(State,Actions,Parameters,callmethod)
% 'Environment' implementation of fgetk for learning purposes
% State: Environment state passed to next time step
% Actions: Input actions from controller/agent
% Parameters: Parameters for environment (loaded only at init)
% Callmethod: 'init' for initialization. 'step' if not defined.
% See test_fgetk_environment.m for an example call
%
% [+MEQ MatlabEQuilibrium Toolbox+] Swiss Plasma Center EPFL Lausanne 2022. All rights reserved.

persistent L LYt iout Prec dt

if nargin<4
  init = false; % default: no init
elseif ischar(callmethod) && isequal(callmethod,'init')
  init = true;
end

if init
  % Initialize
    [State,~,L,LY,dt] = ...
      init_fgetk_environment(Parameters);
    xnl_new = State.xnl;
    LYt = LY;
    alarm = false; LXt=meqxk(Parameters.LX,1); LXt.dt=meqdt(Parameters.LX.t,1);
    % get the input mapping from KHybrid.OutputName to L.G.dima
    iout = make_output_mapping(L.G.dima,Parameters.Khybrid.OutputName);
    % Preconditioner is constant, so only compute once.
    if isempty(Prec)
      if L.P.usepreconditioner
        % Preconditioner is exported: Load if exists.
        if isfield(L, 'Prec')
          Prec = L.Prec;
        else
          Prec  = fgepre(L,LXt.dt);
        end
      else
        Prec = 1;
      end
    end

else
  % Step
  % Get non-controlled inputs for this iteration counter from LX
  it = State.it;
  max_steps = numel(Parameters.LX.t);

  % extract desired time slice
  LXt1 = meqxk(Parameters.LX,min(it+1,max_steps));

  if it+1>max_steps
    fprintf(['WARNING - fgetk_environment: reached end of LX, ',...
             'using LX(%d) at iteration = %d\n'], max_steps, it+1);
    % Correct time stamp.
    LXt1.t = LYt.t + dt;
  end

  assert(abs(LXt1.t - LYt.t - dt) < 1e-12, ...
         'dt in LX does not match initial dt. dt needs to be constant.')
  LXt1.dt = dt;

  % We want all environments to 'look like the real plant', i.e. the interface as
  % exposed by DM-magcontrol.
  % This means Actions are in KHybrid order and length, and not L.G.dima order.
  % So convert to L.G.dima conventions.
  Va_lgdima = Actions.Va(iout);

  %% Power supply
  [State.psstate,Vact] = meqps(L.G,State.psstate,LXt1.dt,it,Va_lgdima);

  LXt1.Va = Vact; % Controller Voltage inputs

  [xnl_new,LYt1] = fgetk_implicit(State.xnl,L,LXt1,LYt,Prec);

  % Check coil current etc limits
  [State.Tstate,LYt1,alarm] = meqlim(L,LYt1,LXt1.dt,it,State.Tstate);

  % Va should always be stored for timing consistency
  LYt.Va = LXt1.Va;

  LY = LYt;
  LYt = LYt1; % next step

  State.it = State.it + 1; % increment iteration counter
end

% Return previous time step. Should always be valid. Equates to one sim-cycle delay.
Logging.LY = LY;

% Consider a high residual as no solution.
NoSolution = LYt.res > 1e3; 

% Only update our initial guess if we found an acceptable solution. Otherwise reuse previous one.
if LYt.isconverged
   fprintf("Converged - updating initial guess\n")
   State.xnl = xnl_new;
else
   fprintf("No convergence - not updating initial guess.\n.")
end

%% Stopping criteria
Stop = alarm || q95_outside_bounds(L, LYt);

end

function [State,Actions,L,LY0,dt] = ...
  init_fgetk_environment(Parameters)
L = Parameters.L;

LX0 = meqxk(Parameters.LX,1); % LX is now already a converged initial state thanks to fgel
LY0 = fget(L,LX0); % pass through one step of fget to get right initial state

if isfield(L.P,'q95min') && ~isfield(LY0, 'q95')
  error("Lower limit for q95 is set but q95 is not computed. Set L.P.iterq>0.")
end

% Initial state
State.it = 1;
Actions.Va = zeros(20,1);

dt = meqdt(Parameters.LX.t,1);

%% Init power supply
[State.psstate,~] = meqps(L.G,[],dt,0,Actions.Va);

%% Init protection limits
[State.Tstate,LY0,~] = meqlim(L,LY0,dt,0);

% nonlinear iterator state
State.xnl([L.ind.iy,L.ind.ig,L.ind.ia,L.ind.iu],1) =  ...
  [LY0.Iy(:);LY0.ag(:);LY0.Ia;LY0.Iu]; % Initialization for non-linear operator

end

function alarm = q95_outside_bounds(L, LYt)
alarm = false;
if ~isfield(L.P,'q95min')
  return
end

disp(['LYt.q95: ' num2str(LYt.q95)])
disp(['lower bound is: ' num2str(L.P.q95min)])
if any(LYt.q95 < L.P.q95min)
iviol = find(LYt.q95<L.P.q95min);
for ii=num(iviol)
  fprintf('q95=%2.2f is lower than limit %2.2f. Terminating\n',LYt.q95(iviol(ii)),L.P.q95min);
end
  alarm = true;
end

end
