added new fixed code

This commit is contained in:
2026-01-10 20:35:21 +01:00
parent 79f93e5c29
commit 8379ac8e12
6 changed files with 396 additions and 110 deletions

View File

@@ -4,29 +4,21 @@ import librosa
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np import numpy as np
import random import random
from settings import SR, N_FFT
from misc import audio_to_logmag
HOP = 512
N_FFT = 1024
DURATION = 2.0
SR = 44100
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T)
class WaveformDataset(Dataset): class WaveformDataset(Dataset):
def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4): mean = np.zeros([N_FFT // 2 + 1])
self.cache = dict() std = np.ones([N_FFT // 2 + 1])
# Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!!
# For the purposes of this file, it's the length of the audio clip being selected from the dataset.
def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR):
self.segment_duration = segment_duration
self.sr = sr self.sr = sr
self.lossy_dir = lossy_dir self.lossy_dir = lossy_dir
self.clean_dir = clean_dir self.clean_dir = clean_dir
self.segment_len = int(segment_sec * sr)
self.lossy_files = sorted(os.listdir(lossy_dir)) self.lossy_files = sorted(os.listdir(lossy_dir))
self.clean_files = sorted(os.listdir(clean_dir)) self.clean_files = sorted(os.listdir(clean_dir))
self.file_pairs = [ self.file_pairs = [
@@ -51,9 +43,9 @@ class WaveformDataset(Dataset):
min_len = min(len(lossy), len(clean)) min_len = min(len(lossy), len(clean))
lossy, clean = lossy[:min_len], clean[:min_len] lossy, clean = lossy[:min_len], clean[:min_len]
# Random 2-second clip # Random clip
clip_len = int(DURATION * SR) clip_len = int(self.segment_duration * SR)
if min_len < clip_len: if min_len < clip_len:
# pad if too short # pad if too short
lossy = np.pad(lossy, (0, clip_len - min_len)) lossy = np.pad(lossy, (0, clip_len - min_len))
@@ -61,14 +53,21 @@ class WaveformDataset(Dataset):
start = 0 start = 0
else: else:
start = random.randint(0, min_len - clip_len) start = random.randint(0, min_len - clip_len)
# start = 0
lossy = lossy[start : start + clip_len] lossy = lossy[start : start + clip_len]
clean = clean[start : start + clip_len] clean = clean[start : start + clip_len]
logmag_x = audio_to_logmag(lossy)
logmag_y = audio_to_logmag(clean)
logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8)
logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8)
ans = ( ans = (
torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0), torch.from_numpy(logmag_x_norm).float().unsqueeze(0),
torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0), torch.from_numpy(logmag_y_norm).float().unsqueeze(0),
) )
self.cache[idx] = ans # self.cache[idx] = ans
return ans return ans

View File

@@ -1,53 +1,120 @@
import os
import sys import sys
import random
import warnings import warnings
import librosa import librosa
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch import torch
import soundfile as sf
from caveman_wavedataset import WaveformDataset from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from misc import audio_to_logmag
from settings import N_FFT, HOP, SR
from model import CavemanEnhancer
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
CLEAN_DATA_DIR = "./fma_small" CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/" LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (freq_bins, time_frames) = (513, T)
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)
# Duration is the duration of each example to be selected from the dataset.
# Since we are using the FMA dataset, it's max value is 30s.
# From the standpoint of the model and training,
# it should make absolutely no difference on quality,
# only on the speed of the training process. If duration is larger,
# model is going to be trained on more data per example.
# If smaller, less data per example.
#
# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN
# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD
# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST
# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS
# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME.
DURATION = 2
BATCH_SIZE = 4 BATCH_SIZE = 4
# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved.
EPOCHS = 100 EPOCHS = 100
PREFETCH = 4
# stats = torch.load("freq_stats.pth")
# freq_mean = stats["mean"].numpy() # (513,)
# freq_std = stats["std"].numpy() # (513,)
freq_mean = np.zeros([N_FFT // 2 + 1])
# (513,)
freq_std = np.ones([N_FFT // 2 + 1]) # (513,)
# freq_mean_torch = stats["mean"] # (513,)
# freq_std_torch = stats["std"] # (513,)
freq_mean_torch = torch.from_numpy(freq_mean)
freq_std_torch = torch.from_numpy(freq_std)
def run_example(model_filename, device):
model = CavemanEnhancer().to(device)
model.load_state_dict(
torch.load(model_filename, weights_only=False)["model_state_dict"]
)
model.eval()
enhance_mono(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
# Load
x, sr = librosa.load(
"./examples/mirror_mirror/mirror_mirror.mp3",
sr=SR,
)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
Y_pred = normalize(X)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim
# y_reconstructed = librosa.griffinlim(
# mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
# )
# Combine: enhanced mag + original phase
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
time = np.minimum(x.shape[0], y_reconstructed.shape[0])
print(
f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}"
)
# Save
sf.write(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
y_reconstructed,
sr,
)
show_spectrogram(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
return
def main(): def main():
@@ -58,94 +125,96 @@ def main():
if ans != "y": if ans != "y":
exit(1) exit(1)
model = CavemanEnhancer().to(device)
if len(sys.argv) > 1: if len(sys.argv) > 1:
model.load_state_dict( run_example(sys.argv[1], device)
torch.load(sys.argv[1], weights_only=False)["model_state_dict"] exit(0)
)
model.eval()
enhance_audio(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav",
)
# Load
x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
# Reconstruct with Griffin-Lim
y_reconstructed = librosa.griffinlim(
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
)
import soundfile as sf
# Save
sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
return
# Data # Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR)
dataset.mean = freq_mean
dataset.std = freq_std
# separate the test and val data
n_val = int(0.1 * len(dataset)) n_val = int(0.1 * len(dataset))
n_train = len(dataset) - n_val n_train = len(dataset) - n_val
train_indices = list(range(n_train)) train_indices = list(range(n_train))
val_indices = list(range(n_train, len(dataset))) val_indices = list(range(n_train, len(dataset)))
train_dataset = torch.utils.data.Subset(dataset, train_indices) train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=True, shuffle=True,
pin_memory=True, pin_memory=True,
num_workers=16, num_workers=16,
) )
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=4, batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False, shuffle=False,
num_workers=10, num_workers=16,
pin_memory=True, pin_memory=True,
) )
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # model
model = CavemanEnhancer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# I am actually not sure it really improves anything, but there is little reason not to keep this, I guess.
scheduler = lr_scheduler.ReduceLROnPlateau( scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer, optimizer,
mode="min", mode="min",
factor=0.1, factor=0.1,
patience=5, patience=3,
cooldown=3, # cooldown=3,
threshold=1e-3, # threshold=1e-3,
) )
from tqdm import tqdm # Weight: emphasize high frequencies
weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x
weight = torch.exp(weight)
weight = weight.view(1, 513, 1)
criterion = nn.L1Loss() def weighted_l1_loss(pred, target):
return torch.mean(weight * torch.abs(pred - target))
# criterion = nn.L1Loss()
criterion = weighted_l1_loss
# criterion = nn.MSELoss()
# baseline (doing nothing)
# if True:
# loss = 0.0
# for lossy, clean in tqdm(train_loader, desc="Baseline"):
# loss += criterion(clean, lossy)
# loss /= len(train_loader)
# print(f"baseling loss: {loss:.4f}")
# Train # Train
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
model.train() model.train()
train_loss = 0.0
for lossy, clean in tqdm(train_loader, desc="Training"): for lossy, clean in tqdm(train_loader, desc="Training"):
lossy, clean = lossy.to(device), clean.to(device) lossy, clean = lossy.to(device), clean.to(device)
with autocast(): with autocast():
enhanced = model(lossy) enhanced = model(lossy)
loss = criterion(clean, enhanced) loss = criterion(clean, enhanced)
train_loss += loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_loss /= len(train_loader)
# Validate (per epoch)
model.eval() model.eval()
total_loss = 0.0 total_loss = 0.0
val_loss = 0 val_loss = 0
@@ -154,17 +223,17 @@ def main():
lossy, clean = lossy.to(device), clean.to(device) lossy, clean = lossy.to(device), clean.to(device)
output = model(lossy) output = model(lossy)
loss_ = criterion(output, clean) loss_ = criterion(output, clean)
total_loss += loss_.item() total_loss += loss_.item()
val_loss = total_loss / len(train_loader) val_loss = total_loss / len(train_loader)
scheduler.step(val_loss) # Update learning rate based on validation loss scheduler.step(val_loss) # Update learning rate based on validation loss
if (epoch + 1) % 10 == 0: lr = optimizer.param_groups[0]["lr"]
lr = optimizer.param_groups[0]["lr"] print(f"LR: {lr:.6f}")
print(f"LR: {lr:.6f}") print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}")
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
# Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap.
torch.save( torch.save(
{ {
"epoch": epoch, "epoch": epoch,
@@ -176,12 +245,86 @@ def main():
) )
def enhance_audio(model, lossy_path, output_path): def file_to_logmag(path):
# Load y, sr = librosa.load(path, sr=SR, mono=True)
x, sr = librosa.load(lossy_path, sr=SR) print(y.shape)
return np.squeeze(audio_to_logmag(y))
def show_spectrogram(path1, path2, path3):
spectrogram1 = file_to_logmag(path1)
spectrogram2 = file_to_logmag(path2)
spectrogram3 = file_to_logmag(path3)
spectrogram1 = normalize(spectrogram1)
spectrogram2 = normalize(spectrogram2)
spectrogram3 = normalize(spectrogram3)
from matplotlib import pyplot as plt
# Create a figure with two subplots (1 row, 2 columns)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
# Display the first image
axes[0].imshow(spectrogram1, aspect="auto", cmap="gray")
axes[0].set_title("spectrogram 1")
axes[0].axis("off") # Hide axes
# Display the second image
axes[1].imshow(spectrogram2, aspect="auto", cmap="gray")
axes[1].set_title("spectrogram 2")
axes[1].axis("off")
# Display the second image
axes[2].imshow(spectrogram3, aspect="auto", cmap="gray")
axes[2].set_title("spectrogram 3")
axes[2].axis("off")
plt.tight_layout()
plt.show()
def enhance_stereo(model, lossy_path, output_path):
# Load stereo audio (returns shape: (2, T) if stereo)
y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels
# Ensure shape is (2, T)
if y.ndim == 1:
raise ValueError("Input is mono! Expected stereo.")
y_l = y[0]
y_r = y[1]
y_enhanced_l = enhance_audio(model, y_l, sr)
y_enhanced_r = enhance_audio(model, y_r, sr)
stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r))
import soundfile as sf
# Save (soundfile handles (2, T) -> stereo correctly)
sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required
# accepts shape (513, T)!!!!
def denorm_torch(spectrogram):
return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None]
# accepts shape (513, T)!!!!
def normalize(spectrogram):
return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8)
# accepts shape (513, T)!!!!
def denorm(spectrogram):
return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None]
def enhance_audio(model, audio, sr):
# Convert to log-mag # Convert to log-mag
X = audio_to_logmag(x) # (513, T) X = audio_to_logmag(audio) # (513, T)
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513) X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
@@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path):
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513) Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
Y_pred = Y_pred.squeeze(0) # back to (T, 513) Y_pred = Y_pred.squeeze(0) # back to (T, 513)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0) # Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(Y_pred, 0) Y_pred = np.maximum(Y_pred, 0)
# Invert log # Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim # Combine: enhanced mag + original phase
y_reconstructed = librosa.griffinlim( stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
) return y_reconstructed
def enhance_mono(model, lossy_path, output_path):
# Load
x, sr = librosa.load(lossy_path, sr=SR)
y_reconstructed = enhance_audio(model, x, sr)
import soundfile as sf import soundfile as sf

89
make_stats.py Normal file
View File

@@ -0,0 +1,89 @@
import os
import sys
import random
import warnings
import librosa
import numpy as np
import torch.nn as nn
import torch
from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
BATCH_SIZE = 16
EPOCHS = 100
PREFETCH = 4
# Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False,
pin_memory=True,
num_workers=16,
)
stats = torch.load("freq_stats.pth")
freq_mean = stats["mean"].numpy() # (513,)
freq_std = stats["std"].numpy() # (513,)
freq_mean_torch = stats["mean"] # (513,)
freq_std_torch = stats["std"] # (513,)
dataset.mean = freq_mean
dataset.std = freq_std
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
mean = torch.zeros(freq_bins, device=device)
M2 = torch.zeros(freq_bins, device=device)
total_frames = 0
with torch.no_grad():
for _, lossy in tqdm(dataloader, desc="Stats"):
x = lossy.to(device) # (B, 1, F, T)
B, _, F, T = x.shape
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
N = x_flat.shape[1]
if N == 0:
continue
# Current mean (broadcast)
mean_old = mean.clone()
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
# Update M2: M2 += (x - mean_old) * (x - new_mean)
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
total_frames += N
std = (
torch.sqrt(M2 / (total_frames - 1))
if total_frames > 1
else torch.ones_like(mean)
)
return mean.cpu(), std.cpu()
if __name__ == "__main__":
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
print(compute_per_freq_stats_vectorized(loader))

11
misc.py Normal file
View File

@@ -0,0 +1,11 @@
import numpy as np
import librosa
from settings import N_FFT, HOP
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (freq_bins, time_frames) = (513, T)

31
model.py Normal file
View File

@@ -0,0 +1,31 @@
import torch.nn as nn
INPUT_KERNEL = 15
SIZE = 32
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=SIZE,
kernel_size=INPUT_KERNEL,
padding=INPUT_KERNEL // 2,
),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)

3
settings.py Normal file
View File

@@ -0,0 +1,3 @@
SR = 44100 # sample rate
HOP = 512
N_FFT = 1024