Compare commits
4 Commits
79f93e5c29
...
7a04a3ac67
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a04a3ac67 | |||
| d6df888387 | |||
| fb97f6c3a8 | |||
| 8379ac8e12 |
28
README.md
28
README.md
@@ -1,3 +1,31 @@
|
||||
# caveman
|
||||
|
||||
Audio reconstruction model, that "uncompresses" lossy mp3s.
|
||||
|
||||
## Installing dependencies
|
||||
|
||||
There is the pyproject file, so you can use that.
|
||||
|
||||
Or, if you prefer to keep your sanity, use uv:
|
||||
|
||||
```
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Adjust the location of your data directories via the variables in the cavemanml.py file. Any compression type is accepted, as long as your audio is 44100kHz.
|
||||
|
||||
```
|
||||
uv run cavemanml.py
|
||||
```
|
||||
|
||||
Obviously the original dataset (FMA) is not provided with this repo.
|
||||
|
||||
## Inference
|
||||
|
||||
For now, the only inference this code can do, is the example. You can adjust the input and output file in the code, though. Theoretically, the model should take on anything. To use the provided checkpoint:
|
||||
|
||||
```
|
||||
uv run cavemanml.py ./checkpoint70.pth
|
||||
```
|
||||
|
||||
@@ -4,29 +4,21 @@ import librosa
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
DURATION = 2.0
|
||||
SR = 44100
|
||||
|
||||
|
||||
def audio_to_logmag(audio):
|
||||
# STFT
|
||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||
mag = np.abs(stft)
|
||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
||||
return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T)
|
||||
from settings import SR, N_FFT
|
||||
from misc import audio_to_logmag
|
||||
|
||||
|
||||
class WaveformDataset(Dataset):
|
||||
def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4):
|
||||
self.cache = dict()
|
||||
mean = np.zeros([N_FFT // 2 + 1])
|
||||
std = np.ones([N_FFT // 2 + 1])
|
||||
|
||||
# Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!!
|
||||
# For the purposes of this file, it's the length of the audio clip being selected from the dataset.
|
||||
def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR):
|
||||
self.segment_duration = segment_duration
|
||||
self.sr = sr
|
||||
self.lossy_dir = lossy_dir
|
||||
self.clean_dir = clean_dir
|
||||
self.segment_len = int(segment_sec * sr)
|
||||
self.lossy_files = sorted(os.listdir(lossy_dir))
|
||||
self.clean_files = sorted(os.listdir(clean_dir))
|
||||
self.file_pairs = [
|
||||
@@ -51,9 +43,9 @@ class WaveformDataset(Dataset):
|
||||
min_len = min(len(lossy), len(clean))
|
||||
lossy, clean = lossy[:min_len], clean[:min_len]
|
||||
|
||||
# Random 2-second clip
|
||||
# Random clip
|
||||
|
||||
clip_len = int(DURATION * SR)
|
||||
clip_len = int(self.segment_duration * SR)
|
||||
if min_len < clip_len:
|
||||
# pad if too short
|
||||
lossy = np.pad(lossy, (0, clip_len - min_len))
|
||||
@@ -61,14 +53,21 @@ class WaveformDataset(Dataset):
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, min_len - clip_len)
|
||||
# start = 0
|
||||
lossy = lossy[start : start + clip_len]
|
||||
clean = clean[start : start + clip_len]
|
||||
|
||||
logmag_x = audio_to_logmag(lossy)
|
||||
logmag_y = audio_to_logmag(clean)
|
||||
|
||||
logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
||||
logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
||||
|
||||
ans = (
|
||||
torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0),
|
||||
torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0),
|
||||
torch.from_numpy(logmag_x_norm).float().unsqueeze(0),
|
||||
torch.from_numpy(logmag_y_norm).float().unsqueeze(0),
|
||||
)
|
||||
|
||||
self.cache[idx] = ans
|
||||
# self.cache[idx] = ans
|
||||
|
||||
return ans
|
||||
|
||||
315
cavemanml.py
315
cavemanml.py
@@ -1,53 +1,120 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import warnings
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from caveman_wavedataset import WaveformDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.cuda.amp import autocast
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from tqdm import tqdm
|
||||
from misc import audio_to_logmag
|
||||
from settings import N_FFT, HOP, SR
|
||||
from model import CavemanEnhancer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
CLEAN_DATA_DIR = "./fma_small"
|
||||
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||
SR = 44100 # sample rate
|
||||
DURATION = 2.0 # seconds per clip
|
||||
N_MELS = None # we'll use full STFT for now
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
|
||||
# Duration is the duration of each example to be selected from the dataset.
|
||||
# Since we are using the FMA dataset, it's max value is 30s.
|
||||
# From the standpoint of the model and training,
|
||||
# it should make absolutely no difference on quality,
|
||||
# only on the speed of the training process. If duration is larger,
|
||||
# model is going to be trained on more data per example.
|
||||
# If smaller, less data per example.
|
||||
#
|
||||
# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN
|
||||
# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD
|
||||
# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST
|
||||
# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS
|
||||
# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME.
|
||||
DURATION = 2
|
||||
BATCH_SIZE = 4
|
||||
# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved.
|
||||
EPOCHS = 100
|
||||
PREFETCH = 4
|
||||
|
||||
# stats = torch.load("freq_stats.pth")
|
||||
# freq_mean = stats["mean"].numpy() # (513,)
|
||||
# freq_std = stats["std"].numpy() # (513,)
|
||||
|
||||
freq_mean = np.zeros([N_FFT // 2 + 1])
|
||||
# (513,)
|
||||
freq_std = np.ones([N_FFT // 2 + 1]) # (513,)
|
||||
|
||||
|
||||
def audio_to_logmag(audio):
|
||||
# STFT
|
||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||
mag = np.abs(stft)
|
||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
||||
return logmag # shape: (freq_bins, time_frames) = (513, T)
|
||||
# freq_mean_torch = stats["mean"] # (513,)
|
||||
# freq_std_torch = stats["std"] # (513,)
|
||||
freq_mean_torch = torch.from_numpy(freq_mean)
|
||||
freq_std_torch = torch.from_numpy(freq_std)
|
||||
|
||||
|
||||
class CavemanEnhancer(nn.Module):
|
||||
def __init__(self, freq_bins=513):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 1, kernel_size=3, padding=1),
|
||||
def run_example(model_filename, device):
|
||||
model = CavemanEnhancer().to(device)
|
||||
model.load_state_dict(
|
||||
torch.load(model_filename, weights_only=False)["model_state_dict"]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, freq_bins)
|
||||
return self.net(x)
|
||||
model.eval()
|
||||
enhance_mono(
|
||||
model,
|
||||
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
||||
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
|
||||
)
|
||||
|
||||
# Load
|
||||
x, sr = librosa.load(
|
||||
"./examples/mirror_mirror/mirror_mirror.mp3",
|
||||
sr=SR,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
# Convert to log-mag
|
||||
X = audio_to_logmag(x) # (513, T)
|
||||
|
||||
Y_pred = normalize(X)
|
||||
Y_pred = denorm(Y_pred)
|
||||
|
||||
# Clip to valid range (log1p output ≥ 0)
|
||||
Y_pred = np.maximum(X, 0)
|
||||
|
||||
stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP)
|
||||
# Invert log
|
||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||
phase_lossy = np.angle(stft)
|
||||
|
||||
# Reconstruct with Griffin-Lim
|
||||
# y_reconstructed = librosa.griffinlim(
|
||||
# mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||
# )
|
||||
|
||||
# Combine: enhanced mag + original phase
|
||||
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
|
||||
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
|
||||
|
||||
time = np.minimum(x.shape[0], y_reconstructed.shape[0])
|
||||
|
||||
print(
|
||||
f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}"
|
||||
)
|
||||
|
||||
# Save
|
||||
sf.write(
|
||||
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
|
||||
y_reconstructed,
|
||||
sr,
|
||||
)
|
||||
|
||||
show_spectrogram(
|
||||
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
|
||||
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
||||
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
@@ -58,94 +125,96 @@ def main():
|
||||
if ans != "y":
|
||||
exit(1)
|
||||
|
||||
model = CavemanEnhancer().to(device)
|
||||
if len(sys.argv) > 1:
|
||||
model.load_state_dict(
|
||||
torch.load(sys.argv[1], weights_only=False)["model_state_dict"]
|
||||
)
|
||||
model.eval()
|
||||
enhance_audio(
|
||||
model,
|
||||
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
|
||||
"./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav",
|
||||
)
|
||||
|
||||
# Load
|
||||
x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR)
|
||||
|
||||
# Convert to log-mag
|
||||
X = audio_to_logmag(x) # (513, T)
|
||||
|
||||
# Clip to valid range (log1p output ≥ 0)
|
||||
Y_pred = np.maximum(X, 0)
|
||||
|
||||
# Invert log
|
||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||
|
||||
# Reconstruct with Griffin-Lim
|
||||
y_reconstructed = librosa.griffinlim(
|
||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||
)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
# Save
|
||||
sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
|
||||
return
|
||||
run_example(sys.argv[1], device)
|
||||
exit(0)
|
||||
|
||||
# Data
|
||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR)
|
||||
dataset.mean = freq_mean
|
||||
dataset.std = freq_std
|
||||
|
||||
# separate the test and val data
|
||||
n_val = int(0.1 * len(dataset))
|
||||
n_train = len(dataset) - n_val
|
||||
|
||||
train_indices = list(range(n_train))
|
||||
val_indices = list(range(n_train, len(dataset)))
|
||||
|
||||
train_dataset = torch.utils.data.Subset(dataset, train_indices)
|
||||
val_dataset = torch.utils.data.Subset(dataset, val_indices)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
prefetch_factor=PREFETCH,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=16,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=4,
|
||||
batch_size=BATCH_SIZE,
|
||||
prefetch_factor=PREFETCH,
|
||||
shuffle=False,
|
||||
num_workers=10,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
# model
|
||||
model = CavemanEnhancer().to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
# I am actually not sure it really improves anything, but there is little reason not to keep this, I guess.
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode="min",
|
||||
factor=0.1,
|
||||
patience=5,
|
||||
cooldown=3,
|
||||
threshold=1e-3,
|
||||
patience=3,
|
||||
# cooldown=3,
|
||||
# threshold=1e-3,
|
||||
)
|
||||
|
||||
from tqdm import tqdm
|
||||
# Weight: emphasize high frequencies
|
||||
weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x
|
||||
weight = torch.exp(weight)
|
||||
weight = weight.view(1, 513, 1)
|
||||
|
||||
criterion = nn.L1Loss()
|
||||
def weighted_l1_loss(pred, target):
|
||||
return torch.mean(weight * torch.abs(pred - target))
|
||||
|
||||
# criterion = nn.L1Loss()
|
||||
criterion = weighted_l1_loss
|
||||
# criterion = nn.MSELoss()
|
||||
|
||||
# baseline (doing nothing)
|
||||
# if True:
|
||||
# loss = 0.0
|
||||
# for lossy, clean in tqdm(train_loader, desc="Baseline"):
|
||||
# loss += criterion(clean, lossy)
|
||||
# loss /= len(train_loader)
|
||||
# print(f"baseling loss: {loss:.4f}")
|
||||
|
||||
# Train
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
for lossy, clean in tqdm(train_loader, desc="Training"):
|
||||
lossy, clean = lossy.to(device), clean.to(device)
|
||||
|
||||
with autocast():
|
||||
enhanced = model(lossy)
|
||||
loss = criterion(clean, enhanced)
|
||||
train_loss += loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss /= len(train_loader)
|
||||
|
||||
# Validate (per epoch)
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
val_loss = 0
|
||||
@@ -154,17 +223,17 @@ def main():
|
||||
lossy, clean = lossy.to(device), clean.to(device)
|
||||
output = model(lossy)
|
||||
loss_ = criterion(output, clean)
|
||||
|
||||
total_loss += loss_.item()
|
||||
val_loss = total_loss / len(train_loader)
|
||||
|
||||
scheduler.step(val_loss) # Update learning rate based on validation loss
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
print(f"LR: {lr:.6f}")
|
||||
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}")
|
||||
|
||||
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
|
||||
|
||||
# Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap.
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
@@ -176,12 +245,86 @@ def main():
|
||||
)
|
||||
|
||||
|
||||
def enhance_audio(model, lossy_path, output_path):
|
||||
# Load
|
||||
x, sr = librosa.load(lossy_path, sr=SR)
|
||||
def file_to_logmag(path):
|
||||
y, sr = librosa.load(path, sr=SR, mono=True)
|
||||
print(y.shape)
|
||||
return np.squeeze(audio_to_logmag(y))
|
||||
|
||||
|
||||
def show_spectrogram(path1, path2, path3):
|
||||
spectrogram1 = file_to_logmag(path1)
|
||||
spectrogram2 = file_to_logmag(path2)
|
||||
spectrogram3 = file_to_logmag(path3)
|
||||
|
||||
spectrogram1 = normalize(spectrogram1)
|
||||
spectrogram2 = normalize(spectrogram2)
|
||||
spectrogram3 = normalize(spectrogram3)
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Create a figure with two subplots (1 row, 2 columns)
|
||||
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
||||
|
||||
# Display the first image
|
||||
axes[0].imshow(spectrogram1, aspect="auto", cmap="gray")
|
||||
axes[0].set_title("spectrogram 1")
|
||||
axes[0].axis("off") # Hide axes
|
||||
|
||||
# Display the second image
|
||||
axes[1].imshow(spectrogram2, aspect="auto", cmap="gray")
|
||||
axes[1].set_title("spectrogram 2")
|
||||
axes[1].axis("off")
|
||||
|
||||
# Display the second image
|
||||
axes[2].imshow(spectrogram3, aspect="auto", cmap="gray")
|
||||
axes[2].set_title("spectrogram 3")
|
||||
axes[2].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def enhance_stereo(model, lossy_path, output_path):
|
||||
# Load stereo audio (returns shape: (2, T) if stereo)
|
||||
y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels
|
||||
|
||||
# Ensure shape is (2, T)
|
||||
if y.ndim == 1:
|
||||
raise ValueError("Input is mono! Expected stereo.")
|
||||
|
||||
y_l = y[0]
|
||||
y_r = y[1]
|
||||
|
||||
y_enhanced_l = enhance_audio(model, y_l, sr)
|
||||
y_enhanced_r = enhance_audio(model, y_r, sr)
|
||||
|
||||
stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r))
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
# Save (soundfile handles (2, T) -> stereo correctly)
|
||||
sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required
|
||||
|
||||
|
||||
# accepts shape (513, T)!!!!
|
||||
def denorm_torch(spectrogram):
|
||||
return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None]
|
||||
|
||||
|
||||
# accepts shape (513, T)!!!!
|
||||
def normalize(spectrogram):
|
||||
return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8)
|
||||
|
||||
|
||||
# accepts shape (513, T)!!!!
|
||||
def denorm(spectrogram):
|
||||
return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None]
|
||||
|
||||
|
||||
def enhance_audio(model, audio, sr):
|
||||
# Convert to log-mag
|
||||
X = audio_to_logmag(x) # (513, T)
|
||||
X = audio_to_logmag(audio) # (513, T)
|
||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
|
||||
@@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path):
|
||||
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
|
||||
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
|
||||
|
||||
Y_pred = denorm(Y_pred)
|
||||
|
||||
# Clip to valid range (log1p output ≥ 0)
|
||||
Y_pred = np.maximum(Y_pred, 0)
|
||||
|
||||
# Invert log
|
||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||
phase_lossy = np.angle(stft)
|
||||
|
||||
# Reconstruct with Griffin-Lim
|
||||
y_reconstructed = librosa.griffinlim(
|
||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||
)
|
||||
# Combine: enhanced mag + original phase
|
||||
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
|
||||
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
|
||||
return y_reconstructed
|
||||
|
||||
|
||||
def enhance_mono(model, lossy_path, output_path):
|
||||
# Load
|
||||
x, sr = librosa.load(lossy_path, sr=SR)
|
||||
|
||||
y_reconstructed = enhance_audio(model, x, sr)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
BIN
checkpoint70.pth
Normal file
BIN
checkpoint70.pth
Normal file
Binary file not shown.
Binary file not shown.
0
examples/mirror_mirror/mirror_mirror.mp3
Normal file → Executable file
0
examples/mirror_mirror/mirror_mirror.mp3
Normal file → Executable file
BIN
examples/mirror_mirror/mirror_mirror_STFT.mp3
Normal file
BIN
examples/mirror_mirror/mirror_mirror_STFT.mp3
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
89
make_stats.py
Normal file
89
make_stats.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import warnings
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from caveman_wavedataset import WaveformDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
CLEAN_DATA_DIR = "./fma_small"
|
||||
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||
SR = 44100 # sample rate
|
||||
DURATION = 2.0 # seconds per clip
|
||||
N_MELS = None # we'll use full STFT for now
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 100
|
||||
PREFETCH = 4
|
||||
|
||||
# Data
|
||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
prefetch_factor=PREFETCH,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=16,
|
||||
)
|
||||
|
||||
stats = torch.load("freq_stats.pth")
|
||||
freq_mean = stats["mean"].numpy() # (513,)
|
||||
freq_std = stats["std"].numpy() # (513,)
|
||||
|
||||
freq_mean_torch = stats["mean"] # (513,)
|
||||
freq_std_torch = stats["std"] # (513,)
|
||||
|
||||
dataset.mean = freq_mean
|
||||
dataset.std = freq_std
|
||||
|
||||
|
||||
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
|
||||
mean = torch.zeros(freq_bins, device=device)
|
||||
M2 = torch.zeros(freq_bins, device=device)
|
||||
total_frames = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for _, lossy in tqdm(dataloader, desc="Stats"):
|
||||
x = lossy.to(device) # (B, 1, F, T)
|
||||
B, _, F, T = x.shape
|
||||
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
|
||||
N = x_flat.shape[1]
|
||||
|
||||
if N == 0:
|
||||
continue
|
||||
|
||||
# Current mean (broadcast)
|
||||
mean_old = mean.clone()
|
||||
|
||||
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
|
||||
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
|
||||
|
||||
# Update M2: M2 += (x - mean_old) * (x - new_mean)
|
||||
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
|
||||
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
|
||||
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
|
||||
|
||||
total_frames += N
|
||||
|
||||
std = (
|
||||
torch.sqrt(M2 / (total_frames - 1))
|
||||
if total_frames > 1
|
||||
else torch.ones_like(mean)
|
||||
)
|
||||
return mean.cpu(), std.cpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
|
||||
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
|
||||
print(compute_per_freq_stats_vectorized(loader))
|
||||
11
misc.py
Normal file
11
misc.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import numpy as np
|
||||
import librosa
|
||||
from settings import N_FFT, HOP
|
||||
|
||||
|
||||
def audio_to_logmag(audio):
|
||||
# STFT
|
||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||
mag = np.abs(stft)
|
||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
||||
return logmag # shape: (freq_bins, time_frames) = (513, T)
|
||||
31
model.py
Normal file
31
model.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch.nn as nn
|
||||
|
||||
INPUT_KERNEL = 15
|
||||
SIZE = 32
|
||||
|
||||
|
||||
class CavemanEnhancer(nn.Module):
|
||||
def __init__(self, freq_bins=513):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=SIZE,
|
||||
kernel_size=INPUT_KERNEL,
|
||||
padding=INPUT_KERNEL // 2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, freq_bins)
|
||||
return self.net(x)
|
||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "ml-project"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.14"
|
||||
dependencies = [
|
||||
"accelerate>=1.12.0",
|
||||
"datasets>=4.4.2",
|
||||
"diffusers>=0.36.0",
|
||||
"ftfy>=6.3.1",
|
||||
"ipython>=9.9.0",
|
||||
"ipywidgets>=8.1.8",
|
||||
"jupyterlab>=4.5.1",
|
||||
"librosa>=0.11.0",
|
||||
"matplotlib>=3.10.8",
|
||||
"scikit-learn>=1.8.0",
|
||||
"torchaudio>=2.9.1",
|
||||
"torchcodec>=0.9.1",
|
||||
"torchvision>=0.24.1",
|
||||
"transformers>=4.57.3",
|
||||
]
|
||||
3
settings.py
Normal file
3
settings.py
Normal file
@@ -0,0 +1,3 @@
|
||||
SR = 44100 # sample rate
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
Reference in New Issue
Block a user