Compare commits

...

4 Commits

Author SHA1 Message Date
7a04a3ac67 updated README 2026-01-10 20:49:59 +01:00
d6df888387 removed old checkpoint 2026-01-10 20:43:56 +01:00
fb97f6c3a8 added model and new examples; added pyproject toml 2026-01-10 20:40:46 +01:00
8379ac8e12 added new fixed code 2026-01-10 20:35:21 +01:00
15 changed files with 3201 additions and 111 deletions

View File

@@ -1,3 +1,31 @@
# caveman
Audio reconstruction model, that "uncompresses" lossy mp3s.
Audio reconstruction model, that "uncompresses" lossy mp3s.
## Installing dependencies
There is the pyproject file, so you can use that.
Or, if you prefer to keep your sanity, use uv:
```
uv sync
```
## Training
Adjust the location of your data directories via the variables in the cavemanml.py file. Any compression type is accepted, as long as your audio is 44100kHz.
```
uv run cavemanml.py
```
Obviously the original dataset (FMA) is not provided with this repo.
## Inference
For now, the only inference this code can do, is the example. You can adjust the input and output file in the code, though. Theoretically, the model should take on anything. To use the provided checkpoint:
```
uv run cavemanml.py ./checkpoint70.pth
```

View File

@@ -4,29 +4,21 @@ import librosa
from torch.utils.data import Dataset
import numpy as np
import random
HOP = 512
N_FFT = 1024
DURATION = 2.0
SR = 44100
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T)
from settings import SR, N_FFT
from misc import audio_to_logmag
class WaveformDataset(Dataset):
def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4):
self.cache = dict()
mean = np.zeros([N_FFT // 2 + 1])
std = np.ones([N_FFT // 2 + 1])
# Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!!
# For the purposes of this file, it's the length of the audio clip being selected from the dataset.
def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR):
self.segment_duration = segment_duration
self.sr = sr
self.lossy_dir = lossy_dir
self.clean_dir = clean_dir
self.segment_len = int(segment_sec * sr)
self.lossy_files = sorted(os.listdir(lossy_dir))
self.clean_files = sorted(os.listdir(clean_dir))
self.file_pairs = [
@@ -51,9 +43,9 @@ class WaveformDataset(Dataset):
min_len = min(len(lossy), len(clean))
lossy, clean = lossy[:min_len], clean[:min_len]
# Random 2-second clip
# Random clip
clip_len = int(DURATION * SR)
clip_len = int(self.segment_duration * SR)
if min_len < clip_len:
# pad if too short
lossy = np.pad(lossy, (0, clip_len - min_len))
@@ -61,14 +53,21 @@ class WaveformDataset(Dataset):
start = 0
else:
start = random.randint(0, min_len - clip_len)
# start = 0
lossy = lossy[start : start + clip_len]
clean = clean[start : start + clip_len]
logmag_x = audio_to_logmag(lossy)
logmag_y = audio_to_logmag(clean)
logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8)
logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8)
ans = (
torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0),
torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0),
torch.from_numpy(logmag_x_norm).float().unsqueeze(0),
torch.from_numpy(logmag_y_norm).float().unsqueeze(0),
)
self.cache[idx] = ans
# self.cache[idx] = ans
return ans

View File

@@ -1,53 +1,120 @@
import os
import sys
import random
import warnings
import librosa
import numpy as np
import torch.nn as nn
import torch
import soundfile as sf
from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.cuda.amp import autocast
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from misc import audio_to_logmag
from settings import N_FFT, HOP, SR
from model import CavemanEnhancer
warnings.filterwarnings("ignore")
CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (freq_bins, time_frames) = (513, T)
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)
# Duration is the duration of each example to be selected from the dataset.
# Since we are using the FMA dataset, it's max value is 30s.
# From the standpoint of the model and training,
# it should make absolutely no difference on quality,
# only on the speed of the training process. If duration is larger,
# model is going to be trained on more data per example.
# If smaller, less data per example.
#
# YOU ONLY NEED TO CHANGE THIS IS IF YOU ARE CPU-BOUND WHEN
# LOADING THE DATA DURING TRAINING. INCREASE TO PLACE MORE LOAD
# ON THE GPU, REDUCE TO PUT MORE LOAD ON THE CPU. DO NOT ADJUST
# THE BATCH SIZE, IT WILL MAKE NO DIFFERENCE, SINCE WE ARE ALWAYS
# FORCED TO LOAD THE ENTIRE EXAMPLE FROM DISK EVERY SINGLE TIME.
DURATION = 2
BATCH_SIZE = 4
# 100 is a bit ridicilous, but you are free to Ctrl-C anytime, since the checkpoints are always saved.
EPOCHS = 100
PREFETCH = 4
# stats = torch.load("freq_stats.pth")
# freq_mean = stats["mean"].numpy() # (513,)
# freq_std = stats["std"].numpy() # (513,)
freq_mean = np.zeros([N_FFT // 2 + 1])
# (513,)
freq_std = np.ones([N_FFT // 2 + 1]) # (513,)
# freq_mean_torch = stats["mean"] # (513,)
# freq_std_torch = stats["std"] # (513,)
freq_mean_torch = torch.from_numpy(freq_mean)
freq_std_torch = torch.from_numpy(freq_std)
def run_example(model_filename, device):
model = CavemanEnhancer().to(device)
model.load_state_dict(
torch.load(model_filename, weights_only=False)["model_state_dict"]
)
model.eval()
enhance_mono(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
# Load
x, sr = librosa.load(
"./examples/mirror_mirror/mirror_mirror.mp3",
sr=SR,
)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
Y_pred = normalize(X)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
stft = librosa.stft(x, n_fft=N_FFT, hop_length=HOP)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim
# y_reconstructed = librosa.griffinlim(
# mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
# )
# Combine: enhanced mag + original phase
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
time = np.minimum(x.shape[0], y_reconstructed.shape[0])
print(
f"Loss from reconstruction: {nn.MSELoss()(torch.from_numpy(x[:time]), torch.from_numpy(y_reconstructed[:time]))}"
)
# Save
sf.write(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
y_reconstructed,
sr,
)
show_spectrogram(
"./examples/mirror_mirror/mirror_mirror_STFT.mp3",
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64.wav",
)
return
def main():
@@ -58,94 +125,96 @@ def main():
if ans != "y":
exit(1)
model = CavemanEnhancer().to(device)
if len(sys.argv) > 1:
model.load_state_dict(
torch.load(sys.argv[1], weights_only=False)["model_state_dict"]
)
model.eval()
enhance_audio(
model,
"./examples/mirror_mirror/mirror_mirror_compressed_64.mp3",
"./examples/mirror_mirror/mirror_mirror_decompressed_64_mse.wav",
)
# Load
x, sr = librosa.load("./examples/mirror_mirror/mirror_mirror_compressed_64.mp3", sr=SR)
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(X, 0)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
# Reconstruct with Griffin-Lim
y_reconstructed = librosa.griffinlim(
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
)
import soundfile as sf
# Save
sf.write("./examples/mirror_mirror/mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
return
run_example(sys.argv[1], device)
exit(0)
# Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, DURATION, sr=SR)
dataset.mean = freq_mean
dataset.std = freq_std
# separate the test and val data
n_val = int(0.1 * len(dataset))
n_train = len(dataset) - n_val
train_indices = list(range(n_train))
val_indices = list(range(n_train, len(dataset)))
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=True,
pin_memory=True,
num_workers=16,
)
val_loader = DataLoader(
val_dataset,
batch_size=4,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False,
num_workers=10,
num_workers=16,
pin_memory=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# model
model = CavemanEnhancer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# I am actually not sure it really improves anything, but there is little reason not to keep this, I guess.
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=5,
cooldown=3,
threshold=1e-3,
patience=3,
# cooldown=3,
# threshold=1e-3,
)
from tqdm import tqdm
# Weight: emphasize high frequencies
weight = torch.linspace(1.0, 8.0, 513).to(device) # low=1x, high=8x
weight = torch.exp(weight)
weight = weight.view(1, 513, 1)
criterion = nn.L1Loss()
def weighted_l1_loss(pred, target):
return torch.mean(weight * torch.abs(pred - target))
# criterion = nn.L1Loss()
criterion = weighted_l1_loss
# criterion = nn.MSELoss()
# baseline (doing nothing)
# if True:
# loss = 0.0
# for lossy, clean in tqdm(train_loader, desc="Baseline"):
# loss += criterion(clean, lossy)
# loss /= len(train_loader)
# print(f"baseling loss: {loss:.4f}")
# Train
for epoch in range(EPOCHS):
model.train()
train_loss = 0.0
for lossy, clean in tqdm(train_loader, desc="Training"):
lossy, clean = lossy.to(device), clean.to(device)
with autocast():
enhanced = model(lossy)
loss = criterion(clean, enhanced)
train_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss /= len(train_loader)
# Validate (per epoch)
model.eval()
total_loss = 0.0
val_loss = 0
@@ -154,17 +223,17 @@ def main():
lossy, clean = lossy.to(device), clean.to(device)
output = model(lossy)
loss_ = criterion(output, clean)
total_loss += loss_.item()
val_loss = total_loss / len(train_loader)
scheduler.step(val_loss) # Update learning rate based on validation loss
if (epoch + 1) % 10 == 0:
lr = optimizer.param_groups[0]["lr"]
print(f"LR: {lr:.6f}")
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
lr = optimizer.param_groups[0]["lr"]
print(f"LR: {lr:.6f}")
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val: {val_loss:.4f}")
# Yes, we are saving checkpoints for every epoch. The model is small, and disk space cheap.
torch.save(
{
"epoch": epoch,
@@ -176,12 +245,86 @@ def main():
)
def enhance_audio(model, lossy_path, output_path):
# Load
x, sr = librosa.load(lossy_path, sr=SR)
def file_to_logmag(path):
y, sr = librosa.load(path, sr=SR, mono=True)
print(y.shape)
return np.squeeze(audio_to_logmag(y))
def show_spectrogram(path1, path2, path3):
spectrogram1 = file_to_logmag(path1)
spectrogram2 = file_to_logmag(path2)
spectrogram3 = file_to_logmag(path3)
spectrogram1 = normalize(spectrogram1)
spectrogram2 = normalize(spectrogram2)
spectrogram3 = normalize(spectrogram3)
from matplotlib import pyplot as plt
# Create a figure with two subplots (1 row, 2 columns)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
# Display the first image
axes[0].imshow(spectrogram1, aspect="auto", cmap="gray")
axes[0].set_title("spectrogram 1")
axes[0].axis("off") # Hide axes
# Display the second image
axes[1].imshow(spectrogram2, aspect="auto", cmap="gray")
axes[1].set_title("spectrogram 2")
axes[1].axis("off")
# Display the second image
axes[2].imshow(spectrogram3, aspect="auto", cmap="gray")
axes[2].set_title("spectrogram 3")
axes[2].axis("off")
plt.tight_layout()
plt.show()
def enhance_stereo(model, lossy_path, output_path):
# Load stereo audio (returns shape: (2, T) if stereo)
y, sr = librosa.load(lossy_path, sr=SR, mono=False) # mono=False preserves channels
# Ensure shape is (2, T)
if y.ndim == 1:
raise ValueError("Input is mono! Expected stereo.")
y_l = y[0]
y_r = y[1]
y_enhanced_l = enhance_audio(model, y_l, sr)
y_enhanced_r = enhance_audio(model, y_r, sr)
stereo_reconstructed = np.vstack((y_enhanced_l, y_enhanced_r))
import soundfile as sf
# Save (soundfile handles (2, T) -> stereo correctly)
sf.write(output_path, stereo_reconstructed.T, sr) # Note: .T to (T, 2) if required
# accepts shape (513, T)!!!!
def denorm_torch(spectrogram):
return spectrogram * (freq_std_torch[:, None] + 1e-8) + freq_mean_torch[:, None]
# accepts shape (513, T)!!!!
def normalize(spectrogram):
return (spectrogram - freq_mean[:, None]) / (freq_std[:, None] + 1e-8)
# accepts shape (513, T)!!!!
def denorm(spectrogram):
return spectrogram * (freq_std[:, None] + 1e-8) + freq_mean[:, None]
def enhance_audio(model, audio, sr):
# Convert to log-mag
X = audio_to_logmag(x) # (513, T)
X = audio_to_logmag(audio) # (513, T)
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
device = "cuda" if torch.cuda.is_available() else "cpu"
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
@@ -189,16 +332,26 @@ def enhance_audio(model, lossy_path, output_path):
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
Y_pred = denorm(Y_pred)
# Clip to valid range (log1p output ≥ 0)
Y_pred = np.maximum(Y_pred, 0)
# Invert log
mag_pred = np.expm1(Y_pred) # inverse of log1p
phase_lossy = np.angle(stft)
# Reconstruct with Griffin-Lim
y_reconstructed = librosa.griffinlim(
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
)
# Combine: enhanced mag + original phase
stft_enhanced = mag_pred * np.exp(1j * phase_lossy)
y_reconstructed = librosa.istft(stft_enhanced, n_fft=N_FFT, hop_length=HOP)
return y_reconstructed
def enhance_mono(model, lossy_path, output_path):
# Load
x, sr = librosa.load(lossy_path, sr=SR)
y_reconstructed = enhance_audio(model, x, sr)
import soundfile as sf

BIN
checkpoint70.pth Normal file

Binary file not shown.

Binary file not shown.

0
examples/mirror_mirror/mirror_mirror.mp3 Normal file → Executable file
View File

Binary file not shown.

89
make_stats.py Normal file
View File

@@ -0,0 +1,89 @@
import os
import sys
import random
import warnings
import librosa
import numpy as np
import torch.nn as nn
import torch
from caveman_wavedataset import WaveformDataset
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
CLEAN_DATA_DIR = "./fma_small"
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
SR = 44100 # sample rate
DURATION = 2.0 # seconds per clip
N_MELS = None # we'll use full STFT for now
HOP = 512
N_FFT = 1024
BATCH_SIZE = 16
EPOCHS = 100
PREFETCH = 4
# Data
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
prefetch_factor=PREFETCH,
shuffle=False,
pin_memory=True,
num_workers=16,
)
stats = torch.load("freq_stats.pth")
freq_mean = stats["mean"].numpy() # (513,)
freq_std = stats["std"].numpy() # (513,)
freq_mean_torch = stats["mean"] # (513,)
freq_std_torch = stats["std"] # (513,)
dataset.mean = freq_mean
dataset.std = freq_std
def compute_per_freq_stats_vectorized(dataloader, freq_bins=513, device="cuda"):
mean = torch.zeros(freq_bins, device=device)
M2 = torch.zeros(freq_bins, device=device)
total_frames = 0
with torch.no_grad():
for _, lossy in tqdm(dataloader, desc="Stats"):
x = lossy.to(device) # (B, 1, F, T)
B, _, F, T = x.shape
x_flat = x.squeeze(1).permute(1, 0, 2).reshape(F, -1) # (F, N)
N = x_flat.shape[1]
if N == 0:
continue
# Current mean (broadcast)
mean_old = mean.clone()
# Update mean: mean = mean + (sum(x) - N*mean) / (total + N)
mean = mean_old + (x_flat.sum(dim=1) - N * mean_old) / (total_frames + N)
# Update M2: M2 += (x - mean_old) * (x - new_mean)
delta_old = x_flat - mean_old.unsqueeze(1) # (F, N)
delta_new = x_flat - mean.unsqueeze(1) # (F, N)
M2 += (delta_old * delta_new).sum(dim=1) # (F,)
total_frames += N
std = (
torch.sqrt(M2 / (total_frames - 1))
if total_frames > 1
else torch.ones_like(mean)
)
return mean.cpu(), std.cpu()
if __name__ == "__main__":
# freq_mean, freq_std = compute_per_freq_stats_vectorized(loader, device="cuda")
# torch.save({"mean": freq_mean, "std": freq_std}, "freq_stats.pth")
print(compute_per_freq_stats_vectorized(loader))

11
misc.py Normal file
View File

@@ -0,0 +1,11 @@
import numpy as np
import librosa
from settings import N_FFT, HOP
def audio_to_logmag(audio):
# STFT
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
mag = np.abs(stft)
logmag = np.log1p(mag) # log(1 + x) for stability
return logmag # shape: (freq_bins, time_frames) = (513, T)

31
model.py Normal file
View File

@@ -0,0 +1,31 @@
import torch.nn as nn
INPUT_KERNEL = 15
SIZE = 32
class CavemanEnhancer(nn.Module):
def __init__(self, freq_bins=513):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=SIZE,
kernel_size=INPUT_KERNEL,
padding=INPUT_KERNEL // 2,
),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, SIZE, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(SIZE, 1, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (batch, freq_bins)
return self.net(x)

22
pyproject.toml Normal file
View File

@@ -0,0 +1,22 @@
[project]
name = "ml-project"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.14"
dependencies = [
"accelerate>=1.12.0",
"datasets>=4.4.2",
"diffusers>=0.36.0",
"ftfy>=6.3.1",
"ipython>=9.9.0",
"ipywidgets>=8.1.8",
"jupyterlab>=4.5.1",
"librosa>=0.11.0",
"matplotlib>=3.10.8",
"scikit-learn>=1.8.0",
"torchaudio>=2.9.1",
"torchcodec>=0.9.1",
"torchvision>=0.24.1",
"transformers>=4.57.3",
]

3
settings.py Normal file
View File

@@ -0,0 +1,3 @@
SR = 44100 # sample rate
HOP = 512
N_FFT = 1024

2754
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff