function lambda = sol_channel_estMMSE(Rf, N, psd_mask, training_symbols, Ka, delays, sigma2)
%MY_CHANNEL_ESTMMSE Estimate the channel coefficients in the frequency
%domain
%   The channel delays are known. The covariance matrix of channel
%   amplitudes is also assumed to be known.
%   LAMBDA = MY_CHANNEL_ESTMMSE(RF, N, PSD_MASK, ...
%       TRAINING_SYMBOLS, KA, DELAYS, SIGMA2)
%
%       RF: Matrix of TRAINING_SYMBOLS and DATA_SYMBOLS, as returned by
%         OFDM_RX_FRAME.
%       N: number of carriers per OFDM block (power of 2).
%       PSD_MASK: A {0,1}-valued vector of length NUM_CARRIERS, used to
%         to turn off individual carriers.
%       TRAINING_SYMBOLS: vector of symbols known to the receiver, used
%         to estimate the channel. Its length is the number of ones in the
%         PSD_MASK (one training symbol per non-off carrier).
%       KA: the covariance matrix of channel amplitudes.
%       DELAYS: vector containing the delays for each path in the 
%         multipath channel. The delays are expressed in number of 
%         samples, i.e., tau_l/T_s.
%       SIGMA2: noise variance. 
%   LAMBDA: Column vector containing the channel coefficients in the
%   frequency domain. The number of elements in LAMBDA equals the number 
%   of ones in PSD_MASK.


if ~isscalar(N) || N < 0  || mod(N,1) ~= 0
    error('channel_estMMSE:dimensionMismatch', ...
        'N must be a positive scalar integer');
end

if ~isvector(psd_mask)|| numel(psd_mask) ~= N
    error('channel_estMMSE:dimensionMismatch',...
        'PSD_MASK must be a vector of length %d', ...
        N);
end

if sum(psd_mask == 0) + sum(psd_mask == 1) ~= N
    error('channel_estMMSE:invalidMask',...
        'PSD_MASK must be {0,1}-valued');
end

if ~isvector(delays)
    error('channel_estMMSE:dimensionMismatch',...
        'DELAYS must be a vector');
end


num_useful_carriers = sum(psd_mask);


if ~isvector(training_symbols) || ...
        numel(training_symbols) ~= num_useful_carriers
    error('channel_estMMSE:dimensionMismatch', ...
        'PREAMBLE_SYMBOLS must be a vector of length %d', ...
        num_useful_carriers);
end


y = Rf(:,1); % channel output due to training_symbols 
S = diag(training_symbols); % notation as in the lecture notes

% Form the matrix A according to the lecture notes
allFreqIndices = fftshift(-N/2:N/2-1); % [0, ..., (N/2 -1), -N/2, ..., -1]
FreqIndices = allFreqIndices(logical(psd_mask));
exponentsOfA = FreqIndices(:)*delays(:)';
A = exp(-1j*2*pi/N).^exponentsOfA;

% Form the remaining matrices as in the lecture notes
B = Ka;
Kd = A*B*A';
Kdy = Kd*S';
Kz = diag((sigma2)*ones(num_useful_carriers,1));
Ky = S*Kd*S' + Kz;
% Estimate lambda
lambda = Kdy*(Ky\y);
end
