74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import os
|
|
import torch
|
|
import librosa
|
|
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
import random
|
|
from settings import SR, N_FFT
|
|
from misc import audio_to_logmag
|
|
|
|
|
|
class WaveformDataset(Dataset):
|
|
mean = np.zeros([N_FFT // 2 + 1])
|
|
std = np.ones([N_FFT // 2 + 1])
|
|
|
|
# Duration is a very very important parameter, read the cavemanml.py to see how and why adjust it!!!
|
|
# For the purposes of this file, it's the length of the audio clip being selected from the dataset.
|
|
def __init__(self, lossy_dir, clean_dir, segment_duration, sr=SR):
|
|
self.segment_duration = segment_duration
|
|
self.sr = sr
|
|
self.lossy_dir = lossy_dir
|
|
self.clean_dir = clean_dir
|
|
self.lossy_files = sorted(os.listdir(lossy_dir))
|
|
self.clean_files = sorted(os.listdir(clean_dir))
|
|
self.file_pairs = [
|
|
(f, f) for f in self.lossy_files if f in set(self.clean_files)
|
|
]
|
|
|
|
def __len__(self):
|
|
return len(self.file_pairs)
|
|
|
|
def __getitem__(self, idx):
|
|
if idx in self.cache:
|
|
return self.cache[idx]
|
|
|
|
lossy_path = os.path.join(self.lossy_dir, self.lossy_files[idx])
|
|
clean_path = os.path.join(self.clean_dir, self.clean_files[idx])
|
|
|
|
# Load
|
|
lossy, _ = librosa.load(lossy_path, sr=self.sr, mono=True)
|
|
clean, _ = librosa.load(clean_path, sr=self.sr, mono=True)
|
|
|
|
# Match length
|
|
min_len = min(len(lossy), len(clean))
|
|
lossy, clean = lossy[:min_len], clean[:min_len]
|
|
|
|
# Random clip
|
|
|
|
clip_len = int(self.segment_duration * SR)
|
|
if min_len < clip_len:
|
|
# pad if too short
|
|
lossy = np.pad(lossy, (0, clip_len - min_len))
|
|
clean = np.pad(clean, (0, clip_len - min_len))
|
|
start = 0
|
|
else:
|
|
start = random.randint(0, min_len - clip_len)
|
|
# start = 0
|
|
lossy = lossy[start : start + clip_len]
|
|
clean = clean[start : start + clip_len]
|
|
|
|
logmag_x = audio_to_logmag(lossy)
|
|
logmag_y = audio_to_logmag(clean)
|
|
|
|
logmag_x_norm = (logmag_x - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
|
logmag_y_norm = (logmag_y - self.mean[:, None]) / (self.std[:, None] + 1e-8)
|
|
|
|
ans = (
|
|
torch.from_numpy(logmag_x_norm).float().unsqueeze(0),
|
|
torch.from_numpy(logmag_y_norm).float().unsqueeze(0),
|
|
)
|
|
|
|
# self.cache[idx] = ans
|
|
|
|
return ans
|