function [ output_args ] = QLKNN_plot_IO(rho, NNin, NNout, QLKNNnet, QLKNNparams )
%QLKNN_PLOT_IO Summary of this function goes here
%   Detailed explanation goes here
if strcmp(QLKNNnet.name, 'QLKNN10D')
    input_names = {'Zeff', 'Ati', 'Ate', 'An', 'q', 'smag', 'x', 'Ti_Te', 'log10(Nustar)'};
    input_names{end+1} = 'gammaE';
    
    input_plot_idx = {[2:3], 4, 5, 6, [1 8], 7, 9, 10};
    output_names = {'q_e', 'q_i', '\Gamma_e', 'D_e', 'V_e', 'D_i', 'V_i'};
    output_plot_idx = {1, 2};
    
    %ii=30:40; disp([NNout(ii,1),NNout(ii,2)]) % tmp deleteme
    if QLKNNparams.use_effective_diffusivity
        output_plot_idx{end+1} = 3;
    else
        if QLKNNparams.use_ion_diffusivity_networks
            output_plot_idx{end+1} = [6, 7];
        else
            output_plot_idx{end+1} = [4, 5];
        end
    end
elseif strcmp(QLKNNnet.name, 'QLKNN4Dkin')
    input_names = {'q', 'smag', 'Ti_{T_e}', 'Ati'};
    input_plot_idx = {1, 2, 3, 4};
    output_names = {'\chi_e', '\chi_i', 'D_e', 'V_e'};
    output_plot_idx = {1, 2, [3, 4]};
end
% order of outputs: [qe, qi, Gamma_e, De, Ve, Di, Vi], in GB units
[hax] = create_or_get_axes('QLKNN in');
for ii = 1:numel(input_plot_idx)
    ax = hax(ii);
    plot_input(ax, rho, input_plot_idx{ii}, NNin, QLKNNnet, QLKNNparams, input_names)
end

if isfield(QLKNNparams, 'channel_tag') && strcmp(QLKNNparams.channel_tag, 'chi_e')
    outname = 'QLKNN heat out';
elseif isfield(QLKNNparams, 'channel_tag') && strcmp(QLKNNparams.channel_tag, 'vpdn_e')
    outname = 'QLKNN part out';
else
    outname = 'QLKNN out';
end
[hax] = create_or_get_axes(outname);
for ii = 1:numel(output_plot_idx)
    ax = hax(ii);
    plot_output(ax, rho, output_plot_idx{ii}, NNout, QLKNNnet, output_names)
end

return

function plot_output(ax, rho, output_idx, NNoutput, QLKNNnet, output_names)
markers =     {'o', '+', 'x'};
labels = '';
ii = 1;
rho = rho(1:numel(NNoutput(:, 1)));
for outp = output_idx
    name = output_names{outp};

    scatter(ax, rho, NNoutput(:, outp), 'DisplayName', name, 'Marker', markers{ii})
    labels = strcat(labels, name, '[', markers{ii}, '] ');
    hold(ax, 'on');
    ii = ii + 1;
end
if true
    labelff(ax,labels,'y');
end
return

function plot_input(ax, rho, input_idx, NNinputs, QLKNNnet, QLKNNparams, input_names)
markers =     {'o', '+', 'x'};
colors =      {'r', 'g', 'b'};

ii = 1;
labels = '';
rho = rho(1:numel(NNinputs(:, 1)));
for inp = input_idx
    name = input_names{inp};
    scatter(ax, rho, NNinputs(:, inp), 'DisplayName', name, 'Marker', markers{ii}, 'MarkerEdgeColor', colors{ii})
    labels = strcat(labels, name, '[', markers{ii}, '] ');
    hold(ax, 'on');
    ii = ii + 1;
end
if true
    labelff(ax,[labels],'y');
end
ii = 1;
for inp = input_idx
    if inp < 10
        plot(ax, rho, repmat(QLKNNparams.min_input(inp), [1 size(NNinputs, 1)]), '--', 'Color', colors{ii})
        plot(ax, rho, repmat(QLKNNparams.max_input(inp), [1 size(NNinputs, 1)]), '--', 'Color', colors{ii})
    end
    ii = ii + 1;
end

function [hax] = create_or_get_axes(figname)
hf = findall(0,'Name',figname,'Type','figure');
if isempty(hf) || ~ishandle(hf)
    hf = figure('Name',figname,'NumberTitle','off','handlevisibility','callback');
end

hax = sort(findobj(get(hf,'children'),'type','axes'));
if isempty(hax)
  hax = create_hax(hf,figname);
end

for ii=1:numel(hax)
    cla(hax(ii));
end

return

function hax = create_hax(hf, figname)
if strcmp(figname, 'QLKNN in')
    hax = multiaxes(hf,4,2);
%elseif strcmp(figname, 'QLKNN part out') || strcmp(figname, 'QLKNN heat out')
else
    hax = multiaxes(hf,3,1);
end
return