Files
caveman/model.py
2026-01-10 20:35:21 +01:00

32 lines
889 B
Python

import torch.nn as nn
INPUT_KERNEL = 15
SIZE = 32
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=SIZE,
kernel_size=INPUT_KERNEL,
padding=INPUT_KERNEL // 2,
),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)