import os import sys import random import warnings import librosa import numpy as np import torch.nn as nn import torch from caveman_wavedataset import WaveformDataset from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler import torch.optim.lr_scheduler as lr_scheduler from tqdm import tqdm CLEAN_DATA_DIR = "./fma_small" LOSSY_DATA_DIR = "./fma_small_compressed_64/" SR = 44100 # sample rate DURATION = 2.0 # seconds per clip N_MELS = None # we'll use full STFT for now HOP = 512 N_FFT = 1024 BATCH_SIZE = 16 EPOCHS = 100 PREFETCH = 4 # Data dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) loader = DataLoader( dataset, batch_size=BATCH_SIZE, prefetch_factor=PREFETCH, shuffle=False, pin_memory=True, num_workers=16, ) stats = torch.load("freq_stats.pth") freq_mean = stats["mean"].numpy() # (513,) freq_std = stats["std"].numpy() # (513,) freq_mean_torch = stats["mean"] # (513,) freq_std_torch = stats["std"] # (513,) dataset.mean = freq_mean dataset.std = freq_std def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"): mean = torch.zeros(freq_bins, device=device) M2 = torch.zeros(freq_bins, device=device) total_frames = 0 with torch.no_grad(): for _, lossy in tqdm(dataloader, desc="Stats"): x = lossy.to(device) # (B, 1, F, T) B, _, F, T = x.shape x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N) N = x_flat.shape[1] if N == 0: continue # Current mean (broadcast) mean_old = mean.clone() # Update mean: mean = mean + (sum(x) - N*mean) / (total + N) mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N) # Update M2: M2 += (x - mean_old) * (x - new_mean) delta_old = x_flat - mean_old.unsqueeze(1) # (F, N) delta_new = x_flat - mean.unsqueeze(1) # (F, N) M2 += (delta_old * delta_new).sum(dim=1) # (F,) total_frames += N std = ( torch.sqrt(M2 / (total_frames - 1)) if total_frames > 1 else torch.ones_like(mean) ) return mean.cpu(), std.cpu() if __name__ == "__main__": # freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda") # torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth") print(compute_per_freq_stats_vectorized(loader))