Source code for selfclean_audio.datasets.gtzan

# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from __future__ import annotations

import itertools
import os
from dataclasses import dataclass
from pathlib import Path

import torch

from selfclean_audio.datasets.base import BaseAudioDataset


@dataclass
class _GTZANGroundTruth:
    duplicate_pairs: set[tuple[int, int]]
    mislabeled_ids: set[str]


[docs] class GTZANKnownIssuesDataset(BaseAudioDataset): """ GTZAN dataset wrapper with built-in access to known data quality issues. - Exposes audio samples from a local GTZAN folder (``genres/<class>/*.wav``) - Parses ground truth CSVs with known issues from external_code - Provides `get_errors()` compatible with SelfClean evaluation: - For ISSUE_TYPE "duplicates": returns (set of (idx_i, idx_j) pairs, [..labels..]) - For ISSUE_TYPE "label_errors": returns a list[int] of 0/1 per sample """ def __init__( self, root: str | Path, issue_type: str = "duplicates", gt_duplicates_file: str | Path | None = None, gt_prep_file: str | Path | None = None, convert_mono: bool = True, sample_rate: int = 16000, target_duration_sec: float | None = 30.0, extensions: tuple[str, ...] = (".wav", ".mp3", ".flac"), ) -> None: super().__init__( root=str(root), convert_mono=convert_mono, sample_rate=sample_rate, target_duration_sec=target_duration_sec, ) self.issue_type = issue_type self.extensions = extensions # Scan files by class self.class_names: list[str] = [] self.class_to_idx: dict[str, int] = {} self.samples: list[tuple[str, int]] = [] # (path, label_idx) self._scan_dataset() # Map from GTZAN id (e.g., "disco.00050") to dataset index self.id_to_index: dict[str, int] = { self._path_to_gtzan_id(path): i for i, (path, _) in enumerate(self.samples) } # Parse ground truth CSVs if provided self._gt = self._parse_ground_truth( gt_duplicates_file=gt_duplicates_file, gt_prep_file=gt_prep_file ) # ------------------------------------------------------------------ # Dataset API # ------------------------------------------------------------------ def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> tuple[torch.Tensor, str, int, torch.Tensor]: path, label_idx = self.samples[idx] waveform, _ = self._load_and_preprocess_audio(path) # Optional: mark known mislabeled items as noisy=1 for visualization gtzan_id = self._path_to_gtzan_id(path) noisy = 1 if gtzan_id in self._gt.mislabeled_ids else 0 return waveform, path, label_idx, torch.tensor(noisy) # ------------------------------------------------------------------ # Ground truth interfacing used by SelfCleanAudio.calculate_scores # ------------------------------------------------------------------
[docs] def get_errors(self): """ Return evaluation ground truth in the format expected by SelfClean. - If evaluating duplicates: returns (set[(i, j)], [..labels..]) - If evaluating label errors: returns list[int] of length len(self) """ if "duplicates" in self.issue_type: return self._gt.duplicate_pairs, ["gtzan_original", "gtzan_duplicate"] if "label_errors" in self.issue_type: indicators = [0] * len(self) for gtzan_id in self._gt.mislabeled_ids: idx = self.id_to_index.get(gtzan_id) if idx is not None: indicators[idx] = 1 return indicators raise ValueError( f"Unsupported issue_type for GTZAN: {self.issue_type}. Use 'duplicates' or 'label_errors'." )
# ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _scan_dataset(self) -> None: classes = [] files: list[tuple[str, int]] = [] # Expect structure: root/genres/<class>/<file> for cls_name in sorted(os.listdir(self.root)): cls_path = os.path.join(self.root, cls_name) if not os.path.isdir(cls_path): continue classes.append(cls_name) self.class_names = sorted(classes) self.class_to_idx = {c: i for i, c in enumerate(self.class_names)} for cls in self.class_names: cls_path = os.path.join(self.root, cls) for fname in sorted(os.listdir(cls_path)): if any(fname.lower().endswith(ext) for ext in self.extensions): files.append( (os.path.join(cls_path, fname), self.class_to_idx[cls]) ) if len(files) == 0: raise RuntimeError( f"No audio files found under {self.root}. Expected GTZAN layout '.../genres/<class>/*.wav'" ) self.samples = files @staticmethod def _path_to_gtzan_id(path: str | Path) -> str: p = Path(path) # e.g., genres/disco/disco.00050.wav -> disco.00050 stem = p.stem # disco.00050 return stem def _parse_ground_truth( self, gt_duplicates_file: str | Path | None, gt_prep_file: str | Path | None, ) -> _GTZANGroundTruth: dup_pairs: set[tuple[int, int]] = set() mislabeled_ids: set[str] = set() # Parse duplicates ground truth if gt_duplicates_file is not None and Path(gt_duplicates_file).exists(): try: groups: dict[str, list[str]] = {} with open(gt_duplicates_file, "r", encoding="utf-8") as f: # Expect header: Id;Issue;Value _ = f.readline() for line in f: line = line.strip() if not line: continue parts = line.split(";") if len(parts) < 3: continue gtzan_id, issue, value = parts[0], parts[1], parts[2] if issue.strip().lower().startswith("exact repetition"): if value == "": # Some entries have empty group ids; ignore those continue groups.setdefault(value, []).append(gtzan_id) # Build all unique pairs per group for members in groups.values(): # map to indices; skip ids not present locally idxs: list[int] = [ self.id_to_index[m] for m in members if m in self.id_to_index ] for i, j in itertools.combinations(sorted(set(idxs)), 2): dup_pairs.add((i, j)) except Exception: # Keep empty set if parsing fails pass # Parse mislabeled indicators (optional) if gt_prep_file is not None and Path(gt_prep_file).exists(): try: with open(gt_prep_file, "r", encoding="utf-8") as f: # Header: Id;Artist Repetition;Distortions;Exact Repetition;Exact Rep Value;Mislabelings;... _ = f.readline() for line in f: line = line.strip() if not line: continue parts = line.split(";") if len(parts) < 6: continue gtzan_id = parts[0].strip() mislabel = parts[5].strip().lower() if mislabel == "yes": if gtzan_id in self.id_to_index: mislabeled_ids.add(gtzan_id) except Exception: # Ignore parsing errors; treat as unknown pass return _GTZANGroundTruth( duplicate_pairs=dup_pairs, mislabeled_ids=mislabeled_ids )