added new fixed code
This commit is contained in:
@@ -4,29 +4,21 @@ import librosa
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
from settings import SR, N_FFT
|
||||||
|
from misc import audio_to_logmag
|
||||||
HOP = 512
|
|
||||||
N_FFT = 1024
|
|
||||||
DURATION = 2.0
|
|
||||||
SR = 44100
|
|
||||||
|
|
||||||
|
|
||||||
def audio_to_logmag(audio):
|
|
||||||
# STFT
|
|
||||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
|
||||||
mag = np.abs(stft)
|
|
||||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
|
||||||
return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T)
|
|
||||||
|
|
||||||
|
|
||||||
class WaveformDataset(Dataset):
|
class WaveformDataset(Dataset):
|
||||||
def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4):
|
mean = np.zeros([N_FFT // 2 + 1])
|
||||||
self.cache = dict()
|
std = np.ones([N_FFT // 2 + 1])
|
||||||
|
|
||||||
|
# Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!!
|
||||||
|
# For the purposes of this file, it's the length of the audio clip being selected from the dataset.
|
||||||
|
def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR):
|
||||||
|
self.segment_duration = segment_duration
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
self.lossy_dir = lossy_dir
|
self.lossy_dir = lossy_dir
|
||||||
self.clean_dir = clean_dir
|
self.clean_dir = clean_dir
|
||||||
self.segment_len = int(segment_sec * sr)
|
|
||||||
self.lossy_files = sorted(os.listdir(lossy_dir))
|
self.lossy_files = sorted(os.listdir(lossy_dir))
|
||||||
self.clean_files = sorted(os.listdir(clean_dir))
|
self.clean_files = sorted(os.listdir(clean_dir))
|
||||||
self.file_pairs = [
|
self.file_pairs = [
|
||||||
@@ -51,9 +43,9 @@ class WaveformDataset(Dataset):
|
|||||||
min_len = min(len(lossy), len(clean))
|
min_len = min(len(lossy), len(clean))
|
||||||
lossy, clean = lossy[:min_len], clean[:min_len]
|
lossy, clean = lossy[:min_len], clean[:min_len]
|
||||||
|
|
||||||
# Random 2-second clip
|
# Random clip
|
||||||
|
|
||||||
clip_len = int(DURATION * SR)
|
clip_len = int(self.segment_duration * SR)
|
||||||
if min_len < clip_len:
|
if min_len < clip_len:
|
||||||
# pad if too short
|
# pad if too short
|
||||||
lossy = np.pad(lossy, (0, clip_len - min_len))
|
lossy = np.pad(lossy, (0, clip_len - min_len))
|
||||||
@@ -61,14 +53,21 @@ class WaveformDataset(Dataset):
|
|||||||
start = 0
|
start = 0
|
||||||
else:
|
else:
|
||||||
start = random.randint(0, min_len - clip_len)
|
start = random.randint(0, min_len - clip_len)
|
||||||
|
# start = 0
|
||||||
lossy = lossy[start : start + clip_len]
|
lossy = lossy[start : start + clip_len]
|
||||||
clean = clean[start : start + clip_len]
|
clean = clean[start : start + clip_len]
|
||||||
|
|
||||||
|
logmag_x = audio_to_logmag(lossy)
|
||||||
|
logmag_y = audio_to_logmag(clean)
|
||||||
|
|
||||||
|
logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
||||||
|
logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
||||||
|
|
||||||
ans = (
|
ans = (
|
||||||
torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0),
|
torch.from_numpy(logmag_x_norm).float().unsqueeze(0),
|
||||||
torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0),
|
torch.from_numpy(logmag_y_norm).float().unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[idx] = ans
|
# self.cache[idx] = ans
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|||||||
315
cavemanml.py
315
cavemanml.py
@@ -1,53 +1,120 @@
|
|||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch
|
import torch
|
||||||
|
import soundfile as sf
|
||||||
from caveman_wavedataset import WaveformDataset
|
from caveman_wavedataset import WaveformDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
from misc import audio_to_logmag
|
||||||
|
from settings import N_FFT, HOP, SR
|
||||||
|
from model import CavemanEnhancer
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
CLEAN_DATA_DIR = "./fma_small"
|
CLEAN_DATA_DIR = "./fma_small"
|
||||||
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||||
SR = 44100 # sample rate
|
|
||||||
DURATION = 2.0 # seconds per clip
|
# Duration is the duration of each example to be selected from the dataset.
|
||||||
N_MELS = None # we'll use full STFT for now
|
# Since we are using the FMA dataset, it's max value is 30s.
|
||||||
HOP = 512
|
# From the standpoint of the model and training,
|
||||||
N_FFT = 1024
|
# it should make absolutely no difference on quality,
|
||||||
|
# only on the speed of the training process. If duration is larger,
|
||||||
|
# model is going to be trained on more data per example.
|
||||||
|
# If smaller, less data per example.
|
||||||
|
#
|
||||||
|
# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN
|
||||||
|
# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD
|
||||||
|
# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST
|
||||||
|
# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS
|
||||||
|
# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME.
|
||||||
|
DURATION = 2
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved.
|
||||||
|
EPOCHS = 100
|
||||||
|
PREFETCH = 4
|
||||||
|
|
||||||
|
# stats = torch.load("freq_stats.pth")
|
||||||
|
# freq_mean = stats["mean"].numpy() # (513,)
|
||||||
|
# freq_std = stats["std"].numpy() # (513,)
|
||||||
|
|
||||||
|
freq_mean = np.zeros([N_FFT // 2 + 1])
|
||||||
|
# (513,)
|
||||||
|
freq_std = np.ones([N_FFT // 2 + 1]) # (513,)
|
||||||
|
|
||||||
|
|
||||||
def audio_to_logmag(audio):
|
# freq_mean_torch = stats["mean"] # (513,)
|
||||||
# STFT
|
# freq_std_torch = stats["std"] # (513,)
|
||||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
freq_mean_torch = torch.from_numpy(freq_mean)
|
||||||
mag = np.abs(stft)
|
freq_std_torch = torch.from_numpy(freq_std)
|
||||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
|
||||||
return logmag # shape: (freq_bins, time_frames) = (513, T)
|
|
||||||
|
|
||||||
|
|
||||||
class CavemanEnhancer(nn.Module):
|
def run_example(model_filename, device):
|
||||||
def __init__(self, freq_bins=513):
|
model = CavemanEnhancer().to(device)
|
||||||
super().__init__()
|
model.load_state_dict(
|
||||||
self.net = nn.Sequential(
|
torch.load(model_filename, weights_only=False)["model_state_dict"]
|
||||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(32, 1, kernel_size=3, padding=1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
model.eval()
|
||||||
# x: (batch, freq_bins)
|
enhance_mono(
|
||||||
return self.net(x)
|
model,
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load
|
||||||
|
x, sr = librosa.load(
|
||||||
|
"./examples/mirror_mirror/mirror_mirror.mp3",
|
||||||
|
sr=SR,
|
||||||
|
)
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
# Convert to log-mag
|
||||||
EPOCHS = 100
|
X = audio_to_logmag(x) # (513, T)
|
||||||
|
|
||||||
|
Y_pred = normalize(X)
|
||||||
|
Y_pred = denorm(Y_pred)
|
||||||
|
|
||||||
|
# Clip to valid range (log1p output ≥ 0)
|
||||||
|
Y_pred = np.maximum(X, 0)
|
||||||
|
|
||||||
|
stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP)
|
||||||
|
# Invert log
|
||||||
|
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||||
|
phase_lossy = np.angle(stft)
|
||||||
|
|
||||||
|
# Reconstruct with Griffin-Lim
|
||||||
|
# y_reconstructed = librosa.griffinlim(
|
||||||
|
# mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Combine: enhanced mag + original phase
|
||||||
|
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
|
||||||
|
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
|
||||||
|
|
||||||
|
time = np.minimum(x.shape[0], y_reconstructed.shape[0])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
sf.write(
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
|
||||||
|
y_reconstructed,
|
||||||
|
sr,
|
||||||
|
)
|
||||||
|
|
||||||
|
show_spectrogram(
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
||||||
|
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -58,94 +125,96 @@ def main():
|
|||||||
if ans != "y":
|
if ans != "y":
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
model = CavemanEnhancer().to(device)
|
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
model.load_state_dict(
|
run_example(sys.argv[1], device)
|
||||||
torch.load(sys.argv[1], weights_only=False)["model_state_dict"]
|
exit(0)
|
||||||
)
|
|
||||||
model.eval()
|
|
||||||
enhance_audio(
|
|
||||||
model,
|
|
||||||
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
|
||||||
"./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load
|
|
||||||
x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR)
|
|
||||||
|
|
||||||
# Convert to log-mag
|
|
||||||
X = audio_to_logmag(x) # (513, T)
|
|
||||||
|
|
||||||
# Clip to valid range (log1p output ≥ 0)
|
|
||||||
Y_pred = np.maximum(X, 0)
|
|
||||||
|
|
||||||
# Invert log
|
|
||||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
|
||||||
|
|
||||||
# Reconstruct with Griffin-Lim
|
|
||||||
y_reconstructed = librosa.griffinlim(
|
|
||||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
|
||||||
)
|
|
||||||
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
# Save
|
|
||||||
sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Data
|
# Data
|
||||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR)
|
||||||
|
dataset.mean = freq_mean
|
||||||
|
dataset.std = freq_std
|
||||||
|
|
||||||
|
# separate the test and val data
|
||||||
n_val = int(0.1 * len(dataset))
|
n_val = int(0.1 * len(dataset))
|
||||||
n_train = len(dataset) - n_val
|
n_train = len(dataset) - n_val
|
||||||
|
|
||||||
train_indices = list(range(n_train))
|
train_indices = list(range(n_train))
|
||||||
val_indices = list(range(n_train, len(dataset)))
|
val_indices = list(range(n_train, len(dataset)))
|
||||||
|
|
||||||
train_dataset = torch.utils.data.Subset(dataset, train_indices)
|
train_dataset = torch.utils.data.Subset(dataset, train_indices)
|
||||||
val_dataset = torch.utils.data.Subset(dataset, val_indices)
|
val_dataset = torch.utils.data.Subset(dataset, val_indices)
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
|
prefetch_factor=PREFETCH,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
num_workers=16,
|
num_workers=16,
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=4,
|
batch_size=BATCH_SIZE,
|
||||||
|
prefetch_factor=PREFETCH,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=10,
|
num_workers=16,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
# model
|
||||||
|
model = CavemanEnhancer().to(device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# I am actually not sure it really improves anything, but there is little reason not to keep this, I guess.
|
||||||
scheduler = lr_scheduler.ReduceLROnPlateau(
|
scheduler = lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer,
|
optimizer,
|
||||||
mode="min",
|
mode="min",
|
||||||
factor=0.1,
|
factor=0.1,
|
||||||
patience=5,
|
patience=3,
|
||||||
cooldown=3,
|
# cooldown=3,
|
||||||
threshold=1e-3,
|
# threshold=1e-3,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tqdm import tqdm
|
# Weight: emphasize high frequencies
|
||||||
|
weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x
|
||||||
|
weight = torch.exp(weight)
|
||||||
|
weight = weight.view(1, 513, 1)
|
||||||
|
|
||||||
criterion = nn.L1Loss()
|
def weighted_l1_loss(pred, target):
|
||||||
|
return torch.mean(weight * torch.abs(pred - target))
|
||||||
|
|
||||||
|
# criterion = nn.L1Loss()
|
||||||
|
criterion = weighted_l1_loss
|
||||||
|
# criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
# baseline (doing nothing)
|
||||||
|
# if True:
|
||||||
|
# loss = 0.0
|
||||||
|
# for lossy, clean in tqdm(train_loader, desc="Baseline"):
|
||||||
|
# loss += criterion(clean, lossy)
|
||||||
|
# loss /= len(train_loader)
|
||||||
|
# print(f"baseling loss: {loss:.4f}")
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
for epoch in range(EPOCHS):
|
for epoch in range(EPOCHS):
|
||||||
model.train()
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
for lossy, clean in tqdm(train_loader, desc="Training"):
|
for lossy, clean in tqdm(train_loader, desc="Training"):
|
||||||
lossy, clean = lossy.to(device), clean.to(device)
|
lossy, clean = lossy.to(device), clean.to(device)
|
||||||
|
|
||||||
with autocast():
|
with autocast():
|
||||||
enhanced = model(lossy)
|
enhanced = model(lossy)
|
||||||
loss = criterion(clean, enhanced)
|
loss = criterion(clean, enhanced)
|
||||||
|
train_loss += loss
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
train_loss /= len(train_loader)
|
||||||
|
|
||||||
|
# Validate (per epoch)
|
||||||
model.eval()
|
model.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
@@ -154,17 +223,17 @@ def main():
|
|||||||
lossy, clean = lossy.to(device), clean.to(device)
|
lossy, clean = lossy.to(device), clean.to(device)
|
||||||
output = model(lossy)
|
output = model(lossy)
|
||||||
loss_ = criterion(output, clean)
|
loss_ = criterion(output, clean)
|
||||||
|
|
||||||
total_loss += loss_.item()
|
total_loss += loss_.item()
|
||||||
val_loss = total_loss / len(train_loader)
|
val_loss = total_loss / len(train_loader)
|
||||||
|
|
||||||
scheduler.step(val_loss) # Update learning rate based on validation loss
|
scheduler.step(val_loss) # Update learning rate based on validation loss
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0:
|
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
print(f"LR: {lr:.6f}")
|
print(f"LR: {lr:.6f}")
|
||||||
|
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}")
|
||||||
|
|
||||||
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
|
# Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap.
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
@@ -176,12 +245,86 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def enhance_audio(model, lossy_path, output_path):
|
def file_to_logmag(path):
|
||||||
# Load
|
y, sr = librosa.load(path, sr=SR, mono=True)
|
||||||
x, sr = librosa.load(lossy_path, sr=SR)
|
print(y.shape)
|
||||||
|
return np.squeeze(audio_to_logmag(y))
|
||||||
|
|
||||||
|
|
||||||
|
def show_spectrogram(path1, path2, path3):
|
||||||
|
spectrogram1 = file_to_logmag(path1)
|
||||||
|
spectrogram2 = file_to_logmag(path2)
|
||||||
|
spectrogram3 = file_to_logmag(path3)
|
||||||
|
|
||||||
|
spectrogram1 = normalize(spectrogram1)
|
||||||
|
spectrogram2 = normalize(spectrogram2)
|
||||||
|
spectrogram3 = normalize(spectrogram3)
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# Create a figure with two subplots (1 row, 2 columns)
|
||||||
|
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
||||||
|
|
||||||
|
# Display the first image
|
||||||
|
axes[0].imshow(spectrogram1, aspect="auto", cmap="gray")
|
||||||
|
axes[0].set_title("spectrogram 1")
|
||||||
|
axes[0].axis("off") # Hide axes
|
||||||
|
|
||||||
|
# Display the second image
|
||||||
|
axes[1].imshow(spectrogram2, aspect="auto", cmap="gray")
|
||||||
|
axes[1].set_title("spectrogram 2")
|
||||||
|
axes[1].axis("off")
|
||||||
|
|
||||||
|
# Display the second image
|
||||||
|
axes[2].imshow(spectrogram3, aspect="auto", cmap="gray")
|
||||||
|
axes[2].set_title("spectrogram 3")
|
||||||
|
axes[2].axis("off")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_stereo(model, lossy_path, output_path):
|
||||||
|
# Load stereo audio (returns shape: (2, T) if stereo)
|
||||||
|
y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels
|
||||||
|
|
||||||
|
# Ensure shape is (2, T)
|
||||||
|
if y.ndim == 1:
|
||||||
|
raise ValueError("Input is mono! Expected stereo.")
|
||||||
|
|
||||||
|
y_l = y[0]
|
||||||
|
y_r = y[1]
|
||||||
|
|
||||||
|
y_enhanced_l = enhance_audio(model, y_l, sr)
|
||||||
|
y_enhanced_r = enhance_audio(model, y_r, sr)
|
||||||
|
|
||||||
|
stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r))
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
# Save (soundfile handles (2, T) -> stereo correctly)
|
||||||
|
sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required
|
||||||
|
|
||||||
|
|
||||||
|
# accepts shape (513, T)!!!!
|
||||||
|
def denorm_torch(spectrogram):
|
||||||
|
return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
# accepts shape (513, T)!!!!
|
||||||
|
def normalize(spectrogram):
|
||||||
|
return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
# accepts shape (513, T)!!!!
|
||||||
|
def denorm(spectrogram):
|
||||||
|
return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_audio(model, audio, sr):
|
||||||
# Convert to log-mag
|
# Convert to log-mag
|
||||||
X = audio_to_logmag(x) # (513, T)
|
X = audio_to_logmag(audio) # (513, T)
|
||||||
|
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
|
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
|
||||||
@@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path):
|
|||||||
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
|
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
|
||||||
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
|
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
|
||||||
|
|
||||||
|
Y_pred = denorm(Y_pred)
|
||||||
|
|
||||||
# Clip to valid range (log1p output ≥ 0)
|
# Clip to valid range (log1p output ≥ 0)
|
||||||
Y_pred = np.maximum(Y_pred, 0)
|
Y_pred = np.maximum(Y_pred, 0)
|
||||||
|
|
||||||
# Invert log
|
# Invert log
|
||||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||||
|
phase_lossy = np.angle(stft)
|
||||||
|
|
||||||
# Reconstruct with Griffin-Lim
|
# Combine: enhanced mag + original phase
|
||||||
y_reconstructed = librosa.griffinlim(
|
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
|
||||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
|
||||||
)
|
return y_reconstructed
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_mono(model, lossy_path, output_path):
|
||||||
|
# Load
|
||||||
|
x, sr = librosa.load(lossy_path, sr=SR)
|
||||||
|
|
||||||
|
y_reconstructed = enhance_audio(model, x, sr)
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
|||||||
89
make_stats.py
Normal file
89
make_stats.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from caveman_wavedataset import WaveformDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
CLEAN_DATA_DIR = "./fma_small"
|
||||||
|
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||||
|
SR = 44100 # sample rate
|
||||||
|
DURATION = 2.0 # seconds per clip
|
||||||
|
N_MELS = None # we'll use full STFT for now
|
||||||
|
HOP = 512
|
||||||
|
N_FFT = 1024
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
EPOCHS = 100
|
||||||
|
PREFETCH = 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
prefetch_factor=PREFETCH,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = torch.load("freq_stats.pth")
|
||||||
|
freq_mean = stats["mean"].numpy() # (513,)
|
||||||
|
freq_std = stats["std"].numpy() # (513,)
|
||||||
|
|
||||||
|
freq_mean_torch = stats["mean"] # (513,)
|
||||||
|
freq_std_torch = stats["std"] # (513,)
|
||||||
|
|
||||||
|
dataset.mean = freq_mean
|
||||||
|
dataset.std = freq_std
|
||||||
|
|
||||||
|
|
||||||
|
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
|
||||||
|
mean = torch.zeros(freq_bins, device=device)
|
||||||
|
M2 = torch.zeros(freq_bins, device=device)
|
||||||
|
total_frames = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, lossy in tqdm(dataloader, desc="Stats"):
|
||||||
|
x = lossy.to(device) # (B, 1, F, T)
|
||||||
|
B, _, F, T = x.shape
|
||||||
|
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
|
||||||
|
N = x_flat.shape[1]
|
||||||
|
|
||||||
|
if N == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Current mean (broadcast)
|
||||||
|
mean_old = mean.clone()
|
||||||
|
|
||||||
|
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
|
||||||
|
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
|
||||||
|
|
||||||
|
# Update M2: M2 += (x - mean_old) * (x - new_mean)
|
||||||
|
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
|
||||||
|
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
|
||||||
|
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
|
||||||
|
|
||||||
|
total_frames += N
|
||||||
|
|
||||||
|
std = (
|
||||||
|
torch.sqrt(M2 / (total_frames - 1))
|
||||||
|
if total_frames > 1
|
||||||
|
else torch.ones_like(mean)
|
||||||
|
)
|
||||||
|
return mean.cpu(), std.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
|
||||||
|
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
|
||||||
|
print(compute_per_freq_stats_vectorized(loader))
|
||||||
11
misc.py
Normal file
11
misc.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from settings import N_FFT, HOP
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_logmag(audio):
|
||||||
|
# STFT
|
||||||
|
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||||
|
mag = np.abs(stft)
|
||||||
|
logmag = np.log1p(mag) # log(1 + x) for stability
|
||||||
|
return logmag # shape: (freq_bins, time_frames) = (513, T)
|
||||||
31
model.py
Normal file
31
model.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
INPUT_KERNEL = 15
|
||||||
|
SIZE = 32
|
||||||
|
|
||||||
|
|
||||||
|
class CavemanEnhancer(nn.Module):
|
||||||
|
def __init__(self, freq_bins=513):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=SIZE,
|
||||||
|
kernel_size=INPUT_KERNEL,
|
||||||
|
padding=INPUT_KERNEL // 2,
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, freq_bins)
|
||||||
|
return self.net(x)
|
||||||
3
settings.py
Normal file
3
settings.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
SR = 44100 # sample rate
|
||||||
|
HOP = 512
|
||||||
|
N_FFT = 1024
|
||||||
Reference in New Issue
Block a user