Source code for selfclean_audio.utils.types

# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import random

import numpy as np
import torch
from omegaconf import DictConfig


[docs] def set_seeds(_C: DictConfig): """ Set the random seed for reproducibility across various libraries and frameworks. This function sets the seed for the Python ``random`` module, the ``numpy`` library, and the ``torch`` library to ensure deterministic behavior. Additionally, it configures CuDNN's deterministic and benchmarking behavior based on the provided configuration. Args: _C (DictConfig): A dict configuration that contains the random seed and CuDNN settings (``seed``, ``cudnn_deterministic``, and ``cudnn_benchmark``). Notes: - ``_C.params.seed`` should be an integer value used for seeding. - ``_C.params.cudnn_deterministic`` is a boolean flag that enforces deterministic behavior in CuDNN. - ``_C.params.cudnn_benchmark`` is a boolean flag that enables CuDNN's autotuner for optimal performance. """ # Try multiple locations for the seed to support different templates seed = None try: if hasattr(_C, "params") and hasattr(_C.params, "seed"): seed = int(_C.params.seed) elif hasattr(_C, "SEED"): seed = int(_C.SEED) elif hasattr(_C, "selfclean_audio") and hasattr( _C.selfclean_audio, "random_seed" ): seed = int(_C.selfclean_audio.random_seed) except Exception: pass if seed is None: # Fallback to a reasonable default but warn via comment (no logging here to avoid side-effects) seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # CuDNN flags: default to deterministic=False, benchmark=False if missing det = False bench = False try: if hasattr(_C, "params"): det = bool(getattr(_C.params, "cudnn_deterministic", det)) bench = bool(getattr(_C.params, "cudnn_benchmark", bench)) except Exception: pass torch.backends.cudnn.deterministic = det torch.backends.cudnn.benchmark = bench
[docs] def ensure_dictconfig(_C): """ Ensure the provided configuration is an instance of DictConfig. This function checks if the given object ``_C`` is an instance of the ``DictConfig`` class. If not, a `TypeError` is raised. Args: _C (object): The configuration object to check. Returns: DictConfig: The original `_C` object if it is an instance of `DictConfig`. Raises: TypeError: If `_C` is not an instance of `DictConfig`. """ if not isinstance(_C, DictConfig): raise TypeError(f"Expected DictConfig, got {type(_C)}") return _C