function [J_, b_, dx, Pinv_] = reduce_system(J, b, Pinv, usemask, mask, Jdiag)
% [J_, b_, Pinv_] = reduce_system(J, b, Pinv, usemask, mask, Jdiag)
% helper function, reducing the system if parts of the jacobian are known
% to only posess diagonal elements.
% We use the information from mask and Jdiag that tells us that after
% grouping the rows and columns of J corresponding to mask=true at
% the end, it has the form [J1,J2;0,D] such that the solution to
% J*dx=b is, after separating dx into [dx1;dx2] and b into [b1;b2]:
%   dx2 = b2./D, dx1 = J1\(b1 - J2*dx2)
% The product J2*dx2 can be estimated using a matrix-free approach if
% needed.
%
% inputs:
%   J:        Function handle or (Sparse) Matrix, linear operator given as
%             function handle or possibly sparse matrix
%   b:        Vector, system rhs
%   Pinv:     function handle, function computing v->inv(M)*v for
%             some preconditioning matrix M
%   usemask:  bool, flag to indicate if reduction should even be done
%   mask:     Vector of bool, Mask indicating rows of J where only diagonal
%             elements reside
%   Jdiag:    Vector, vector containing the diagonal elements of J, where
%             mask==true, otherwise zeros
% returns:
%   J_:       Function handle or (Sparse) Matrix, Jacobian of reduced system
%   b_:       Vector, rhs of reduced system
%   dx:       Vector, partially filled solution vector. Contains zeros
%             where mask==0 and the correct values b(mask)./Jdiag(mask)
%             where mask==1
%   Pinv_:    function handle, function computing v->inv(M)*v for
%             the reduced preconditioning matrix M(~mask, ~mask)
%
% [+MEQ MatlabEQuilibrium Toolbox+]

%    Copyright 2022-2025 Swiss Plasma Center EPFL
%
%   Licensed under the Apache License, Version 2.0 (the "License");
%   you may not use this file except in compliance with the License.
%   You may obtain a copy of the License at
%
%       http://www.apache.org/licenses/LICENSE-2.0
%
%   Unless required by applicable law or agreed to in writing, software
%   distributed under the License is distributed on an "AS IS" BASIS,
%   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%   See the License for the specific language governing permissions and
%   limitations under the License.

% initializing dx output
[n,nb] = size(b);
dx = zeros(n,nb);

% if system shouldnt or couldnt be reduced, just return inputs
if ~usemask || ~any(mask)
  J_ = J; b_ = b; Pinv_ = Pinv;
  return;
end

% Filling reduced DOF entries in dx with correct values
dx(mask,:) = b(mask,:)./Jdiag(mask);

if isa(J,'function_handle')
  % J is given as a function handle
  J_of_dx = J(dx);
  b_ = b(~mask,:) - J_of_dx(~mask,:);

  % get a function handle computing only the reduced direction
  J_ = @(v_) reduced_function_handle(J, v_, mask);

else
  % J is given as a matrix
  b_ = b(~mask,:) - J(~mask, mask) * dx(mask,:);

  % reduce the matrix
  J_ = J(~mask, ~mask);
end

% reduce preconditioner function handle
if isa(Pinv,'function_handle')
  Pinv_ = @(v_) reduced_function_handle(Pinv, v_, mask);
elseif isnumeric(Pinv) && numel(Pinv)>1
  Pinv_ = Pinv(~mask,~mask);
else
  Pinv_ = Pinv;
end
end

function result = reduced_function_handle(func, v, mask)
% result = reduced_function_handle(func, v, mask)
% takes in a function handle representing a mapping from a size-n Vector to a
% size-n output Vector. reduced_function_handle then represents the
% reduction of func to the DOFs specified by mask.
%
% Example: if func represents a matrix multiplication v->M * v, then
% reduced_function_handle(func,...) will be w -> M(~mask,~mask) * w
%
% inputs:
%   func:       function handle. dim-n -> dim-n function that should be
%               reduced
%   v:          Vector, input vector to put extend by zeros and put into
%               func; it should holds that numel(v) == sum(~mask)
%   mask:       Vector of bool, Mask indicating the DOFs which should
%               implicitely be assumed to be zero
% returns:
%   result:     Vector, output from applying the reduced func to v

nx = numel(mask);
nb = size(v,2);

% extend the small reduced v by zeros to full size
v_big = zeros(nx,nb);
v_big(~mask,:) = v;

result = func(v_big);
result = result(~mask,:);
end
