32 lines
889 B
Python
32 lines
889 B
Python
import torch.nn as nn
|
|
|
|
INPUT_KERNEL = 15
|
|
SIZE = 32
|
|
|
|
|
|
class CavemanEnhancer(nn.Module):
|
|
def __init__(self, freq_bins=513):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=SIZE,
|
|
kernel_size=INPUT_KERNEL,
|
|
padding=INPUT_KERNEL // 2,
|
|
),
|
|
nn.ReLU(),
|
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: (batch, freq_bins)
|
|
return self.net(x)
|