Files
caveman/make_stats.py
2026-01-10 20:35:21 +01:00

90 lines
2.4 KiB
Python

import os
import sys
import random
import warnings
import librosa
import numpy as np
import torch.nn as nn
import torch
from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
BATCH_SIZE = 16
EPOCHS = 100
PREFETCH = 4
# Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False,
pin_memory=True,
num_workers=16,
)
stats = torch.load("freq_stats.pth")
freq_mean = stats["mean"].numpy() # (513,)
freq_std = stats["std"].numpy() # (513,)
freq_mean_torch = stats["mean"] # (513,)
freq_std_torch = stats["std"] # (513,)
dataset.mean = freq_mean
dataset.std = freq_std
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
mean = torch.zeros(freq_bins, device=device)
M2 = torch.zeros(freq_bins, device=device)
total_frames = 0
with torch.no_grad():
for _, lossy in tqdm(dataloader, desc="Stats"):
x = lossy.to(device) # (B, 1, F, T)
B, _, F, T = x.shape
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
N = x_flat.shape[1]
if N == 0:
continue
# Current mean (broadcast)
mean_old = mean.clone()
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
# Update M2: M2 += (x - mean_old) * (x - new_mean)
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
total_frames += N
std = (
torch.sqrt(M2 / (total_frames - 1))
if total_frames > 1
else torch.ones_like(mean)
)
return mean.cpu(), std.cpu()
if __name__ == "__main__":
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
print(compute_per_freq_stats_vectorized(loader))