added new fixed code
This commit is contained in:
89
make_stats.py
Normal file
89
make_stats.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import warnings
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from caveman_wavedataset import WaveformDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
CLEAN_DATA_DIR = "./fma_small"
|
||||
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||
SR = 44100 # sample rate
|
||||
DURATION = 2.0 # seconds per clip
|
||||
N_MELS = None # we'll use full STFT for now
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 100
|
||||
PREFETCH = 4
|
||||
|
||||
# Data
|
||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
prefetch_factor=PREFETCH,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=16,
|
||||
)
|
||||
|
||||
stats = torch.load("freq_stats.pth")
|
||||
freq_mean = stats["mean"].numpy() # (513,)
|
||||
freq_std = stats["std"].numpy() # (513,)
|
||||
|
||||
freq_mean_torch = stats["mean"] # (513,)
|
||||
freq_std_torch = stats["std"] # (513,)
|
||||
|
||||
dataset.mean = freq_mean
|
||||
dataset.std = freq_std
|
||||
|
||||
|
||||
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
|
||||
mean = torch.zeros(freq_bins, device=device)
|
||||
M2 = torch.zeros(freq_bins, device=device)
|
||||
total_frames = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for _, lossy in tqdm(dataloader, desc="Stats"):
|
||||
x = lossy.to(device) # (B, 1, F, T)
|
||||
B, _, F, T = x.shape
|
||||
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
|
||||
N = x_flat.shape[1]
|
||||
|
||||
if N == 0:
|
||||
continue
|
||||
|
||||
# Current mean (broadcast)
|
||||
mean_old = mean.clone()
|
||||
|
||||
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
|
||||
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
|
||||
|
||||
# Update M2: M2 += (x - mean_old) * (x - new_mean)
|
||||
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
|
||||
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
|
||||
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
|
||||
|
||||
total_frames += N
|
||||
|
||||
std = (
|
||||
torch.sqrt(M2 / (total_frames - 1))
|
||||
if total_frames > 1
|
||||
else torch.ones_like(mean)
|
||||
)
|
||||
return mean.cpu(), std.cpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
|
||||
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
|
||||
print(compute_per_freq_stats_vectorized(loader))
|
||||
Reference in New Issue
Block a user