import os import sys import random import warnings import librosa import numpy as np import torch.nn as nn import torch from caveman_wavedataset import WaveformDataset from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler import torch.optim.lr_scheduler as lr_scheduler warnings.filterwarnings("ignore") CLEAN_DATA_DIR = "./fma_small" LOSSY_DATA_DIR = "./fma_small_compressed_64/" SR = 44100 # sample rate DURATION = 2.0 # seconds per clip N_MELS = None # we'll use full STFT for now HOP = 512 N_FFT = 1024 def audio_to_logmag(audio): # STFT stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) mag = np.abs(stft) logmag = np.log1p(mag) # log(1 + x) for stability return logmag # shape: (freq_bins, time_frames) = (513, T) class CavemanEnhancer(nn.Module): def __init__(self, freq_bins=513): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 1, kernel_size=3, padding=1), ) def forward(self, x): # x: (batch, freq_bins) return self.net(x) BATCH_SIZE = 4 EPOCHS = 100 def main(): # Model device = "cuda" if torch.cuda.is_available() else "cpu" print(device) ans = input("Do you want to use this device?") if ans != "y": exit(1) model = CavemanEnhancer().to(device) if len(sys.argv) > 1: model.load_state_dict( torch.load(sys.argv[1], weights_only=False)["model_state_dict"] ) model.eval() enhance_audio( model, "./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", "./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav", ) # Load x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR) # Convert to log-mag X = audio_to_logmag(x) # (513, T) # Clip to valid range (log1p output ≥ 0) Y_pred = np.maximum(X, 0) # Invert log mag_pred = np.expm1(Y_pred) # inverse of log1p # Reconstruct with Griffin-Lim y_reconstructed = librosa.griffinlim( mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT ) import soundfile as sf # Save sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr) return # Data dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) n_val = int(0.1 * len(dataset)) n_train = len(dataset) - n_val train_indices = list(range(n_train)) val_indices = list(range(n_train, len(dataset))) train_dataset = torch.utils.data.Subset(dataset, train_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16, ) val_loader = DataLoader( val_dataset, batch_size=4, shuffle=False, num_workers=10, pin_memory=True, ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.1, patience=5, cooldown=3, threshold=1e-3, ) from tqdm import tqdm criterion = nn.L1Loss() # Train for epoch in range(EPOCHS): model.train() for lossy, clean in tqdm(train_loader, desc="Training"): lossy, clean = lossy.to(device), clean.to(device) with autocast(): enhanced = model(lossy) loss = criterion(clean, enhanced) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() total_loss = 0.0 val_loss = 0 with torch.no_grad(): for lossy, clean in tqdm(val_loader, desc="Validating"): lossy, clean = lossy.to(device), clean.to(device) output = model(lossy) loss_ = criterion(output, clean) total_loss += loss_.item() val_loss = total_loss / len(train_loader) scheduler.step(val_loss) # Update learning rate based on validation loss if (epoch + 1) % 10 == 0: lr = optimizer.param_groups[0]["lr"] print(f"LR: {lr:.6f}") print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}") torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, f"checkpoint{epoch}.pth", ) def enhance_audio(model, lossy_path, output_path): # Load x, sr = librosa.load(lossy_path, sr=SR) # Convert to log-mag X = audio_to_logmag(x) # (513, T) device = "cuda" if torch.cuda.is_available() else "cpu" X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513) with torch.no_grad(): Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513) Y_pred = Y_pred.squeeze(0) # back to (T, 513) # Clip to valid range (log1p output ≥ 0) Y_pred = np.maximum(Y_pred, 0) # Invert log mag_pred = np.expm1(Y_pred) # inverse of log1p # Reconstruct with Griffin-Lim y_reconstructed = librosa.griffinlim( mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT ) import soundfile as sf # Save sf.write(output_path, y_reconstructed, sr) if __name__ == "__main__": main()