diff --git a/caveman_wavedataset.py b/caveman_wavedataset.py index 1518b84..9f437a6 100644 --- a/caveman_wavedataset.py +++ b/caveman_wavedataset.py @@ -4,29 +4,21 @@ import librosa from torch.utils.data import Dataset import numpy as np import random - - -HOP = 512 -N_FFT = 1024 -DURATION = 2.0 -SR = 44100 - - -def audio_to_logmag(audio): - # STFT - stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) - mag = np.abs(stft) - logmag = np.log1p(mag) # log(1 + x) for stability - return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T) +from settings import SR, N_FFT +from misc import audio_to_logmag class WaveformDataset(Dataset): - def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4): - self.cache = dict() + mean = np.zeros([N_FFT // 2 + 1]) + std = np.ones([N_FFT // 2 + 1]) + + # Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!! + # For the purposes of this file, it's the length of the audio clip being selected from the dataset. + def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR): + self.segment_duration = segment_duration self.sr = sr self.lossy_dir = lossy_dir self.clean_dir = clean_dir - self.segment_len = int(segment_sec * sr) self.lossy_files = sorted(os.listdir(lossy_dir)) self.clean_files = sorted(os.listdir(clean_dir)) self.file_pairs = [ @@ -51,9 +43,9 @@ class WaveformDataset(Dataset): min_len = min(len(lossy), len(clean)) lossy, clean = lossy[:min_len], clean[:min_len] - # Random 2-second clip + # Random clip - clip_len = int(DURATION * SR) + clip_len = int(self.segment_duration * SR) if min_len < clip_len: # pad if too short lossy = np.pad(lossy, (0, clip_len - min_len)) @@ -61,14 +53,21 @@ class WaveformDataset(Dataset): start = 0 else: start = random.randint(0, min_len - clip_len) + # start = 0 lossy = lossy[start : start + clip_len] clean = clean[start : start + clip_len] + logmag_x = audio_to_logmag(lossy) + logmag_y = audio_to_logmag(clean) + + logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8) + logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8) + ans = ( - torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0), - torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0), + torch.from_numpy(logmag_x_norm).float().unsqueeze(0), + torch.from_numpy(logmag_y_norm).float().unsqueeze(0), ) - self.cache[idx] = ans + # self.cache[idx] = ans return ans diff --git a/cavemanml.py b/cavemanml.py index b1bfa5c..c36b5ab 100644 --- a/cavemanml.py +++ b/cavemanml.py @@ -1,53 +1,120 @@ -import os import sys -import random import warnings import librosa import numpy as np import torch.nn as nn import torch +import soundfile as sf from caveman_wavedataset import WaveformDataset from torch.utils.data import DataLoader -from torch.cuda.amp import autocast, GradScaler +from torch.cuda.amp import autocast import torch.optim.lr_scheduler as lr_scheduler +from tqdm import tqdm +from misc import audio_to_logmag +from settings import N_FFT, HOP, SR +from model import CavemanEnhancer warnings.filterwarnings("ignore") CLEAN_DATA_DIR = "./fma_small" LOSSY_DATA_DIR = "./fma_small_compressed_64/" -SR = 44100 # sample rate -DURATION = 2.0 # seconds per clip -N_MELS = None # we'll use full STFT for now -HOP = 512 -N_FFT = 1024 - - -def audio_to_logmag(audio): - # STFT - stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) - mag = np.abs(stft) - logmag = np.log1p(mag) # log(1 + x) for stability - return logmag # shape: (freq_bins, time_frames) = (513, T) - - -class CavemanEnhancer(nn.Module): - def __init__(self, freq_bins=513): - super().__init__() - self.net = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(32, 32, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv2d(32, 1, kernel_size=3, padding=1), - ) - - def forward(self, x): - # x: (batch, freq_bins) - return self.net(x) - +# Duration is the duration of each example to be selected from the dataset. +# Since we are using the FMA dataset, it's max value is 30s. +# From the standpoint of the model and training, +# it should make absolutely no difference on quality, +# only on the speed of the training process. If duration is larger, +# model is going to be trained on more data per example. +# If smaller, less data per example. +# +# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN +# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD +# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST +# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS +# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME. +DURATION = 2 BATCH_SIZE = 4 +# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved. EPOCHS = 100 +PREFETCH = 4 + +# stats = torch.load("freq_stats.pth") +# freq_mean = stats["mean"].numpy() # (513,) +# freq_std = stats["std"].numpy() # (513,) + +freq_mean = np.zeros([N_FFT // 2 + 1]) +# (513,) +freq_std = np.ones([N_FFT // 2 + 1]) # (513,) + + +# freq_mean_torch = stats["mean"] # (513,) +# freq_std_torch = stats["std"] # (513,) +freq_mean_torch = torch.from_numpy(freq_mean) +freq_std_torch = torch.from_numpy(freq_std) + + +def run_example(model_filename, device): + model = CavemanEnhancer().to(device) + model.load_state_dict( + torch.load(model_filename, weights_only=False)["model_state_dict"] + ) + + model.eval() + enhance_mono( + model, + "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", + "./examples/mirror_mirror/mirror_mirror_decompressed_64.wav", + ) + + # Load + x, sr = librosa.load( + "./examples/mirror_mirror/mirror_mirror.mp3", + sr=SR, + ) + + # Convert to log-mag + X = audio_to_logmag(x) # (513, T) + + Y_pred = normalize(X) + Y_pred = denorm(Y_pred) + + # Clip to valid range (log1p output ≥ 0) + Y_pred = np.maximum(X, 0) + + stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP) + # Invert log + mag_pred = np.expm1(Y_pred) # inverse of log1p + phase_lossy = np.angle(stft) + + # Reconstruct with Griffin-Lim + # y_reconstructed = librosa.griffinlim( + # mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT + # ) + + # Combine: enhanced mag + original phase + stft_enhanced = mag_pred * np.exp(1j * phase_lossy) + y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP) + + time = np.minimum(x.shape[0], y_reconstructed.shape[0]) + + print( + f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}" + ) + + # Save + sf.write( + "./examples/mirror_mirror/mirror_mirror_STFT.mp3", + y_reconstructed, + sr, + ) + + show_spectrogram( + "./examples/mirror_mirror/mirror_mirror_STFT.mp3", + "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", + "./examples/mirror_mirror/mirror_mirror_decompressed_64.wav", + ) + + return def main(): @@ -58,94 +125,96 @@ def main(): if ans != "y": exit(1) - model = CavemanEnhancer().to(device) if len(sys.argv) > 1: - model.load_state_dict( - torch.load(sys.argv[1], weights_only=False)["model_state_dict"] - ) - model.eval() - enhance_audio( - model, - "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", - "./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav", - ) - - # Load - x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR) - - # Convert to log-mag - X = audio_to_logmag(x) # (513, T) - - # Clip to valid range (log1p output ≥ 0) - Y_pred = np.maximum(X, 0) - - # Invert log - mag_pred = np.expm1(Y_pred) # inverse of log1p - - # Reconstruct with Griffin-Lim - y_reconstructed = librosa.griffinlim( - mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT - ) - - import soundfile as sf - - # Save - sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr) - return + run_example(sys.argv[1], device) + exit(0) # Data - dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) + dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR) + dataset.mean = freq_mean + dataset.std = freq_std + # separate the test and val data n_val = int(0.1 * len(dataset)) n_train = len(dataset) - n_val + train_indices = list(range(n_train)) val_indices = list(range(n_train, len(dataset))) + train_dataset = torch.utils.data.Subset(dataset, train_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, + prefetch_factor=PREFETCH, shuffle=True, pin_memory=True, num_workers=16, ) val_loader = DataLoader( val_dataset, - batch_size=4, + batch_size=BATCH_SIZE, + prefetch_factor=PREFETCH, shuffle=False, - num_workers=10, + num_workers=16, pin_memory=True, ) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + # model + model = CavemanEnhancer().to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # I am actually not sure it really improves anything, but there is little reason not to keep this, I guess. scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.1, - patience=5, - cooldown=3, - threshold=1e-3, + patience=3, + # cooldown=3, + # threshold=1e-3, ) - from tqdm import tqdm + # Weight: emphasize high frequencies + weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x + weight = torch.exp(weight) + weight = weight.view(1, 513, 1) - criterion = nn.L1Loss() + def weighted_l1_loss(pred, target): + return torch.mean(weight * torch.abs(pred - target)) + + # criterion = nn.L1Loss() + criterion = weighted_l1_loss + # criterion = nn.MSELoss() + + # baseline (doing nothing) + # if True: + # loss = 0.0 + # for lossy, clean in tqdm(train_loader, desc="Baseline"): + # loss += criterion(clean, lossy) + # loss /= len(train_loader) + # print(f"baseling loss: {loss:.4f}") # Train for epoch in range(EPOCHS): model.train() + train_loss = 0.0 for lossy, clean in tqdm(train_loader, desc="Training"): lossy, clean = lossy.to(device), clean.to(device) with autocast(): enhanced = model(lossy) loss = criterion(clean, enhanced) + train_loss += loss optimizer.zero_grad() loss.backward() optimizer.step() + train_loss /= len(train_loader) + + # Validate (per epoch) model.eval() total_loss = 0.0 val_loss = 0 @@ -154,17 +223,17 @@ def main(): lossy, clean = lossy.to(device), clean.to(device) output = model(lossy) loss_ = criterion(output, clean) + total_loss += loss_.item() val_loss = total_loss / len(train_loader) scheduler.step(val_loss) # Update learning rate based on validation loss - if (epoch + 1) % 10 == 0: - lr = optimizer.param_groups[0]["lr"] - print(f"LR: {lr:.6f}") - - print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}") + lr = optimizer.param_groups[0]["lr"] + print(f"LR: {lr:.6f}") + print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}") + # Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap. torch.save( { "epoch": epoch, @@ -176,12 +245,86 @@ def main(): ) -def enhance_audio(model, lossy_path, output_path): - # Load - x, sr = librosa.load(lossy_path, sr=SR) +def file_to_logmag(path): + y, sr = librosa.load(path, sr=SR, mono=True) + print(y.shape) + return np.squeeze(audio_to_logmag(y)) + +def show_spectrogram(path1, path2, path3): + spectrogram1 = file_to_logmag(path1) + spectrogram2 = file_to_logmag(path2) + spectrogram3 = file_to_logmag(path3) + + spectrogram1 = normalize(spectrogram1) + spectrogram2 = normalize(spectrogram2) + spectrogram3 = normalize(spectrogram3) + + from matplotlib import pyplot as plt + + # Create a figure with two subplots (1 row, 2 columns) + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Display the first image + axes[0].imshow(spectrogram1, aspect="auto", cmap="gray") + axes[0].set_title("spectrogram 1") + axes[0].axis("off") # Hide axes + + # Display the second image + axes[1].imshow(spectrogram2, aspect="auto", cmap="gray") + axes[1].set_title("spectrogram 2") + axes[1].axis("off") + + # Display the second image + axes[2].imshow(spectrogram3, aspect="auto", cmap="gray") + axes[2].set_title("spectrogram 3") + axes[2].axis("off") + + plt.tight_layout() + plt.show() + + +def enhance_stereo(model, lossy_path, output_path): + # Load stereo audio (returns shape: (2, T) if stereo) + y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels + + # Ensure shape is (2, T) + if y.ndim == 1: + raise ValueError("Input is mono! Expected stereo.") + + y_l = y[0] + y_r = y[1] + + y_enhanced_l = enhance_audio(model, y_l, sr) + y_enhanced_r = enhance_audio(model, y_r, sr) + + stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r)) + + import soundfile as sf + + # Save (soundfile handles (2, T) -> stereo correctly) + sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required + + +# accepts shape (513, T)!!!! +def denorm_torch(spectrogram): + return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None] + + +# accepts shape (513, T)!!!! +def normalize(spectrogram): + return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8) + + +# accepts shape (513, T)!!!! +def denorm(spectrogram): + return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None] + + +def enhance_audio(model, audio, sr): # Convert to log-mag - X = audio_to_logmag(x) # (513, T) + X = audio_to_logmag(audio) # (513, T) + stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) device = "cuda" if torch.cuda.is_available() else "cpu" X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513) @@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path): Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513) Y_pred = Y_pred.squeeze(0) # back to (T, 513) + Y_pred = denorm(Y_pred) + # Clip to valid range (log1p output ≥ 0) Y_pred = np.maximum(Y_pred, 0) # Invert log mag_pred = np.expm1(Y_pred) # inverse of log1p + phase_lossy = np.angle(stft) - # Reconstruct with Griffin-Lim - y_reconstructed = librosa.griffinlim( - mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT - ) + # Combine: enhanced mag + original phase + stft_enhanced = mag_pred * np.exp(1j * phase_lossy) + y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP) + return y_reconstructed + + +def enhance_mono(model, lossy_path, output_path): + # Load + x, sr = librosa.load(lossy_path, sr=SR) + + y_reconstructed = enhance_audio(model, x, sr) import soundfile as sf diff --git a/make_stats.py b/make_stats.py new file mode 100644 index 0000000..410d663 --- /dev/null +++ b/make_stats.py @@ -0,0 +1,89 @@ +import os +import sys +import random +import warnings +import librosa +import numpy as np +import torch.nn as nn +import torch +from caveman_wavedataset import WaveformDataset +from torch.utils.data import DataLoader +from torch.cuda.amp import autocast, GradScaler +import torch.optim.lr_scheduler as lr_scheduler +from tqdm import tqdm + + +CLEAN_DATA_DIR = "./fma_small" +LOSSY_DATA_DIR = "./fma_small_compressed_64/" +SR = 44100 # sample rate +DURATION = 2.0 # seconds per clip +N_MELS = None # we'll use full STFT for now +HOP = 512 +N_FFT = 1024 +BATCH_SIZE = 16 +EPOCHS = 100 +PREFETCH = 4 + +# Data +dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) + +loader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + prefetch_factor=PREFETCH, + shuffle=False, + pin_memory=True, + num_workers=16, +) + +stats = torch.load("freq_stats.pth") +freq_mean = stats["mean"].numpy() # (513,) +freq_std = stats["std"].numpy() # (513,) + +freq_mean_torch = stats["mean"] # (513,) +freq_std_torch = stats["std"] # (513,) + +dataset.mean = freq_mean +dataset.std = freq_std + + +def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"): + mean = torch.zeros(freq_bins, device=device) + M2 = torch.zeros(freq_bins, device=device) + total_frames = 0 + + with torch.no_grad(): + for _, lossy in tqdm(dataloader, desc="Stats"): + x = lossy.to(device) # (B, 1, F, T) + B, _, F, T = x.shape + x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N) + N = x_flat.shape[1] + + if N == 0: + continue + + # Current mean (broadcast) + mean_old = mean.clone() + + # Update mean: mean = mean + (sum(x) - N*mean) / (total + N) + mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N) + + # Update M2: M2 += (x - mean_old) * (x - new_mean) + delta_old = x_flat - mean_old.unsqueeze(1) # (F, N) + delta_new = x_flat - mean.unsqueeze(1) # (F, N) + M2 += (delta_old * delta_new).sum(dim=1) # (F,) + + total_frames += N + + std = ( + torch.sqrt(M2 / (total_frames - 1)) + if total_frames > 1 + else torch.ones_like(mean) + ) + return mean.cpu(), std.cpu() + + +if __name__ == "__main__": + # freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda") + # torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth") + print(compute_per_freq_stats_vectorized(loader)) diff --git a/misc.py b/misc.py new file mode 100644 index 0000000..1a56916 --- /dev/null +++ b/misc.py @@ -0,0 +1,11 @@ +import numpy as np +import librosa +from settings import N_FFT, HOP + + +def audio_to_logmag(audio): + # STFT + stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) + mag = np.abs(stft) + logmag = np.log1p(mag) # log(1 + x) for stability + return logmag # shape: (freq_bins, time_frames) = (513, T) diff --git a/model.py b/model.py new file mode 100644 index 0000000..16a6768 --- /dev/null +++ b/model.py @@ -0,0 +1,31 @@ +import torch.nn as nn + +INPUT_KERNEL = 15 +SIZE = 32 + + +class CavemanEnhancer(nn.Module): + def __init__(self, freq_bins=513): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=SIZE, + kernel_size=INPUT_KERNEL, + padding=INPUT_KERNEL // 2, + ), + nn.ReLU(), + nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(SIZE, 1, kernel_size=3, padding=1), + ) + + def forward(self, x): + # x: (batch, freq_bins) + return self.net(x) diff --git a/settings.py b/settings.py new file mode 100644 index 0000000..501a275 --- /dev/null +++ b/settings.py @@ -0,0 +1,3 @@ +SR = 44100 # sample rate +HOP = 512 +N_FFT = 1024