import torch.nn as nn INPUT_KERNEL = 15 SIZE = 32 class CavemanEnhancer(nn.Module): def __init__(self, freq_bins=513): super().__init__() self.net = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=SIZE, kernel_size=INPUT_KERNEL, padding=INPUT_KERNEL // 2, ), nn.ReLU(), nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(SIZE, 1, kernel_size=3, padding=1), ) def forward(self, x): # x: (batch, freq_bins) return self.net(x)