classdef meqplottraces < handle

  properties (SetAccess = protected)
    % read-accessible to the user
    traceinfo  cell = {} % cell array with traces to plot
    scaleinfo  cell = {}; % scale per trace (default=1)
    ylabelinfo cell = {}; % y label info
    yliminfo   cell = {}; % [min,max] Y axes limits per trace (default: auto)
    titleinfo  cell = {}; % optinal axes title info
    hax % axes handles
    parent % parent figure handle
  end
  
  properties (Access = protected)
    % derived from data
    timetraces cell % cell with arrays of time traces to plot (per input LY)
    time_grid  double % time grid from LY.t
    ylabels cell % cell array with labels per time trace
    scales  cell % cell array with scale info per time trace
    ylims   cell % y limit info
    titles  cell % title info
  end

  properties
    % gettable/settable by user
    time_index % current time index selected
  end
  
  events
    select_time_slice
  end

  methods

    function obj = meqplottraces(L,LY,varargin)
      % MEQPLOTTRACES Plot LY time traces
      %
      % obj = meqplottraces(L,LY); % plot and return object, plots just Ip by default
      % meqplottraces(L,{LY1,LY2}); % multiple LY structures overlaid
      % meqplottraces(L,LY,'hax',hax); % re-use axes
      % meqplottraces(L,LY,'parent',gcf); % pass parent figure
      %
      % for use in combination with GUIs: 
      %  notifies event 'select_time_slice' when user selects a time point.
      %  has settable property obj.timeindex to get/set the currently
      %  selected time.
      %
      % properties, can be set via constructor parameter-value arguments:
      %   hax       : axes handles - default: autogenerated if not passed
      %   parent    : parent figure - default: gcf()
      %   traceinfo : cell array indicating which traces to plot
      %               cells can contain: 
      %                    * strings (matching LY fields)
      %                    * function handles of the signature d=fh(LY) that
      %                      returns an array of time traces 
      %                      e.g. @(LY) LY.Ia returns Ia.
      %               NB: a single traceinfo entry can result in multiple
      %               time traces being plotted, e.g. 'Ia' will plot L.G.na traces
      %   ylabelinfo: cell array with y label information, if string,
      %               attempts to find L.G.(str) or uses it directly. If cell, uses it
      %               directly.
      %   scaleinfo : cell array with scaling factor per trace
      %   yliminfo  : cell array with entries of form [ymin,ymax] setting axes Y limits
      %   titleinfo : cell array with title for this (group of) time traces
      %
      % [+MEQ MatlabEQuilibrium Toolbox+]

      %    Copyright 2022-2025 Swiss Plasma Center EPFL
      %
      %   Licensed under the Apache License, Version 2.0 (the "License");
      %   you may not use this file except in compliance with the License.
      %   You may obtain a copy of the License at
      %
      %       http://www.apache.org/licenses/LICENSE-2.0
      %
      %   Unless required by applicable law or agreed to in writing, software
      %   distributed under the License is distributed on an "AS IS" BASIS,
      %   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      %   See the License for the specific language governing permissions and
      %   limitations under the License.

      % defaults
      hf = [];
      if numel(varargin) == 0
        obj.traceinfo = {'Ip'}; % basic default
      end

      % parse inputs
      for ii = 1:2:numel(varargin)
        % set properties as passed by user
        mypar = varargin{ii}; myval = varargin{ii+1};
        if isprop(obj,mypar)
          obj.(mypar) = myval;
        else
          error('unsupported input parameter %s',mypar)
        end
      end

      % evaluate time traces and other object info
      [obj.timetraces,obj.time_grid,ntrs] = obj.eval_traces(LY);

      % check and set scale and limit defaults
      [obj.ylabels,obj.titles,obj.scales,obj.ylims] = ...
        eval_timetrace_info(obj,L,ntrs);

      % total number of time trace plots
      np = size(obj.timetraces{1},1);

      nc = min(np,2); % number of plot columns
      nr = ceil(np/nc); % number of rows

      % create axes handles if haven't been passed from the outside
      if isempty(obj.hax) || ~all(isgraphics(obj.hax(:),'axes'))
        % defaults
        ru = 0.05; rl = 0.05; % row upper/lower gap
        ci = 0.08; co = 0.08; % column inner/outer gap
        fr = 0.01;
        fc = 0.08; % gap
        if isempty(hf), hf=gcf; end
        obj.hax = obj.create_axes(hf,nr,nc,ru,rl,ci,co,fr,fc);
        hf.MenuBar = 'none';
      else
        assert(numel(obj.hax)==np,...
          'expected %d axes handles, found %d',np,numel(obj.hax))
      end

      obj.plot_traces();

      % set axes callbacks

      % callback when clicking on axes
      for iax = 1:numel(obj.hax)
        ax = obj.hax(iax);
        ax.ButtonDownFcn = @(src,event) click_on_axes(obj,src,event);

        % disable interactivity and toolbar
        if ~verLessThan('matlab','9.5')
          disableDefaultInteractivity(ax); % turn off the little axes menu
        end
      end
    end

    function set.time_index(obj,it)
      % set a time index
      obj.clamp_time_index(it);
      obj.time_index = it;
      obj.draw_time_indicators;
    end
  end


  methods(Hidden)
    function it = clamp_time_index(obj,it)
      % clamp time index between 1 and the number of time points
      it = min(max(it,numel(obj.time_grid)),1);
    end

    function draw_time_indicators(obj)
      % draw vertical lines at the selected time index
      t = obj.time_grid(obj.time_index);

      for ii=1:numel(obj.hax(:))
        ax = obj.hax(ii);
        lim = ax.YLim;

        hpp = findobj(ax,'Tag','Time-indicator');

        if ~isempty(hpp)
          % update
          hpp.XData = t*[1 1];
        else
          % vertical dashed line
          plot(ax,t*[1 1],lim,...
            'color','k','linestyle','--',...
            'Tag','Time-indicator');
        end
      end
    end

    function click_on_axes(obj,src,~)
      % called when user clicks on axes
      % finds the corresponding time index and plots vertical lines at the
      % closest selected time.
      cpt = src.CurrentPoint;
      t = cpt(1,1);
      obj.time_index = time2ind(obj,t); % set to closest index
      obj.draw_time_indicators(); % update vertical line indicators
      notify(obj,'select_time_slice'); % notify listeners
    end

    function it = time2ind(obj,time)
      % find time index closest to a given time (in second)
      it = iround(obj.time_grid,time); % closest time index
    end

    function [timetraces,time_grid,ntrs] = eval_traces(obj,LYs)
      % wrapper around eval_traces_one_LY for multiple LYs
      % returns:
      %   * timetraces: a cell array containg time trace data for each LY
      %   * time_grid: array of times
      %   * ntrs: array with number of time traces per group

      if isstruct(LYs) % single LY
        time_grid = LYs.t;
        [timetraces{1},ntrs] = eval_traces_one_LY(obj,LYs);
      elseif iscell(LYs) % multiple LY
        timetraces = cell(numel(LYs),1);
        time_grid = LYs{1}.t;
        for ii=1:numel(LYs)
          assert(isequal(LYs{ii}.t,time_grid),'can only plot with equal time grid for now')
          LY = LYs{ii};
          [timetraces{ii},ntrs] = eval_traces_one_LY(obj,LY);
        end
      end
    end

    function [timetraces,ntrs] = eval_traces_one_LY(obj,LY)
      % extracts desired time traces to be plotted from LY
      % returns:
      %     timetraces: numeric array with time trace data for one LY
      %     ntrs: array with number of time traces per group

      assert(isstruct(LY) && isfield(LY,'t'),...
        'unexpected input, is this an LY structure?');

      nt = numel(LY.t);
      timetraces = zeros(0,nt); 
      ntrs = [];

      for ii=1:numel(obj.traceinfo)
        mytr = obj.traceinfo{ii};
        if ischar(mytr) || isstring(mytr) && isfield(LY,mytr)
          mytimetrace = LY.(mytr);
        elseif isa(mytr,'function_handle')
          mytimetrace = feval(mytr,LY); % evalute function handle
          assert(size(mytimetrace,2) == nt,...
            'function handle must evaluate to array with numel(LY.t) columns')
        else
          disp(mytr)
          error('could not parse traces cell index %d',ii)
        end
        ntrs = [ntrs,size(mytimetrace,1)]; %#ok<AGROW> 
        timetraces = [timetraces; mytimetrace]; %#ok<AGROW> 
      end

    end

    function [ylabels,titles,scales,ylimits] = eval_timetrace_info(obj,L,ntrs)

      nti = numel(obj.traceinfo);

      ylabels = cell(nti,1);
      titles = ylabels; scales = ylabels; ylimits = ylabels;

      for ii=1:nti
        % time trace info
        mytr = obj.traceinfo{ii};
        ntr = ntrs(ii);

        % Get plot details
        ylabels{ii} = obj.get_ylabel(L,mytr,ii,ntr);
        titles{ii}  = obj.get_title(mytr,ii,ntr);
        scales{ii}  = obj.get_scale(ii,ntr);
        ylimits{ii} = obj.get_ylim(ii,ntr);
      end

      % Concatenate them into a column cell array
      ylabels = vertcat(ylabels{:});
      titles  = vertcat(titles{:});
      scales  = vertcat(scales{:});
      ylimits = vertcat(ylimits{:});
    end

    function my_ylabel = get_ylabel(obj,L,mytr,ii,ntr)
      % get label strings for given mytr
      if ~isempty(obj.ylabelinfo) && ~isempty(obj.ylabelinfo{ii})
        % user passed dim information
        my_ylabelinfo = obj.ylabelinfo{ii};
        if (ischar(my_ylabelinfo) || isstring(my_ylabelinfo))
          if isfield(L.G,my_ylabelinfo)
            % exists as L.G.(mydim)
            my_ylabel = L.G.(my_ylabelinfo);
          else % assigned directly
            my_ylabel = repmat({char(my_ylabelinfo)},ntr,1);
          end
        elseif iscell(my_ylabelinfo)
          % passed directly
          assert(numel(my_ylabelinfo)==ntr,...
            'expected %d labels for time trace %d, found %d',...
            ntr,ii,numel(my_ylabelinfo))
          my_ylabel = my_ylabelinfo;
        end
      else % attempt to generate directly from mytr
        if isa(mytr,'function_handle')
          % convert function handle to string
          my_ylabel = repmat({char(func2str(mytr))},ntr,1);
        elseif ischar(mytr) || isstring(mytr)
          % put string
          my_ylabel = repmat({char(mytr)},ntr,1);
        else
          error('not handled case for mytr of class %s',class(mytr))
        end
      end

      % final check
      assert(numel(my_ylabel)==ntr,...
        'expected %d y labels, found %d',...
        ntr,numel(my_ylabel))
    end

    function my_title = get_title(obj,mytr,ii,ntr)
      % get title strings for given mytr
      if ~isempty(obj.titleinfo) && ~isempty(obj.titleinfo{ii})
        % user passed dim information
        mytitleinfo = obj.titleinfo{ii};
        if ischar(mytitleinfo) || isstring(mytitleinfo)
          % passed directly
          % assign only to first one in case of multiple traces
          my_title = cell(ntr,1);
          my_title{1} = char(mytitleinfo);
        elseif iscell(mytitleinfo)
          % passed directly
          my_title = mytitleinfo;
        else
          error('unhandled case for mytitleinfo')
        end
      else
        my_title = cell(ntr,1);
        if isstring(mytr) || ischar(mytr)
          % determine directly from mytr
          my_title{1} = char(mytr);
        else
          % no title
        end
      end
    end

    function myscale = get_scale(obj,ii,ntr)
      myscale = cell(ntr,1); % init
      if ~isempty(obj.scaleinfo) && ~isempty(obj.scaleinfo{ii})
        % user passed scale information
        myscaleinfo = obj.scaleinfo{ii};
        assert(isscalar(myscaleinfo) && isfinite(myscaleinfo) && myscaleinfo>0,...
          'expected finite, scalar, positive scaleinfo')
        for ii=1:ntr
          myscale{ii} = myscaleinfo; % fill all with this value
        end
      else
        myscale = repmat({1},ntr,1); % default scale 1
      end
    end

    function my_ylim = get_ylim(obj,ii,ntr)
      my_ylim = cell(ntr,1); % init & default
      if ~isempty(obj.yliminfo) && ~isempty(obj.yliminfo{ii})
        % user passed y axes limit information
        myliminfo = obj.yliminfo{ii};
        assert(size(myliminfo,2)==2,...
          'expected liminfo like [Ymin,Ymax]')
        if size(myliminfo,1) == 1
          my_ylim(:) = {myliminfo}; % fill all with this value
        else
          error('unexpected/untreated number of rows for ylim')
        end
      end
    end

    function plot_traces(obj)
      % plot all the traces

      % loop over rows and columns
      ntr = size(obj.timetraces{1},1); % total nr of traces to plot
      nax = numel(obj.hax);
      assert(nax >= ntr,'found %d axes, but need to plot %d traces',nax,ntr);

      for kk = 1:ntr
        ax = obj.hax(kk);

        % data
        scal = obj.scales{kk};
        for dd = 1:numel(obj.timetraces)
          % loop over timetraces from each LY structure
          y = obj.timetraces{dd}(kk,:)*scal;
          t = obj.time_grid;
          hold(ax,'on')
          plot(ax,t,y,'.-'); 
        end

        % Title
        titstr = obj.titles{kk};
        if ~isempty(titstr)
          title(ax,titstr);
        end
        % Y limits
        if ~isempty(obj.ylims{kk})
          ax.YLim = obj.ylims{kk}*scal;
        end

        % Y label
        ystr = obj.ylabels{kk};
        if ~isempty(ystr)
          ylabel(ax,ystr,'interpreter','none');
        end
        hold(ax,'on');
      end
    end

  end % Methods

  methods(Static)
    function hax = create_axes(hf,nr,nc,ru,rl,ci,co,fr,fc)
      % auxiliary function to create axes
      % inputs:
      %   hf: parent figure handle
      %   nr, nc: number of rows/columns
      %   ru,rl: upper/lower gap for rows [unit: normalized]
      %   ci,co: inner/outer gap for columns [unit: normalized]
      %   fr, fc: gap between each row/column [unit: normalized]

      % height/width of each plot
      h = (1-(ru+rl)-fr*(nr-1))/nr;
      w = (1-(ci+co)-fc*(nc-1))/nc;

      % init hax
      hax = gobjects(nr*nc,1);
      kk = 0;
      for ic = 1:nc
        for ir = 1:nr
          % x,y offset of each axis
          dr = (nr - ir) * (h + fr);
          dc = (ic - 1 ) * (w + fc);

          kk = kk+1;
          % normalized units are assumed for subplot position
          pos = [ci+dc,ru+dr,w,h];
          ax = subplot('Position',pos,'Parent',hf);
          % X tick label only at 'bottom' axes, remove from others
          if ir<nr
            ax.XTickLabel = '';
          end
          % Store handle
          hax(kk) = ax;
        end
      end

      % Link X-axis (time) limits
      linkaxes(hax,'x');
    end
  end
end
