Source code for selfclean_audio.utils.sample

# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from __future__ import annotations

from typing import Optional, Tuple

import torch


[docs] def extract_sample( sample_tuple: tuple, ) -> Tuple[torch.Tensor, Optional[str], int, torch.Tensor]: """ Normalize dataset samples to a common shape. Accept tuples returned by various dataset implementations and return a tuple of (waveform, path, label, noisy_label) where: - waveform: torch.Tensor, shape (C, T) or (T,) - path: Optional[str] (file path or None if unavailable) - label: int (class index) - noisy_label: torch.Tensor scalar long, 0 for clean, 1 for noisy This helper avoids repeated ad-hoc tuple length checks across the codebase. """ n = len(sample_tuple) if n == 4: # (audio, path, label, noisy_label) audio, path, label, noisy = sample_tuple if not isinstance(noisy, torch.Tensor): noisy = torch.tensor(int(noisy), dtype=torch.long) return ( audio, path, int(label) if not isinstance(label, torch.Tensor) else int(label.item()), noisy, ) if n == 3: # (audio, path, label) -> add noisy_label=0 audio, path, label = sample_tuple return ( audio, path, int(label) if not isinstance(label, torch.Tensor) else int(label.item()), torch.tensor(0, dtype=torch.long), ) if n == 2: # (audio, label) -> no path information audio, label = sample_tuple return ( audio, None, int(label) if not isinstance(label, torch.Tensor) else int(label.item()), torch.tensor(0, dtype=torch.long), ) raise ValueError( f"Unsupported sample tuple format of length {n}; expected 2, 3 or 4 elements" )