Source code for selfclean_audio.datasets.duplicate_dataset

# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import math
import shutil
from typing import Literal

import numpy as np
import torch
from loguru import logger
from torch.utils.data import Dataset

from selfclean_audio.utils.sample import extract_sample


[docs] class DuplicateDataset(Dataset): """ Dataset that creates near-duplicates by appending modified versions to the original dataset. This follows the same approach as the image domain implementation. """ def __init__( self, dataset: Dataset, frac_error: float = 0.1, n_errors: int | None = None, duplicate_strategy: Literal[ "exact", "noisy", "cropped", "mixed", "combined" ] = "exact", noise_level: float = 0.05, crop_ratio_range: tuple[float, float] = (0.1, 0.25), random_state: int = 42, name: str | None = None, save_to_temp: bool = True, temp_dir: str | None = None, ): """ Args: dataset: Original clean dataset frac_error: Fraction of samples to duplicate n_errors: Exact number of duplicates (overrides frac_error) duplicate_strategy: Type of duplicate to create noise_level: Noise level for noisy duplicates crop_ratio_range: Range of crop ratios for cropped duplicates random_state: Random seed for reproducibility name: Dataset name for logging """ self.name = ( name if name is not None else f"DuplicateDataset_{duplicate_strategy}" ) self.original_dataset = dataset self.duplicate_strategy = duplicate_strategy self.noise_level = noise_level self.crop_ratio_range = crop_ratio_range # Forward sample_rate attribute from original dataset if available if hasattr(dataset, "sample_rate"): self.sample_rate = dataset.sample_rate else: # Default sample rate (commonly used in audio processing) self.sample_rate = 16000 logger.warning( f"Original dataset has no sample_rate attribute, using default: {self.sample_rate}" ) logger.info(f"Initializing {self.name}") self.save_to_temp = save_to_temp self.temp_dir = None if self.save_to_temp: import tempfile from pathlib import Path self.temp_dir = ( Path(temp_dir) if temp_dir is not None else Path(tempfile.mkdtemp(prefix="synthetic_audio_")) ) self.temp_dir.mkdir(parents=True, exist_ok=True) # Set random seed for reproducibility np.random.seed(random_state) torch.manual_seed(random_state) # Determine how many samples to duplicate if n_errors is not None: n_duplicates = n_errors else: n_duplicates = math.ceil(frac_error * len(self.original_dataset)) # Randomly select indices to duplicate idx_range = np.arange(0, len(self.original_dataset)) self.duplicate_indices = np.random.choice( idx_range, size=n_duplicates, replace=False ) # Pre-generate duplicates for consistency self.duplicates = [] self.duplicate_pairs = [] # (original_idx, duplicate_idx) tuples for i, orig_idx in enumerate(self.duplicate_indices): duplicate_idx = len(self.original_dataset) + i self.duplicate_pairs.append((orig_idx, duplicate_idx)) # Get original sample - handle variable number of return values original_data = self.original_dataset[orig_idx] original_audio, audio_path, label, _noisy = extract_sample(original_data) # Create duplicate based on strategy duplicate_audio = self._create_duplicate(original_audio) # Store duplicate with same format as original + noisy_label if audio_path is not None: # Build a real path for the duplicate so audio-hash/Dejavu can read it if self.save_to_temp and self.temp_dir is not None: from pathlib import Path import torchaudio orig_stem = Path(audio_path).stem duplicate_path = self.temp_dir / f"{orig_stem}_duplicate_{i}.wav" # Ensure waveform is 2D (C, T) wav = duplicate_audio if wav.ndim == 1: wav = wav.unsqueeze(0) # Use original dataset sample rate if available sample_rate = getattr(self.original_dataset, "sample_rate", 16000) try: torchaudio.save(str(duplicate_path), wav, sample_rate) except Exception: # Fallback: clamp dtype/values and retry wav = wav.to(torch.float32) torchaudio.save(str(duplicate_path), wav, sample_rate) duplicate_path_str = str(duplicate_path) else: duplicate_path_str = f"{audio_path}_duplicate_{i}" # Duplicates are considered "noisy" (1), originals are clean (0) self.duplicates.append( (duplicate_audio, duplicate_path_str, label, torch.tensor(1)) ) else: self.duplicates.append((duplicate_audio, label, torch.tensor(1))) def _create_duplicate(self, audio: torch.Tensor) -> torch.Tensor: """Create a duplicate of the audio based on the specified strategy.""" if self.duplicate_strategy == "exact": # Exact copy return audio.clone() elif self.duplicate_strategy == "noisy": # Add Gaussian noise duplicate = audio.clone() noise = torch.normal(0, self.noise_level, size=duplicate.shape) duplicate = duplicate + noise return torch.clamp(duplicate, -1.0, 1.0) elif self.duplicate_strategy == "cropped": # Randomly crop part of the audio duplicate = audio.clone() if len(duplicate.shape) == 2: # [channels, length] audio_length = duplicate.shape[1] else: # [length] audio_length = duplicate.shape[0] crop_ratio = np.random.uniform(*self.crop_ratio_range) crop_length = int(audio_length * crop_ratio) # Random crop position start_pos = np.random.randint(0, max(1, audio_length - crop_length)) end_pos = start_pos + crop_length # Zero out the cropped section if len(duplicate.shape) == 2: duplicate[:, start_pos:end_pos] = 0 else: duplicate[start_pos:end_pos] = 0 return duplicate elif self.duplicate_strategy == "mixed": # Mix with small amount of noise duplicate = audio.clone() noise = torch.rand_like(duplicate) * self.noise_level mix_ratio = np.random.uniform(*self.crop_ratio_range) duplicate = (1 - mix_ratio) * duplicate + mix_ratio * noise return torch.clamp(duplicate, -1.0, 1.0) elif self.duplicate_strategy == "combined": # Randomly select one of the other strategies for each duplicate strategies = ["exact", "noisy", "cropped", "mixed"] selected_strategy = np.random.choice(strategies) # Temporarily change strategy and create duplicate original_strategy = self.duplicate_strategy self.duplicate_strategy = selected_strategy duplicate = self._create_duplicate(audio) self.duplicate_strategy = original_strategy return duplicate else: raise ValueError(f"Unknown duplicate strategy: {self.duplicate_strategy}") def __len__(self) -> int: """Total dataset size: original + duplicates""" return len(self.original_dataset) + len(self.duplicates) def __getitem__(self, idx: int): """Get item from either original dataset or duplicates""" if idx < len(self.original_dataset): # Return original sample with noisy_label=0 (clean) original_data = self.original_dataset[idx] if len(original_data) == 3: # ESC50 format: (audio, path, label) -> (audio, path, label, noisy_label) audio, path, label = original_data return audio, path, label, torch.tensor(0) # 0 = clean/original else: # Other formats - extend as needed return original_data + (torch.tensor(0),) else: # Return duplicate sample (already has noisy_label=1) duplicate_idx = idx - len(self.original_dataset) return self.duplicates[duplicate_idx]
[docs] def cleanup_temp_dir(self) -> None: """Remove the temporary directory containing synthetic duplicate files.""" if getattr(self, "save_to_temp", False) and getattr(self, "temp_dir", None): try: shutil.rmtree(str(self.temp_dir), ignore_errors=True) except Exception: pass finally: self.temp_dir = None
[docs] def get_errors(self) -> tuple[set[tuple[int, int]], list[str]]: """ Return ground truth duplicate pairs. This matches the interface from the image domain. Returns: Set of (original_idx, duplicate_idx) tuples List of error type names """ return set(self.duplicate_pairs), ["original", "duplicate"]
[docs] def info(self): """Print dataset information""" print(f"Name: {self.name}") print(f"Original dataset size: {len(self.original_dataset)}") print(f"Number of duplicates: {len(self.duplicates)}") print(f"Total dataset size: {len(self)}") print(f"Duplicate strategy: {self.duplicate_strategy}") print(f"Duplicate pairs: {len(self.duplicate_pairs)}")