added new fixed code
This commit is contained in:
31
model.py
Normal file
31
model.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch.nn as nn
|
||||
|
||||
INPUT_KERNEL = 15
|
||||
SIZE = 32
|
||||
|
||||
|
||||
class CavemanEnhancer(nn.Module):
|
||||
def __init__(self, freq_bins=513):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=SIZE,
|
||||
kernel_size=INPUT_KERNEL,
|
||||
padding=INPUT_KERNEL // 2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, freq_bins)
|
||||
return self.net(x)
|
||||
Reference in New Issue
Block a user