Source code for selfclean_audio.datasets.off_topic_dataset

# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import math
from typing import Literal

import numpy as np
import torch
from loguru import logger
from torch.utils.data import Dataset

from selfclean_audio.utils.sample import extract_sample


[docs] class OffTopicDataset(Dataset): """ Dataset that creates off-topic samples by: 1. Adding samples from an unrelated dataset 2. Adding pure noise samples 3. Adding heavily corrupted versions of original samples This follows the same approach as the image domain implementation. """ def __init__( self, dataset: Dataset, contamination_dataset: Dataset | None = None, frac_error: float = 0.1, n_errors: int | None = None, contamination_strategy: Literal[ "external", "noise", "corrupted", "combined" ] = "combined", noise_level: float = 0.5, # For noise and corrupted strategies random_state: int = 42, name: str | None = None, ): """ Args: dataset: Original clean dataset contamination_dataset: External dataset for contamination (e.g., MUSAN for ESC50) frac_error: Fraction of samples to contaminate n_errors: Exact number of contaminated samples (overrides frac_error) contamination_strategy: How to create off-topic samples noise_level: Level of noise/corruption (0-1) random_state: Random seed for reproducibility name: Dataset name for logging """ self.name = ( name if name is not None else f"OffTopicDataset_{contamination_strategy}" ) self.original_dataset = dataset self.contamination_dataset = contamination_dataset self.contamination_strategy = contamination_strategy self.noise_level = noise_level logger.info(f"Initializing {self.name}") # Forward sample_rate attribute for downstream components (e.g., LoRA adaptation) # Prefer the original dataset's attribute if present, otherwise fallback to common default if hasattr(self.original_dataset, "sample_rate"): try: self.sample_rate = int(getattr(self.original_dataset, "sample_rate")) except Exception: # pragma: no cover - defensive self.sample_rate = 16000 else: self.sample_rate = 16000 # Set random seed for reproducibility np.random.seed(random_state) torch.manual_seed(random_state) # Determine how many samples to contaminate if n_errors is not None: n_contaminated = n_errors else: n_contaminated = math.ceil(frac_error * len(self.original_dataset)) # Randomly select indices to contaminate idx_range = np.arange(0, len(self.original_dataset)) self.contaminated_indices = set( np.random.choice(idx_range, size=n_contaminated, replace=False) ) # Pre-generate contaminated samples for consistency self.contaminated_samples = {} for idx in self.contaminated_indices: audio, path, label, _ = extract_sample(self.original_dataset[idx]) # Create off-topic sample contaminated_audio = self._create_off_topic_sample(audio) contaminated_path = f"{path}_off_topic" # Store contaminated version (off-topic audio, modified path, original label, noisy_label=1) self.contaminated_samples[idx] = ( contaminated_audio, contaminated_path, label, torch.tensor(1), ) def _create_off_topic_sample(self, original_audio: torch.Tensor) -> torch.Tensor: """Create an off-topic sample based on the contamination strategy""" if self.contamination_strategy == "external": # Use sample from external dataset if self.contamination_dataset is not None: ext_idx = np.random.randint(0, len(self.contamination_dataset)) if hasattr(self.contamination_dataset, "__getitem__"): try: ext_data = self.contamination_dataset[ext_idx] if isinstance(ext_data, tuple) and len(ext_data) > 0: ext_audio = ext_data[0] # First element should be audio # Ensure same shape as original if ext_audio.shape != original_audio.shape: # Simple resize/pad/crop to match shape ext_audio = self._match_audio_shape( ext_audio, original_audio.shape ) return ext_audio except Exception as e: logger.error(f"Failed to get item from cont. dataset: {e}") # Fallback to noise if external dataset not available/working return self._create_noise_sample(original_audio) elif self.contamination_strategy == "noise": # Pure noise return self._create_noise_sample(original_audio) elif self.contamination_strategy == "corrupted": # Heavily corrupted version of original corrupted = original_audio.clone() noise = torch.normal(0, self.noise_level, size=corrupted.shape) # High corruption level corrupted = corrupted + noise return torch.clamp(corrupted, -1.0, 1.0) elif self.contamination_strategy == "combined": # Randomly choose between strategies strategy = np.random.choice(["external", "noise", "corrupted"]) return self._create_off_topic_sample_strategy(original_audio, strategy) else: raise ValueError( f"Unknown contamination strategy: {self.contamination_strategy}" ) def _create_off_topic_sample_strategy( self, audio: torch.Tensor, strategy: str ) -> torch.Tensor: """Helper to create off-topic sample with specific strategy""" if strategy == "external" and self.contamination_dataset is not None: try: ext_idx = np.random.randint(0, len(self.contamination_dataset)) ext_data = self.contamination_dataset[ext_idx] ext_audio = ext_data[0] ext_audio = self._match_audio_shape(ext_audio, audio.shape) return ext_audio except Exception as e: logger.error(f"Failed to get item from cont. dataset: {e}") if strategy == "corrupted": corrupted = audio.clone() noise = torch.normal(0, self.noise_level, size=corrupted.shape) corrupted = corrupted + noise return torch.clamp(corrupted, -1.0, 1.0) # Default to noise return self._create_noise_sample(audio) def _create_noise_sample(self, reference_audio: torch.Tensor) -> torch.Tensor: """Create pure noise with same shape as reference""" if np.random.random() < 0.5: # Gaussian noise return torch.normal(0, self.noise_level, size=reference_audio.shape) else: # Uniform noise return (torch.rand_like(reference_audio) - 0.5) * 2 * self.noise_level def _match_audio_shape( self, source_audio: torch.Tensor, target_shape: torch.Size ) -> torch.Tensor: """Resize/pad/crop source audio to match target shape""" if source_audio.shape == target_shape: return source_audio # Handle different tensor dimensions if len(target_shape) == 1: # [length] target_length = target_shape[0] if len(source_audio.shape) == 2: # [channels, length] source_audio = source_audio.mean(dim=0) # Convert to mono current_length = source_audio.shape[0] if current_length > target_length: # Crop start = np.random.randint(0, current_length - target_length) return source_audio[start : start + target_length] elif current_length < target_length: # Pad with repetition repetitions = (target_length // current_length) + 1 repeated = source_audio.repeat(repetitions) return repeated[:target_length] else: return source_audio elif len(target_shape) == 2: # [channels, length] target_channels, target_length = target_shape # Ensure correct channel dimension if len(source_audio.shape) == 1: # [length] -> [channels, length] source_audio = source_audio.unsqueeze(0).repeat(target_channels, 1) elif source_audio.shape[0] != target_channels: if source_audio.shape[0] > target_channels: source_audio = source_audio[:target_channels] else: source_audio = source_audio.repeat( target_channels // source_audio.shape[0] + 1, 1 ) source_audio = source_audio[:target_channels] # Handle length dimension current_length = source_audio.shape[1] if current_length > target_length: start = np.random.randint(0, current_length - target_length) return source_audio[:, start : start + target_length] elif current_length < target_length: repetitions = (target_length // current_length) + 1 repeated = source_audio.repeat(1, repetitions) return repeated[:, :target_length] else: return source_audio return source_audio # Fallback def __len__(self) -> int: """Same size as original dataset""" return len(self.original_dataset) def __getitem__(self, idx: int): """Get item - either original or contaminated version""" if idx in self.contaminated_indices: return self.contaminated_samples[idx] audio, path, label, _ = extract_sample(self.original_dataset[idx]) return audio, path, label, torch.tensor(0)
[docs] def get_errors(self) -> list[int]: """ Return ground truth off-topic indicators. This matches the interface from the image domain. Returns: List of 0/1 indicating whether each sample is off-topic """ return [ 1 if idx in self.contaminated_indices else 0 for idx in range(len(self)) ]
[docs] def info(self): """Print dataset information""" print(f"Name: {self.name}") print(f"Dataset size: {len(self)}") print(f"Number of off-topic samples: {len(self.contaminated_indices)}") print(f"Contamination strategy: {self.contamination_strategy}") print(f"Has external dataset: {self.contamination_dataset is not None}") print(f"Noise level: {self.noise_level}")