import sys import warnings import librosa import numpy as np import torch.nn as nn import torch import soundfile as sf from caveman_wavedataset import WaveformDataset from torch.utils.data import DataLoader from torch.cuda.amp import autocast import torch.optim.lr_scheduler as lr_scheduler from tqdm import tqdm from misc import audio_to_logmag from settings import N_FFT, HOP, SR from model import CavemanEnhancer warnings.filterwarnings("ignore") CLEAN_DATA_DIR = "./fma_small" LOSSY_DATA_DIR = "./fma_small_compressed_64/" # Duration is the duration of each example to be selected from the dataset. # Since we are using the FMA dataset, it's max value is 30s. # From the standpoint of the model and training, # it should make absolutely no difference on quality, # only on the speed of the training process. If duration is larger, # model is going to be trained on more data per example. # If smaller, less data per example. # # YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN # LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD # ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST # THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS # FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME. DURATION = 2 BATCH_SIZE = 4 # 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved. EPOCHS = 100 PREFETCH = 4 # stats = torch.load("freq_stats.pth") # freq_mean = stats["mean"].numpy() # (513,) # freq_std = stats["std"].numpy() # (513,) freq_mean = np.zeros([N_FFT // 2 + 1]) # (513,) freq_std = np.ones([N_FFT // 2 + 1]) # (513,) # freq_mean_torch = stats["mean"] # (513,) # freq_std_torch = stats["std"] # (513,) freq_mean_torch = torch.from_numpy(freq_mean) freq_std_torch = torch.from_numpy(freq_std) def run_example(model_filename, device): model = CavemanEnhancer().to(device) model.load_state_dict( torch.load(model_filename, weights_only=False)["model_state_dict"] ) model.eval() enhance_mono( model, "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", "./examples/mirror_mirror/mirror_mirror_decompressed_64.wav", ) # Load x, sr = librosa.load( "./examples/mirror_mirror/mirror_mirror.mp3", sr=SR, ) # Convert to log-mag X = audio_to_logmag(x) # (513, T) Y_pred = normalize(X) Y_pred = denorm(Y_pred) # Clip to valid range (log1p output ≥ 0) Y_pred = np.maximum(X, 0) stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP) # Invert log mag_pred = np.expm1(Y_pred) # inverse of log1p phase_lossy = np.angle(stft) # Reconstruct with Griffin-Lim # y_reconstructed = librosa.griffinlim( # mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT # ) # Combine: enhanced mag + original phase stft_enhanced = mag_pred * np.exp(1j * phase_lossy) y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP) time = np.minimum(x.shape[0], y_reconstructed.shape[0]) print( f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}" ) # Save sf.write( "./examples/mirror_mirror/mirror_mirror_STFT.mp3", y_reconstructed, sr, ) show_spectrogram( "./examples/mirror_mirror/mirror_mirror_STFT.mp3", "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", "./examples/mirror_mirror/mirror_mirror_decompressed_64.wav", ) return def main(): # Model device = "cuda" if torch.cuda.is_available() else "cpu" print(device) ans = input("Do you want to use this device?") if ans != "y": exit(1) if len(sys.argv) > 1: run_example(sys.argv[1], device) exit(0) # Data dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR) dataset.mean = freq_mean dataset.std = freq_std # separate the test and val data n_val = int(0.1 * len(dataset)) n_train = len(dataset) - n_val train_indices = list(range(n_train)) val_indices = list(range(n_train, len(dataset))) train_dataset = torch.utils.data.Subset(dataset, train_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH, shuffle=True, pin_memory=True, num_workers=16, ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH, shuffle=False, num_workers=16, pin_memory=True, ) # model model = CavemanEnhancer().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # I am actually not sure it really improves anything, but there is little reason not to keep this, I guess. scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.1, patience=3, # cooldown=3, # threshold=1e-3, ) # Weight: emphasize high frequencies weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x weight = torch.exp(weight) weight = weight.view(1, 513, 1) def weighted_l1_loss(pred, target): return torch.mean(weight * torch.abs(pred - target)) # criterion = nn.L1Loss() criterion = weighted_l1_loss # criterion = nn.MSELoss() # baseline (doing nothing) # if True: # loss = 0.0 # for lossy, clean in tqdm(train_loader, desc="Baseline"): # loss += criterion(clean, lossy) # loss /= len(train_loader) # print(f"baseling loss: {loss:.4f}") # Train for epoch in range(EPOCHS): model.train() train_loss = 0.0 for lossy, clean in tqdm(train_loader, desc="Training"): lossy, clean = lossy.to(device), clean.to(device) with autocast(): enhanced = model(lossy) loss = criterion(clean, enhanced) train_loss += loss optimizer.zero_grad() loss.backward() optimizer.step() train_loss /= len(train_loader) # Validate (per epoch) model.eval() total_loss = 0.0 val_loss = 0 with torch.no_grad(): for lossy, clean in tqdm(val_loader, desc="Validating"): lossy, clean = lossy.to(device), clean.to(device) output = model(lossy) loss_ = criterion(output, clean) total_loss += loss_.item() val_loss = total_loss / len(train_loader) scheduler.step(val_loss) # Update learning rate based on validation loss lr = optimizer.param_groups[0]["lr"] print(f"LR: {lr:.6f}") print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}") # Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap. torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, f"checkpoint{epoch}.pth", ) def file_to_logmag(path): y, sr = librosa.load(path, sr=SR, mono=True) print(y.shape) return np.squeeze(audio_to_logmag(y)) def show_spectrogram(path1, path2, path3): spectrogram1 = file_to_logmag(path1) spectrogram2 = file_to_logmag(path2) spectrogram3 = file_to_logmag(path3) spectrogram1 = normalize(spectrogram1) spectrogram2 = normalize(spectrogram2) spectrogram3 = normalize(spectrogram3) from matplotlib import pyplot as plt # Create a figure with two subplots (1 row, 2 columns) fig, axes = plt.subplots(1, 3, figsize=(10, 5)) # Display the first image axes[0].imshow(spectrogram1, aspect="auto", cmap="gray") axes[0].set_title("spectrogram 1") axes[0].axis("off") # Hide axes # Display the second image axes[1].imshow(spectrogram2, aspect="auto", cmap="gray") axes[1].set_title("spectrogram 2") axes[1].axis("off") # Display the second image axes[2].imshow(spectrogram3, aspect="auto", cmap="gray") axes[2].set_title("spectrogram 3") axes[2].axis("off") plt.tight_layout() plt.show() def enhance_stereo(model, lossy_path, output_path): # Load stereo audio (returns shape: (2, T) if stereo) y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels # Ensure shape is (2, T) if y.ndim == 1: raise ValueError("Input is mono! Expected stereo.") y_l = y[0] y_r = y[1] y_enhanced_l = enhance_audio(model, y_l, sr) y_enhanced_r = enhance_audio(model, y_r, sr) stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r)) import soundfile as sf # Save (soundfile handles (2, T) -> stereo correctly) sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required # accepts shape (513, T)!!!! def denorm_torch(spectrogram): return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None] # accepts shape (513, T)!!!! def normalize(spectrogram): return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8) # accepts shape (513, T)!!!! def denorm(spectrogram): return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None] def enhance_audio(model, audio, sr): # Convert to log-mag X = audio_to_logmag(audio) # (513, T) stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) device = "cuda" if torch.cuda.is_available() else "cpu" X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513) with torch.no_grad(): Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513) Y_pred = Y_pred.squeeze(0) # back to (T, 513) Y_pred = denorm(Y_pred) # Clip to valid range (log1p output ≥ 0) Y_pred = np.maximum(Y_pred, 0) # Invert log mag_pred = np.expm1(Y_pred) # inverse of log1p phase_lossy = np.angle(stft) # Combine: enhanced mag + original phase stft_enhanced = mag_pred * np.exp(1j * phase_lossy) y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP) return y_reconstructed def enhance_mono(model, lossy_path, output_path): # Load x, sr = librosa.load(lossy_path, sr=SR) y_reconstructed = enhance_audio(model, x, sr) import soundfile as sf # Save sf.write(output_path, y_reconstructed, sr) if __name__ == "__main__": main()