Upload files to "/"
This commit is contained in:
210
cavemanml.py
Normal file
210
cavemanml.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import warnings
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from caveman_wavedataset import WaveformDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
CLEAN_DATA_DIR = "./fma_small"
|
||||
LOSSY_DATA_DIR = "./fma_small_compressed_64/"
|
||||
SR = 44100 # sample rate
|
||||
DURATION = 2.0 # seconds per clip
|
||||
N_MELS = None # we'll use full STFT for now
|
||||
HOP = 512
|
||||
N_FFT = 1024
|
||||
|
||||
|
||||
def audio_to_logmag(audio):
|
||||
# STFT
|
||||
stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP)
|
||||
mag = np.abs(stft)
|
||||
logmag = np.log1p(mag) # log(1 + x) for stability
|
||||
return logmag # shape: (freq_bins, time_frames) = (513, T)
|
||||
|
||||
|
||||
class CavemanEnhancer(nn.Module):
|
||||
def __init__(self, freq_bins=513):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 1, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, freq_bins)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
|
||||
|
||||
def main():
|
||||
# Model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(device)
|
||||
ans = input("Do you want to use this device?")
|
||||
if ans != "y":
|
||||
exit(1)
|
||||
|
||||
model = CavemanEnhancer().to(device)
|
||||
if len(sys.argv) > 1:
|
||||
model.load_state_dict(
|
||||
torch.load(sys.argv[1], weights_only=False)["model_state_dict"]
|
||||
)
|
||||
model.eval()
|
||||
enhance_audio(
|
||||
model,
|
||||
"./mirror_mirror_compressed_64.mp3",
|
||||
"./mirror_mirror_decompressed_64_mse.wav",
|
||||
)
|
||||
|
||||
# Load
|
||||
x, sr = librosa.load("./mirror_mirror_compressed_64.mp3", sr=SR)
|
||||
|
||||
# Convert to log-mag
|
||||
X = audio_to_logmag(x) # (513, T)
|
||||
|
||||
# Clip to valid range (log1p output ≥ 0)
|
||||
Y_pred = np.maximum(X, 0)
|
||||
|
||||
# Invert log
|
||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||
|
||||
# Reconstruct with Griffin-Lim
|
||||
y_reconstructed = librosa.griffinlim(
|
||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||
)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
# Save
|
||||
sf.write("./mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr)
|
||||
return
|
||||
|
||||
# Data
|
||||
dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR)
|
||||
|
||||
n_val = int(0.1 * len(dataset))
|
||||
n_train = len(dataset) - n_val
|
||||
train_indices = list(range(n_train))
|
||||
val_indices = list(range(n_train, len(dataset)))
|
||||
train_dataset = torch.utils.data.Subset(dataset, train_indices)
|
||||
val_dataset = torch.utils.data.Subset(dataset, val_indices)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=16,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=4,
|
||||
shuffle=False,
|
||||
num_workers=10,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode="min",
|
||||
factor=0.1,
|
||||
patience=5,
|
||||
cooldown=3,
|
||||
threshold=1e-3,
|
||||
)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
criterion = nn.L1Loss()
|
||||
|
||||
# Train
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
for lossy, clean in tqdm(train_loader, desc="Training"):
|
||||
lossy, clean = lossy.to(device), clean.to(device)
|
||||
|
||||
with autocast():
|
||||
enhanced = model(lossy)
|
||||
loss = criterion(clean, enhanced)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for lossy, clean in tqdm(val_loader, desc="Validating"):
|
||||
lossy, clean = lossy.to(device), clean.to(device)
|
||||
output = model(lossy)
|
||||
loss_ = criterion(output, clean)
|
||||
total_loss += loss_.item()
|
||||
val_loss = total_loss / len(train_loader)
|
||||
|
||||
scheduler.step(val_loss) # Update learning rate based on validation loss
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
print(f"LR: {lr:.6f}")
|
||||
|
||||
print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}")
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"loss": loss,
|
||||
},
|
||||
f"checkpoint{epoch}.pth",
|
||||
)
|
||||
|
||||
|
||||
def enhance_audio(model, lossy_path, output_path):
|
||||
# Load
|
||||
x, sr = librosa.load(lossy_path, sr=SR)
|
||||
|
||||
# Convert to log-mag
|
||||
X = audio_to_logmag(x) # (513, T)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513)
|
||||
with torch.no_grad():
|
||||
Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513)
|
||||
Y_pred = Y_pred.squeeze(0) # back to (T, 513)
|
||||
|
||||
# Clip to valid range (log1p output ≥ 0)
|
||||
Y_pred = np.maximum(Y_pred, 0)
|
||||
|
||||
# Invert log
|
||||
mag_pred = np.expm1(Y_pred) # inverse of log1p
|
||||
|
||||
# Reconstruct with Griffin-Lim
|
||||
y_reconstructed = librosa.griffinlim(
|
||||
mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT
|
||||
)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
# Save
|
||||
sf.write(output_path, y_reconstructed, sr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user