Source code for selfclean_audio.datasets.base
# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base audio dataset class providing common functionality for audio loading and preprocessing."""
from abc import ABC, abstractmethod
from pathlib import Path
import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset
[docs]
class BaseAudioDataset(Dataset, ABC):
"""
Base class for audio datasets with common preprocessing functionality.
Provides standardized audio loading, mono conversion, resampling, and duration handling.
"""
def __init__(
self,
root: str | None = None,
convert_mono: bool = True,
sample_rate: int = 44100,
target_duration_sec: float | None = None,
):
"""
Initialize base audio dataset.
Args:
root: Root directory path for the dataset
convert_mono: Convert stereo audio to mono if True
sample_rate: Target sample rate for audio (will resample if needed)
target_duration_sec: Target duration in seconds (will pad/trim if specified)
"""
self.root = Path(root) if root else Path("./")
self.convert_mono = convert_mono
self.sample_rate = sample_rate
self.target_duration_sec = target_duration_sec
self.target_length = (
int(target_duration_sec * sample_rate) if target_duration_sec else None
)
def _load_and_preprocess_audio(
self, audio_path: str | Path
) -> tuple[torch.Tensor, int]:
"""
Load and preprocess audio file with standardized transformations.
Args:
audio_path: Path to audio file
Returns:
Tuple of (waveform, sample_rate) where waveform is preprocessed
"""
audio_path = Path(audio_path)
# Load audio file
waveform, file_sample_rate = torchaudio.load(str(audio_path))
# Convert to mono if requested
if self.convert_mono and waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if needed
if file_sample_rate != self.sample_rate:
waveform = T.Resample(
orig_freq=file_sample_rate, new_freq=self.sample_rate
)(waveform)
# Handle target duration (pad or trim)
if self.target_length is not None:
waveform = self._resize_waveform(waveform, self.target_length)
return waveform, self.sample_rate
def _resize_waveform(
self, waveform: torch.Tensor, target_length: int
) -> torch.Tensor:
"""
Resize waveform to target length by padding or trimming.
Args:
waveform: Input waveform tensor of shape (C, T)
target_length: Target length in samples
Returns:
Resized waveform tensor
"""
current_length = waveform.size(1)
if current_length < target_length:
# Pad with zeros
padding = target_length - current_length
waveform = torch.nn.functional.pad(waveform, (0, padding))
elif current_length > target_length:
# Trim to target length
waveform = waveform[:, :target_length]
return waveform
@abstractmethod
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
pass
@abstractmethod
def __getitem__(self, idx: int) -> tuple:
"""Get a sample from the dataset."""
pass