import numpy as np
import matplotlib.pyplot as plt

def show_RF(W, patch_size=16):
    RF = W.reshape(patch_size, patch_size)
    max_value = np.max(np.abs(W))
    plt.imshow(RF, cmap='PiYG', vmin=-max_value, vmax=max_value)
    plt.colorbar()
    plt.show()

def show_RFs(Ws, n_col=5, title='', patch_size=16):
    n = len(Ws)
    n_row = int(np.ceil(n / n_col))
    fig, axs = plt.subplots(nrows=n_row, ncols=n_col, sharex=True, sharey=True, figsize=(2*n_col,2*n_row))

    max_value = np.max(np.abs(Ws))

    for i, ax in enumerate(axs.flat):
        display = i < len(Ws)
        W = Ws[i].reshape(patch_size, patch_size) if display else np.zeros((patch_size, patch_size))
        display_img = ax.imshow(W, cmap='PiYG', vmin=-max_value, vmax=max_value)
        ax.set_visible(display)
        ax.axis('off')

    fig.colorbar(display_img, ax=axs, orientation='horizontal')
    fig.suptitle(title)
    plt.show()