90 lines
2.4 KiB
Python
90 lines
2.4 KiB
Python
import os
|
|
import sys
|
|
import random
|
|
import warnings
|
|
import librosa
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch
|
|
from caveman_wavedataset import WaveformDataset
|
|
from torch.utils.data import DataLoader
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
from tqdm import tqdm
|
|
|
|
|
|
CLEAN_DATA_DIR = "./fma_small"
|
|
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
|
SR = 44100 # sample rate
|
|
DURATION = 2.0 # seconds per clip
|
|
N_MELS = None # we'll use full STFT for now
|
|
HOP = 512
|
|
N_FFT = 1024
|
|
BATCH_SIZE = 16
|
|
EPOCHS = 100
|
|
PREFETCH = 4
|
|
|
|
# Data
|
|
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
|
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=BATCH_SIZE,
|
|
prefetch_factor=PREFETCH,
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
num_workers=16,
|
|
)
|
|
|
|
stats = torch.load("freq_stats.pth")
|
|
freq_mean = stats["mean"].numpy() # (513,)
|
|
freq_std = stats["std"].numpy() # (513,)
|
|
|
|
freq_mean_torch = stats["mean"] # (513,)
|
|
freq_std_torch = stats["std"] # (513,)
|
|
|
|
dataset.mean = freq_mean
|
|
dataset.std = freq_std
|
|
|
|
|
|
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
|
|
mean = torch.zeros(freq_bins, device=device)
|
|
M2 = torch.zeros(freq_bins, device=device)
|
|
total_frames = 0
|
|
|
|
with torch.no_grad():
|
|
for _, lossy in tqdm(dataloader, desc="Stats"):
|
|
x = lossy.to(device) # (B, 1, F, T)
|
|
B, _, F, T = x.shape
|
|
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
|
|
N = x_flat.shape[1]
|
|
|
|
if N == 0:
|
|
continue
|
|
|
|
# Current mean (broadcast)
|
|
mean_old = mean.clone()
|
|
|
|
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
|
|
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
|
|
|
|
# Update M2: M2 += (x - mean_old) * (x - new_mean)
|
|
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
|
|
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
|
|
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
|
|
|
|
total_frames += N
|
|
|
|
std = (
|
|
torch.sqrt(M2 / (total_frames - 1))
|
|
if total_frames > 1
|
|
else torch.ones_like(mean)
|
|
)
|
|
return mean.cpu(), std.cpu()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
|
|
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
|
|
print(compute_per_freq_stats_vectorized(loader))
|