diff --git a/caveman_wavedataset.py b/caveman_wavedataset.py new file mode 100644 index 0000000..1518b84 --- /dev/null +++ b/caveman_wavedataset.py @@ -0,0 +1,74 @@ +import os +import torch +import librosa +from torch.utils.data import Dataset +import numpy as np +import random + + +HOP = 512 +N_FFT = 1024 +DURATION = 2.0 +SR = 44100 + + +def audio_to_logmag(audio): + # STFT + stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) + mag = np.abs(stft) + logmag = np.log1p(mag) # log(1 + x) for stability + return logmag # shape: (1, freq_bins, time_frames) = (1, 513, T) + + +class WaveformDataset(Dataset): + def __init__(self, lossy_dir, clean_dir, sr=SR, segment_sec=4): + self.cache = dict() + self.sr = sr + self.lossy_dir = lossy_dir + self.clean_dir = clean_dir + self.segment_len = int(segment_sec * sr) + self.lossy_files = sorted(os.listdir(lossy_dir)) + self.clean_files = sorted(os.listdir(clean_dir)) + self.file_pairs = [ + (f, f) for f in self.lossy_files if f in set(self.clean_files) + ] + + def __len__(self): + return len(self.file_pairs) + + def __getitem__(self, idx): + if idx in self.cache: + return self.cache[idx] + + lossy_path = os.path.join(self.lossy_dir, self.lossy_files[idx]) + clean_path = os.path.join(self.clean_dir, self.clean_files[idx]) + + # Load + lossy, _ = librosa.load(lossy_path, sr=self.sr, mono=True) + clean, _ = librosa.load(clean_path, sr=self.sr, mono=True) + + # Match length + min_len = min(len(lossy), len(clean)) + lossy, clean = lossy[:min_len], clean[:min_len] + + # Random 2-second clip + + clip_len = int(DURATION * SR) + if min_len < clip_len: + # pad if too short + lossy = np.pad(lossy, (0, clip_len - min_len)) + clean = np.pad(clean, (0, clip_len - min_len)) + start = 0 + else: + start = random.randint(0, min_len - clip_len) + lossy = lossy[start : start + clip_len] + clean = clean[start : start + clip_len] + + ans = ( + torch.from_numpy(audio_to_logmag(lossy)).unsqueeze(0), + torch.from_numpy(audio_to_logmag(clean)).unsqueeze(0), + ) + + self.cache[idx] = ans + + return ans diff --git a/cavemanml.py b/cavemanml.py new file mode 100644 index 0000000..9e1a7d9 --- /dev/null +++ b/cavemanml.py @@ -0,0 +1,210 @@ +import os +import sys +import random +import warnings +import librosa +import numpy as np +import torch.nn as nn +import torch +from caveman_wavedataset import WaveformDataset +from torch.utils.data import DataLoader +from torch.cuda.amp import autocast, GradScaler +import torch.optim.lr_scheduler as lr_scheduler + +warnings.filterwarnings("ignore") + +CLEAN_DATA_DIR = "./fma_small" +LOSSY_DATA_DIR = "./fma_small_compressed_64/" +SR = 44100 # sample rate +DURATION = 2.0 # seconds per clip +N_MELS = None # we'll use full STFT for now +HOP = 512 +N_FFT = 1024 + + +def audio_to_logmag(audio): + # STFT + stft = librosa.stft(audio, n_fft=N_FFT, hop_length=HOP) + mag = np.abs(stft) + logmag = np.log1p(mag) # log(1 + x) for stability + return logmag # shape: (freq_bins, time_frames) = (513, T) + + +class CavemanEnhancer(nn.Module): + def __init__(self, freq_bins=513): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=3, padding=1), + ) + + def forward(self, x): + # x: (batch, freq_bins) + return self.net(x) + + +BATCH_SIZE = 4 +EPOCHS = 100 + + +def main(): + # Model + device = "cuda" if torch.cuda.is_available() else "cpu" + print(device) + ans = input("Do you want to use this device?") + if ans != "y": + exit(1) + + model = CavemanEnhancer().to(device) + if len(sys.argv) > 1: + model.load_state_dict( + torch.load(sys.argv[1], weights_only=False)["model_state_dict"] + ) + model.eval() + enhance_audio( + model, + "./mirror_mirror_compressed_64.mp3", + "./mirror_mirror_decompressed_64_mse.wav", + ) + + # Load + x, sr = librosa.load("./mirror_mirror_compressed_64.mp3", sr=SR) + + # Convert to log-mag + X = audio_to_logmag(x) # (513, T) + + # Clip to valid range (log1p output ≥ 0) + Y_pred = np.maximum(X, 0) + + # Invert log + mag_pred = np.expm1(Y_pred) # inverse of log1p + + # Reconstruct with Griffin-Lim + y_reconstructed = librosa.griffinlim( + mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT + ) + + import soundfile as sf + + # Save + sf.write("./mirror_mirror_compressed_64_STFT.mp3", y_reconstructed, sr) + return + + # Data + dataset = WaveformDataset(LOSSY_DATA_DIR, CLEAN_DATA_DIR, sr=SR) + + n_val = int(0.1 * len(dataset)) + n_train = len(dataset) - n_val + train_indices = list(range(n_train)) + val_indices = list(range(n_train, len(dataset))) + train_dataset = torch.utils.data.Subset(dataset, train_indices) + val_dataset = torch.utils.data.Subset(dataset, val_indices) + + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + pin_memory=True, + num_workers=16, + ) + val_loader = DataLoader( + val_dataset, + batch_size=4, + shuffle=False, + num_workers=10, + pin_memory=True, + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=0.1, + patience=5, + cooldown=3, + threshold=1e-3, + ) + + from tqdm import tqdm + + criterion = nn.L1Loss() + + # Train + for epoch in range(EPOCHS): + model.train() + for lossy, clean in tqdm(train_loader, desc="Training"): + lossy, clean = lossy.to(device), clean.to(device) + + with autocast(): + enhanced = model(lossy) + loss = criterion(clean, enhanced) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + total_loss = 0.0 + val_loss = 0 + with torch.no_grad(): + for lossy, clean in tqdm(val_loader, desc="Validating"): + lossy, clean = lossy.to(device), clean.to(device) + output = model(lossy) + loss_ = criterion(output, clean) + total_loss += loss_.item() + val_loss = total_loss / len(train_loader) + + scheduler.step(val_loss) # Update learning rate based on validation loss + + if (epoch + 1) % 10 == 0: + lr = optimizer.param_groups[0]["lr"] + print(f"LR: {lr:.6f}") + + print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val: {val_loss:.4f}") + + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "loss": loss, + }, + f"checkpoint{epoch}.pth", + ) + + +def enhance_audio(model, lossy_path, output_path): + # Load + x, sr = librosa.load(lossy_path, sr=SR) + + # Convert to log-mag + X = audio_to_logmag(x) # (513, T) + + device = "cuda" if torch.cuda.is_available() else "cpu" + X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device) # (T, 513) + with torch.no_grad(): + Y_pred = model(X_tensor).cpu().numpy() # (1, T, 513) + Y_pred = Y_pred.squeeze(0) # back to (T, 513) + + # Clip to valid range (log1p output ≥ 0) + Y_pred = np.maximum(Y_pred, 0) + + # Invert log + mag_pred = np.expm1(Y_pred) # inverse of log1p + + # Reconstruct with Griffin-Lim + y_reconstructed = librosa.griffinlim( + mag_pred, n_iter=30, hop_length=HOP, win_length=N_FFT, n_fft=N_FFT + ) + + import soundfile as sf + + # Save + sf.write(output_path, y_reconstructed, sr) + + +if __name__ == "__main__": + main() diff --git a/checkpoint_64k.pth b/checkpoint_64k.pth new file mode 100644 index 0000000..f84deaf Binary files /dev/null and b/checkpoint_64k.pth differ diff --git a/compress.bash b/compress.bash new file mode 100644 index 0000000..4d375b8 --- /dev/null +++ b/compress.bash @@ -0,0 +1,127 @@ +#!/bin/bash + +# ============================================================================= +# Batch Audio Compressor - 96 kbps AAC +# Preserves folder structure, uses ffmpeg + GNU parallel + hardware acceleration +# ============================================================================= + +set -euo pipefail + +# ------------------------------- +# Configuration +# ------------------------------- +SOURCE_DIR="${1:-./audio_source}" +OUTPUT_DIR="${2:-./audio_96kbps}" +LOG_FILE="./compression.log" +#NUM_JOBS="1" +NUM_JOBS="${PARALLEL_JOBS:-$(nproc)}" # Use all cores by default + +# Supported input audio extensions (lowercase) +declare -a AUDIO_EXTS=("wav" "flac" "aiff" "aif" "mp3" "m4a" "ogg" "wma" "ac3" "alac") + +# ------------------------------- +# Functions +# ------------------------------- + +log() { + echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$LOG_FILE" +} + +detect_hw_accel() { + + echo "vdpau" + return + + # Cuda? + if ffmpeg -hwaccels 2>/dev/null | grep -q cuda; then + echo "cuda" + # Try to detect available hardware acceleration + elif ffmpeg -hwaccels 2>/dev/null | grep -q vaapi; then + echo "vaapi" + elif ffmpeg -buildconf 2>&1 | grep -q "enable-vdpau"; then + echo "vdpau" + elif [[ "$(uname)" == "Darwin" ]] && ffmpeg -codecs 2>/dev/null | grep -q 'h264.* videotoolbox'; then + echo "videotoolbox" + else + echo "none" + fi +} + +get_hw_args() { + local hw_accel="$1" + local input_file="$2" + + case "$hw_accel" in + "vaapi") + echo "-vaapi_device /dev/dri/renderD128 -vf 'format=nv12,hwupload' -c:a aac -b:a 64k" + ;; + "videotoolbox") + # Apple's VideoToolbox (macOS) — fast for H.264, less useful for audio, but can accelerate some codecs + # Note: Audio encoding isn't accelerated, but we include for completeness if video is present + echo "-c:a aac -b:a 64k -c:v h264_videotoolbox" + ;; + *) + echo "-c:a libmp3lame -b:a 64k" # Fallback to software encoding + ;; + esac +} + +# ------------------------------- +# Input validation +# ------------------------------- +if [[ ! -d "$SOURCE_DIR" ]]; then + echo "Error: Source directory does not exist: $SOURCE_DIR" + echo "Usage: $0 " + exit 1 +fi + +mkdir -p "$OUTPUT_DIR" +log "Starting compression of '$SOURCE_DIR' -> '$OUTPUT_DIR' at 96 kbps AAC" +log "Using $NUM_JOBS parallel jobs" + +# ------------------------------- +# Detect hardware acceleration +# ------------------------------- +HW_ACCEL="$(detect_hw_accel)" +log "Hardware acceleration detected: $HW_ACCEL" + +# ------------------------------- +# Find all audio files and process them via parallel +# ------------------------------- +# Export functions and variables for GNU parallel +export -f get_hw_args +export SOURCE_DIR +export OUTPUT_DIR +export HW_ACCEL + +# Build find command for all audio extensions +find_cmd="find \"$SOURCE_DIR\" -type f \\( " +for ext in "${AUDIO_EXTS[@]}"; do + find_cmd+=" -iname \"*.${ext}\" -o" +done +# Replace trailing "-o" with "\\)" +find_cmd="${find_cmd% -o} \\)" + +# Use eval to execute the dynamic find command and pipe to parallel +eval "$find_cmd" | sort | parallel -j"$NUM_JOBS" --progress --bar --joblog parallel_jobs.log --eta ' + input_file="{}" + rel_path="${SOURCE_DIR:+${input_file#"$SOURCE_DIR"/}}" + output_file="'$OUTPUT_DIR'/${rel_path%.*}.mp3" + + # Create output directory if needed + mkdir -p "$(dirname "$output_file")" + + if [[ -f "$output_file" ]]; then + echo "Skipped (exists): $rel_path" + exit 0 + fi + + hw_args=$(get_hw_args "'$HW_ACCEL'" "$input_file") + ffmpeg -v warning -stats \ + -i "$input_file" \ + $hw_args \ + -y "$output_file" \ + && echo "Converted: $rel_path" +' + +log "Compression complete. Output saved to: $OUTPUT_DIR"