added new fixed code

This commit is contained in:
2026-01-10 20:35:21 +01:00
parent 79f93e5c29
commit 8379ac8e12
6 changed files with 396 additions and 110 deletions

View File

@@ -1,53 +1,120 @@
import os
import sys
import random
import warnings
import librosa
import numpy as np
import torch.nn as nn
import torch
import soundfile as sf
from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.cuda.amp import autocast
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from misc import audio_to_logmag
from settings import N_FFT, HOP, SR
from model import CavemanEnhancer
warnings.filterwarnings("ignore")
CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (freq_bins, time_frames) = (513, T)
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)
# Duration is the duration of each example to be selected from the dataset.
# Since we are using the FMA dataset, it's max value is 30s.
# From the standpoint of the model and training,
# it should make absolutely no difference on quality,
# only on the speed of the training process. If duration is larger,
# model is going to be trained on more data per example.
# If smaller, less data per example.
#
# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN
# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD
# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST
# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS
# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME.
DURATION = 2
BATCH_SIZE = 4
# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved.
EPOCHS = 100
PREFETCH = 4
# stats = torch.load("freq_stats.pth")
# freq_mean = stats["mean"].numpy() # (513,)
# freq_std = stats["std"].numpy() # (513,)
freq_mean = np.zeros([N_FFT // 2 + 1])
# (513,)
freq_std = np.ones([N_FFT // 2 + 1]) # (513,)
# freq_mean_torch = stats["mean"] # (513,)
# freq_std_torch = stats["std"] # (513,)
freq_mean_torch = torch.from_numpy(freq_mean)
freq_std_torch = torch.from_numpy(freq_std)
def run_example(model_filename, device):
model = CavemanEnhancer().to(device)
model.load_state_dict(
torch.load(model_filename, weights_only=False)["model_state_dict"]
)
model.eval()
enhance_mono(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
# Load
x, sr = librosa.load(
"./examples/mirror_mirror/mirror_mirror.mp3",
sr=SR,
)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
Y_pred = normalize(X)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim
# y_reconstructed = librosa.griffinlim(
# mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
# )
# Combine: enhanced mag + original phase
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
time = np.minimum(x.shape[0], y_reconstructed.shape[0])
print(
f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}"
)
# Save
sf.write(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
y_reconstructed,
sr,
)
show_spectrogram(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
return
def main():
@@ -58,94 +125,96 @@ def main():
if ans != "y":
exit(1)
model = CavemanEnhancer().to(device)
if len(sys.argv) > 1:
model.load_state_dict(
torch.load(sys.argv[1], weights_only=False)["model_state_dict"]
)
model.eval()
enhance_audio(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav",
)
# Load
x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
# Reconstruct with Griffin-Lim
y_reconstructed = librosa.griffinlim(
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
)
import soundfile as sf
# Save
sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
return
run_example(sys.argv[1], device)
exit(0)
# Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR)
dataset.mean = freq_mean
dataset.std = freq_std
# separate the test and val data
n_val = int(0.1 * len(dataset))
n_train = len(dataset) - n_val
train_indices = list(range(n_train))
val_indices = list(range(n_train, len(dataset)))
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=True,
pin_memory=True,
num_workers=16,
)
val_loader = DataLoader(
val_dataset,
batch_size=4,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False,
num_workers=10,
num_workers=16,
pin_memory=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# model
model = CavemanEnhancer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# I am actually not sure it really improves anything, but there is little reason not to keep this, I guess.
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=5,
cooldown=3,
threshold=1e-3,
patience=3,
# cooldown=3,
# threshold=1e-3,
)
from tqdm import tqdm
# Weight: emphasize high frequencies
weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x
weight = torch.exp(weight)
weight = weight.view(1, 513, 1)
criterion = nn.L1Loss()
def weighted_l1_loss(pred, target):
return torch.mean(weight * torch.abs(pred - target))
# criterion = nn.L1Loss()
criterion = weighted_l1_loss
# criterion = nn.MSELoss()
# baseline (doing nothing)
# if True:
# loss = 0.0
# for lossy, clean in tqdm(train_loader, desc="Baseline"):
# loss += criterion(clean, lossy)
# loss /= len(train_loader)
# print(f"baseling loss: {loss:.4f}")
# Train
for epoch in range(EPOCHS):
model.train()
train_loss = 0.0
for lossy, clean in tqdm(train_loader, desc="Training"):
lossy, clean = lossy.to(device), clean.to(device)
with autocast():
enhanced = model(lossy)
loss = criterion(clean, enhanced)
train_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss /= len(train_loader)
# Validate (per epoch)
model.eval()
total_loss = 0.0
val_loss = 0
@@ -154,17 +223,17 @@ def main():
lossy, clean = lossy.to(device), clean.to(device)
output = model(lossy)
loss_ = criterion(output, clean)
total_loss += loss_.item()
val_loss = total_loss / len(train_loader)
scheduler.step(val_loss) # Update learning rate based on validation loss
if (epoch + 1) % 10 == 0:
lr = optimizer.param_groups[0]["lr"]
print(f"LR: {lr:.6f}")
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
lr = optimizer.param_groups[0]["lr"]
print(f"LR: {lr:.6f}")
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}")
# Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap.
torch.save(
{
"epoch": epoch,
@@ -176,12 +245,86 @@ def main():
)
def enhance_audio(model, lossy_path, output_path):
# Load
x, sr = librosa.load(lossy_path, sr=SR)
def file_to_logmag(path):
y, sr = librosa.load(path, sr=SR, mono=True)
print(y.shape)
return np.squeeze(audio_to_logmag(y))
def show_spectrogram(path1, path2, path3):
spectrogram1 = file_to_logmag(path1)
spectrogram2 = file_to_logmag(path2)
spectrogram3 = file_to_logmag(path3)
spectrogram1 = normalize(spectrogram1)
spectrogram2 = normalize(spectrogram2)
spectrogram3 = normalize(spectrogram3)
from matplotlib import pyplot as plt
# Create a figure with two subplots (1 row, 2 columns)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
# Display the first image
axes[0].imshow(spectrogram1, aspect="auto", cmap="gray")
axes[0].set_title("spectrogram 1")
axes[0].axis("off") # Hide axes
# Display the second image
axes[1].imshow(spectrogram2, aspect="auto", cmap="gray")
axes[1].set_title("spectrogram 2")
axes[1].axis("off")
# Display the second image
axes[2].imshow(spectrogram3, aspect="auto", cmap="gray")
axes[2].set_title("spectrogram 3")
axes[2].axis("off")
plt.tight_layout()
plt.show()
def enhance_stereo(model, lossy_path, output_path):
# Load stereo audio (returns shape: (2, T) if stereo)
y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels
# Ensure shape is (2, T)
if y.ndim == 1:
raise ValueError("Input is mono! Expected stereo.")
y_l = y[0]
y_r = y[1]
y_enhanced_l = enhance_audio(model, y_l, sr)
y_enhanced_r = enhance_audio(model, y_r, sr)
stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r))
import soundfile as sf
# Save (soundfile handles (2, T) -> stereo correctly)
sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required
# accepts shape (513, T)!!!!
def denorm_torch(spectrogram):
return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None]
# accepts shape (513, T)!!!!
def normalize(spectrogram):
return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8)
# accepts shape (513, T)!!!!
def denorm(spectrogram):
return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None]
def enhance_audio(model, audio, sr):
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
X = audio_to_logmag(audio) # (513, T)
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
device = "cuda" if torch.cuda.is_available() else "cpu"
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
@@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path):
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(Y_pred, 0)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim
y_reconstructed = librosa.griffinlim(
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
)
# Combine: enhanced mag + original phase
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
return y_reconstructed
def enhance_mono(model, lossy_path, output_path):
# Load
x, sr = librosa.load(lossy_path, sr=SR)
y_reconstructed = enhance_audio(model, x, sr)
import soundfile as sf