Source code for selfclean_audio.validation
# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Centralized validation functions for configuration and parameters."""
from omegaconf import DictConfig
from .constants import (
DUPLICATE_STRATEGY_MAP,
OFF_TOPIC_STRATEGY_MAP,
REQUIRED_BASE_PARAMS,
REQUIRED_DATALOADER_PARAMS,
REQUIRED_DATASET_PATHS,
REQUIRED_GTZAN_PARAMS,
)
[docs]
class ValidationError(ValueError):
"""Custom exception for validation errors."""
pass
[docs]
def validate_required_attributes(
obj, required_attrs: list[str], context: str = ""
) -> None:
"""
Validate that an object has all required attributes.
Args:
obj: Object to validate
required_attrs: List of required attribute names
context: Context string for better error messages
Raises:
ValidationError: If any required attribute is missing
"""
missing_attrs = [attr for attr in required_attrs if not hasattr(obj, attr)]
if missing_attrs:
context_str = f" for {context}" if context else ""
raise ValidationError(
f"Missing required attributes{context_str}: {', '.join(missing_attrs)}"
)
[docs]
def validate_base_config(cfg: DictConfig) -> None:
"""Validate basic configuration parameters."""
validate_required_attributes(cfg, REQUIRED_BASE_PARAMS, "base configuration")
[docs]
def validate_dataset_paths(cfg: DictConfig) -> None:
"""Validate dataset path parameters."""
validate_required_attributes(cfg, REQUIRED_DATASET_PATHS, "dataset paths")
[docs]
def validate_gtzan_config(cfg: DictConfig) -> None:
"""Validate GTZAN-specific configuration parameters."""
validate_required_attributes(cfg, REQUIRED_GTZAN_PARAMS, "GTZAN configuration")
[docs]
def validate_dataloader_config(cfg: DictConfig) -> None:
"""Validate dataloader configuration parameters."""
if not hasattr(cfg, "dataloader"):
raise ValidationError("Config must define dataloader parameters")
validate_required_attributes(
cfg.dataloader, REQUIRED_DATALOADER_PARAMS, "dataloader configuration"
)
[docs]
def validate_issue_type(issue_type: str) -> None:
"""
Validate that the issue type is supported.
Args:
issue_type: Issue type to validate
Raises:
ValidationError: If issue type is not supported
"""
all_supported_types = (
list(DUPLICATE_STRATEGY_MAP.keys())
+ ["label_errors"]
+ list(OFF_TOPIC_STRATEGY_MAP.keys())
)
if issue_type not in all_supported_types:
raise ValidationError(
f"Unknown ISSUE_TYPE: {issue_type}. Must be one of: {all_supported_types}"
)
[docs]
def validate_duplicate_strategy(issue_type: str) -> str:
"""
Validate and get duplicate strategy from issue type.
Args:
issue_type: Issue type for duplicate detection
Returns:
str: Corresponding duplicate strategy
Raises:
ValidationError: If issue type is not a valid duplicate type
"""
if issue_type not in DUPLICATE_STRATEGY_MAP:
raise ValidationError(
f"Unknown duplicate ISSUE_TYPE: {issue_type}. "
f"Must be one of: {list(DUPLICATE_STRATEGY_MAP.keys())}"
)
return DUPLICATE_STRATEGY_MAP[issue_type]
[docs]
def validate_off_topic_strategy(issue_type: str) -> str:
"""
Validate and get off-topic strategy from issue type.
Args:
issue_type: Issue type for off-topic detection
Returns:
str: Corresponding off-topic strategy
Raises:
ValidationError: If issue type is not a valid off-topic type
"""
if issue_type not in OFF_TOPIC_STRATEGY_MAP:
raise ValidationError(
f"Unknown off-topic ISSUE_TYPE: {issue_type}. "
f"Must be one of: {list(OFF_TOPIC_STRATEGY_MAP.keys())}"
)
return OFF_TOPIC_STRATEGY_MAP[issue_type]
[docs]
def get_seed_from_config(cfg: DictConfig) -> int:
"""
Extract seed from various possible locations in config.
Args:
cfg: Configuration object
Returns:
int: Seed value
Raises:
ValidationError: If no seed is found in any expected location
"""
seed_locations = [
("SEED", lambda c: c.SEED),
("params.seed", lambda c: c.params.seed),
("selfclean_audio.random_seed", lambda c: c.selfclean_audio.random_seed),
]
for location_name, extractor in seed_locations:
try:
if hasattr(cfg, location_name.split(".")[0]):
seed = extractor(cfg)
return int(seed)
except (AttributeError, TypeError):
continue
raise ValidationError(
"Config must define seed parameter as SEED, params.seed, or selfclean_audio.random_seed"
)
[docs]
def validate_full_config(cfg: DictConfig) -> None:
"""
Perform comprehensive configuration validation.
Args:
cfg: Configuration object to validate
Raises:
ValidationError: If any validation fails
"""
# Basic validation
validate_base_config(cfg)
# Issue type validation
validate_issue_type(cfg.ISSUE_TYPE)
# Seed validation
get_seed_from_config(cfg)
# Path validation (only for non-GTZAN datasets)
eval_dataset = getattr(cfg, "EVAL_DATASET", "").lower()
if eval_dataset != "gtzan":
validate_dataset_paths(cfg)
else:
validate_gtzan_config(cfg)
# Dataloader validation
validate_dataloader_config(cfg)