# Copyright (c) Lucerne University of Applied Sciences and Arts.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torchaudio
from loguru import logger
[docs]
@dataclass
class LoraAdaptConfig:
enable: bool = False
r: int = 8
alpha: int = 16
dropout: float = 0.05
target_modules: tuple[str, ...] = (
"q_proj",
"k_proj",
"v_proj",
"out_proj",
"fc1",
"fc2",
)
epochs: int = 1
lr: float = 1e-4
weight_decay: float = 0.0
temperature: float = 0.2
projection_dim: int = 256
max_steps: int | None = None # limit updates for quick runs
# objective: "infonce" or "vicreg"
objective: str = "infonce"
# VICReg coefficients
vicreg_sim_coeff: float = 25.0
vicreg_var_coeff: float = 25.0
vicreg_cov_coeff: float = 1.0
# Augmentation controls
sample_rate: int = 16000
strong_aug: bool = True
time_shift_max: float = 0.1
add_noise_snr_db: float = 15.0
tempo_min: float = 0.9
tempo_max: float = 1.1
pitch_semitones: float = 2.0
reverb_prob: float = 0.3
eq_prob: float = 0.4
time_mask_prob: float = 0.5
time_mask_max_ratio: float = 0.2
# Gradient accumulation for memory efficiency
gradient_accumulation_steps: int = 1
def _unwrap_quant_noise_modules(model: nn.Module) -> None:
"""Best-effort unwrap quant_noise wrappers so PEFT can see nn.Linear modules.
Some backbones (e.g., BEATs) wrap linear projections with a quant_noise
module that stores the actual nn.Linear under attribute `module`. PEFT's
LoRA injection targets nn.Linear; unwrapping ensures adapters attach.
"""
attrs = ("q_proj", "k_proj", "v_proj", "out_proj")
for module in model.modules():
for attr in attrs:
if hasattr(module, attr):
sub = getattr(module, attr)
# Unwrap single level `module` indirection if present
try:
if hasattr(sub, "module") and isinstance(sub.module, nn.Linear):
setattr(module, attr, sub.module)
except Exception:
pass
def _try_enable_lora(model: nn.Module, cfg: LoraAdaptConfig) -> nn.Module | None:
"""Attach LoRA adapters to the model using PEFT when available.
Returns the PEFT-wrapped model if successful; otherwise None.
"""
try:
from peft import LoraConfig, get_peft_model
# Unwrap quant_noise so adapters can target nn.Linear modules
_unwrap_quant_noise_modules(model)
lcfg = LoraConfig(
r=cfg.r,
lora_alpha=cfg.alpha,
lora_dropout=cfg.dropout,
target_modules=list(cfg.target_modules),
inference_mode=False,
bias="none",
task_type="FEATURE_EXTRACTION",
)
peft_model = get_peft_model(model, lcfg)
peft_model.print_trainable_parameters()
try:
if hasattr(peft_model, "set_adapter"):
peft_model.set_adapter("default")
if hasattr(peft_model, "enable_adapter_layers"):
peft_model.enable_adapter_layers()
except Exception:
pass
# Introspect injected modules for debugging
try:
injected = []
for name, mod in peft_model.named_modules():
if hasattr(mod, "lora_A") or hasattr(mod, "lora_B"):
injected.append((name, mod.__class__.__name__))
if injected:
preview = ", ".join([f"{n}:{c}" for n, c in injected[:6]])
logger.info(
f"LoRA injection summary: {len(injected)} modules with lora_* params. Examples: {preview}"
)
# Debug: analyze which types of modules got LoRA
module_types = {}
for name, cls in injected:
# Extract module type (e.g., 'k_proj', 'v_proj', 'fc1')
module_type = name.split(".")[-1] if "." in name else name
module_types[module_type] = module_types.get(module_type, 0) + 1
logger.info(f"LoRA module distribution: {dict(module_types)}")
else:
logger.warning(
"PEFT did not expose any modules with lora_* attributes."
)
except Exception:
pass
return peft_model
except Exception as e:
logger.warning(
f"PEFT not available or failed to attach LoRA adapters: {e}. "
"Proceeding without LoRA (base model frozen)."
)
return None
def _freeze_all(model: nn.Module):
for p in model.parameters():
p.requires_grad = False
class ProjectionHead(nn.Module):
def __init__(self, in_dim: int, proj_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, in_dim),
nn.ReLU(inplace=True),
nn.Linear(in_dim, proj_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _info_nce_loss(z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.2):
"""Compute SimCLR-style InfoNCE loss over a batch.
z1, z2: shape [B, D]
"""
z1 = nn.functional.normalize(z1, dim=-1)
z2 = nn.functional.normalize(z2, dim=-1)
B = z1.size(0)
reps = torch.cat([z1, z2], dim=0) # [2B, D]
sim = torch.matmul(reps, reps.t()) / temperature # [2B, 2B]
# Mask self-similarity
mask = torch.eye(2 * B, device=sim.device, dtype=torch.bool)
sim = sim.masked_fill(mask, float("-inf"))
# Positives: i <-> i+B and i+B <-> i
targets = torch.arange(B, device=sim.device)
targets = torch.cat([targets + B, targets], dim=0)
loss = nn.functional.cross_entropy(sim, targets)
return loss
def _off_diagonal(x: torch.Tensor) -> torch.Tensor:
n, m = x.shape
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def _vicreg_loss(
z1: torch.Tensor,
z2: torch.Tensor,
sim_coeff: float = 25.0,
var_coeff: float = 25.0,
cov_coeff: float = 1.0,
eps: float = 1e-4,
):
"""Loss VICReg: invariance + variance + covariance regularization.
z1, z2: [B, D]
"""
# invariance
sim_loss = nn.functional.mse_loss(z1, z2)
# variance (per-dimension std close to 1)
def _var_loss(z: torch.Tensor):
std = torch.sqrt(z.var(dim=0) + eps)
return torch.mean(nn.functional.relu(1 - std))
var_loss = _var_loss(z1) + _var_loss(z2)
# covariance (decorrelate features)
def _cov_loss(z: torch.Tensor):
z = z - z.mean(dim=0)
n, d = z.shape
cov = (z.T @ z) / (n - 1)
off = _off_diagonal(cov)
return (off.pow(2).sum()) / d
cov_loss = _cov_loss(z1) + _cov_loss(z2)
loss = sim_coeff * sim_loss + var_coeff * var_loss + cov_coeff * cov_loss
return loss
def _random_time_shift(x: torch.Tensor, max_shift: float = 0.1) -> torch.Tensor:
"""Circular time shift by up to max_shift fraction of length (per-sample)."""
B, T = x.shape[-2], x.shape[-1]
shift = (torch.rand(B, device=x.device) * 2 * max_shift - max_shift) * T
shift = shift.long()
out = torch.zeros_like(x)
for i in range(B):
s = shift[i].item()
out[i] = torch.roll(x[i], shifts=s, dims=-1)
return out
def _additive_noise(x: torch.Tensor, snr_db: float = 10.0) -> torch.Tensor:
rms = (x.pow(2).mean(dim=-1, keepdim=True)).sqrt().clamp_min(1e-8)
noise = torch.randn_like(x)
rms_n = (noise.pow(2).mean(dim=-1, keepdim=True)).sqrt().clamp_min(1e-8)
scale = rms / (10 ** (snr_db / 20)) / rms_n
return x + scale * noise
def _random_crop(x: torch.Tensor, crop_ratio: float = 0.9) -> torch.Tensor:
"""Randomly crop to crop_ratio length and pad back to original size."""
B, T = x.shape[-2], x.shape[-1]
new_T = max(1, int(T * crop_ratio))
start = (torch.rand(B, device=x.device) * (T - new_T + 1)).long()
out = torch.zeros_like(x)
for i in range(B):
seg = x[i, :, start[i] : start[i] + new_T]
out[i, :, :new_T] = seg
return out
def _apply_time_mask(x: torch.Tensor, max_ratio: float) -> torch.Tensor:
"""Zero-out a random temporal segment up to max_ratio of length (per-sample)."""
B, C, T = x.shape
out = x.clone()
for i in range(B):
mask_len = max(1, int(T * random.uniform(0.0, max_ratio)))
start = random.randint(0, max(0, T - mask_len))
out[i, :, start : start + mask_len] = 0
return out
def _maybe_trim_pad(x: torch.Tensor, T: int) -> torch.Tensor:
if x.size(-1) == T:
return x
if x.size(-1) > T:
return x[..., :T]
pad = T - x.size(-1)
return nn.functional.pad(x, (0, pad), mode="constant", value=0)
def _apply_strong_augs(x: torch.Tensor, cfg: LoraAdaptConfig) -> torch.Tensor:
"""Apply a sequence of stronger waveform-level augs using sox effects and biquads.
Improved to handle batch processing more efficiently and consistently.
"""
B, C, T = x.shape
out = x.clone()
# Pre-generate random decisions for the batch for consistency
eq_mask = torch.rand(B) < cfg.eq_prob
tempo_mask = torch.rand(B) < 0.5
pitch_mask = torch.rand(B) < 0.5
reverb_mask = torch.rand(B) < cfg.reverb_prob
for i in range(B):
xi = out[i]
# Random EQ (bandreject / lowpass / highpass)
if eq_mask[i]:
choice = random.choice(["lowpass", "highpass", "bandreject"]) # noqa: S311
if choice == "lowpass":
cutoff = random.uniform(2000, 8000)
xi = torchaudio.functional.lowpass_biquad(xi, cfg.sample_rate, cutoff)
elif choice == "highpass":
cutoff = random.uniform(50, 500)
xi = torchaudio.functional.highpass_biquad(xi, cfg.sample_rate, cutoff)
else:
center = random.uniform(300, 4000)
xi = torchaudio.functional.bandreject_biquad(
xi, cfg.sample_rate, center, Q=0.707
)
# Random tempo / pitch using sox (may change length)
if tempo_mask[i]:
tempo = random.uniform(cfg.tempo_min, cfg.tempo_max)
try:
xi, _ = torchaudio.sox_effects.apply_effects_tensor(
xi, cfg.sample_rate, [["tempo", f"{tempo:.3f}"]]
)
xi = _maybe_trim_pad(xi, T)
except RuntimeError:
# Handle Sox effects failures gracefully
pass
if pitch_mask[i] and cfg.pitch_semitones > 0:
semis = random.uniform(-cfg.pitch_semitones, cfg.pitch_semitones)
try:
xi, _ = torchaudio.sox_effects.apply_effects_tensor(
xi, cfg.sample_rate, [["pitch", f"{semis:.3f}"]]
)
xi = _maybe_trim_pad(xi, T)
except RuntimeError:
# Handle Sox effects failures gracefully
pass
if reverb_mask[i]:
try:
xi, _ = torchaudio.sox_effects.apply_effects_tensor(
xi, cfg.sample_rate, [["reverb", "50", "50", "100"]]
)
xi = _maybe_trim_pad(xi, T)
except RuntimeError:
# Handle Sox effects failures gracefully
pass
out[i] = xi
# Time masking - apply batch-wise with consistent probability
if random.random() < cfg.time_mask_prob:
out = _apply_time_mask(out, cfg.time_mask_max_ratio)
# Always apply some jitter: crop -> shift -> noise
# Use consistent crop ratio across batch for better contrastive learning
crop_ratio = random.uniform(0.85, 0.98)
out = _random_crop(out, crop_ratio=crop_ratio)
out = _random_time_shift(out, max_shift=cfg.time_shift_max)
out = _additive_noise(out, snr_db=cfg.add_noise_snr_db)
return out
def _two_views(
waveform: torch.Tensor, cfg: LoraAdaptConfig
) -> tuple[torch.Tensor, torch.Tensor]:
"""Create two augmented views of a batch of waveforms.
waveform: [B, C, T] or [B, T]; will squeeze to [B, 1, T].
"""
if waveform.ndim == 2:
waveform = waveform.unsqueeze(1)
B, C, T = waveform.shape
# Ensure minimum length for feature extraction when strong augmentations are used
min_length = 400
if cfg.strong_aug and T < min_length:
pad_length = min_length - T
waveform = torch.nn.functional.pad(
waveform, (0, pad_length), mode="constant", value=0
)
T = waveform.size(-1)
logger.debug(
f"Padded batch from {T - pad_length} to {T} samples for augmentation"
)
x = waveform.reshape(B, C, T)
if cfg.strong_aug:
x1 = _apply_strong_augs(x, cfg)
x2 = _apply_strong_augs(x, cfg)
else:
x1 = _additive_noise(_random_time_shift(_random_crop(x)), snr_db=15.0)
x2 = _additive_noise(_random_time_shift(_random_crop(x)), snr_db=15.0)
return x1, x2
def _pool_embedding(emb: torch.Tensor) -> torch.Tensor:
"""Mean-pool temporal embeddings if needed: [B, T, D] -> [B, D]."""
if emb.ndim == 3:
return emb.mean(dim=1)
return emb.squeeze()
[docs]
def adapt_model_with_lora(
model: nn.Module,
dataloader,
device: torch.device | str,
cfg: LoraAdaptConfig,
):
"""Adapt the model on the target dataset using SimCLR and LoRA adapters.
The base model is frozen; only LoRA adapters (if enabled) and a small projection
head are trained. After adaptation, the projection head is discarded and the
model remains with the adapted LoRA weights active.
"""
if not cfg.enable:
return model
logger.info("Starting LoRA adaptation on target dataset (unsupervised)")
model.to(device)
model.train()
_freeze_all(model)
peft_model = _try_enable_lora(model, cfg)
used_lora = peft_model is not None
if used_lora:
model = peft_model # use LoRA-wrapped model
# Infer embedding dimension with a small forward
sample = dataloader.dataset[0][0]
sample = sample.to(device)
with torch.no_grad():
emb, _ = model.extract_features(sample) # type: ignore[attr-defined]
emb_dim = _pool_embedding(emb).shape[-1]
proj = ProjectionHead(emb_dim, cfg.projection_dim).to(device)
# Collect trainable params (LoRA + projection)
if used_lora:
# Get LoRA parameters
lora_params = [p for p in model.parameters() if p.requires_grad]
logger.debug(f"Collected {len(lora_params)} LoRA parameters")
# Get projection parameters
proj_params = list(proj.parameters())
logger.debug(f"Collected {len(proj_params)} projection parameters")
# Combine all parameters
params = lora_params + proj_params
else:
params = list(proj.parameters())
logger.debug(f"Total parameters before validation: {len(params)}")
logger.debug(f"Parameter types: {[type(p).__name__ for p in params[:5]]}")
# Filter to only valid parameters
params = [p for p in params if isinstance(p, nn.Parameter)]
logger.debug(f"Valid parameters after filtering: {len(params)}")
if len(params) == 0:
logger.error("CRITICAL: No valid parameters for optimization!")
return model
opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
# Debug parameter collection
logger.debug(f"Collected params type: {type(params)}")
logger.debug(
f"Params length: {len(params) if hasattr(params, '__len__') else 'No length'}"
)
if len(params) > 0:
logger.debug(f"First param type: {type(params[0])}")
logger.debug(f"First few param types: {[type(p) for p in params[:5]]}")
# Validation: ensure we have trainable parameters
try:
trainable_count = sum(
1 for p in params if hasattr(p, "requires_grad") and p.requires_grad
)
total_params = sum(
p.numel() for p in params if hasattr(p, "requires_grad") and p.requires_grad
)
logger.info(
f"Adaptation setup: {trainable_count} trainable param groups, {total_params:,} total parameters"
)
if trainable_count == 0:
logger.error(
"CRITICAL: No trainable parameters found! Adaptation will fail."
)
return model
except Exception as e:
logger.error(f"Failed to validate parameters: {e}")
logger.error(f"Params debug: {params[:3] if len(params) > 0 else 'Empty'}")
return model
# Verify LoRA integration if enabled
if used_lora:
lora_param_count = sum(
1 for n, p in model.named_parameters() if "lora_" in n and p.requires_grad
)
if lora_param_count == 0:
logger.error(
"CRITICAL: LoRA enabled but no lora_* parameters found! Check PEFT integration."
)
return model
else:
logger.info(
f"LoRA integration verified: {lora_param_count} trainable LoRA parameters"
)
# Try to ensure all LoRA modules are properly enabled
try:
# Force enable all adapters
if hasattr(model, "enable_adapter_layers"):
model.enable_adapter_layers()
if hasattr(model, "set_adapter"):
model.set_adapter("default")
# Additional check: ensure LoRA modules are in train mode
lora_modules_found = 0
for name, module in model.named_modules():
if (
"lora_" in name.lower()
or hasattr(module, "lora_A")
or hasattr(module, "lora_B")
):
if hasattr(module, "train"):
module.train()
lora_modules_found += 1
if lora_modules_found > 0:
logger.info(
f"Ensured {lora_modules_found} LoRA modules are in training mode"
)
except Exception as e:
logger.warning(f"Failed to fully initialize LoRA state: {e}")
# Sanity-check snapshot: LoRA param norms and pre-training embeddings
def _lora_named_params(m: nn.Module):
return [(n, p) for n, p in m.named_parameters() if p.requires_grad]
lora_before = None
if used_lora:
try:
lora_before = {n: p.detach().clone() for n, p in _lora_named_params(model)}
if len(lora_before) == 0:
logger.warning("No trainable LoRA params found after PEFT wrapping.")
except Exception:
lora_before = None
# Helper: robust feature extraction with optional gradients
def _extract_features_with_optional_grad(x: torch.Tensor, want_grad: bool):
# BEATs expects [C, T] format
original_shape = x.shape
logger.debug(f"Input to feature extraction: {original_shape}")
# Convert to proper [C, T] format for BEATs
if x.ndim == 1:
# [T] -> [1, T] (add channel dimension)
x = x.unsqueeze(0)
elif x.ndim == 3:
# [B, C, T] -> [C, T] (this shouldn't happen in per-sample processing, but handle it)
if x.size(0) == 1:
x = x.squeeze(0) # [1, C, T] -> [C, T]
else:
logger.error(
f"Unexpected batch size {x.size(0)} in single-sample processing"
)
x = x[0] # Take first sample
# x should now be [C, T]
if x.ndim != 2:
raise ValueError(
f"After processing, expected 2D tensor [C, T], got {x.ndim}D: {x.shape}"
)
# Check minimum length for BEATs (needs at least 400 samples for window_size)
min_length = 400
if x.size(-1) < min_length:
pad_length = min_length - x.size(-1)
x = torch.nn.functional.pad(x, (0, pad_length), mode="constant", value=0)
logger.debug(
f"Padded from {x.size(-1) - pad_length} to {x.size(-1)} samples"
)
logger.debug(f"Final shape to BEATs: {x.shape}")
try:
return model.extract_features(x, with_grad=want_grad) # type: ignore[attr-defined]
except TypeError:
# BEATs doesn't support with_grad parameter
try:
result = model.extract_features(x) # type: ignore[attr-defined]
if (
want_grad
and hasattr(result[0], "requires_grad")
and not result[0].requires_grad
):
logger.warning(
"Gradients requested but result has requires_grad=False"
)
return result
except Exception as e:
logger.error(
f"BEATs failed: input {original_shape} -> {x.shape}, error: {e}"
)
logger.error(
f"Tensor info: dtype={x.dtype}, device={x.device}, min={x.min():.4f}, max={x.max():.4f}"
)
raise
# Grab a small batch to compare embeddings before/after adaptation
pre_emb: torch.Tensor | None = None
try:
first_batch = next(iter(dataloader))[0].to(device) # [B, C, T] or [B, T]
B = min(first_batch.shape[0], 4)
# Ensure model is in eval mode for consistent baseline embeddings
was_training = model.training
model.eval()
with torch.no_grad():
# BEATs requires per-sample processing
zs = []
for i in range(B):
# Extract single sample: first_batch[i] should be [C, T] after indexing
e, _ = _extract_features_with_optional_grad(
first_batch[i], want_grad=False
)
zs.append(_pool_embedding(e))
pre_emb = torch.stack(zs, dim=0)
# Restore original training state
model.train(was_training)
except Exception as e:
logger.warning(f"Failed to compute baseline embeddings: {e}")
pre_emb = None
step = 0
running_loss = 0.0
loss_history = []
# Gradient accumulation tracking
accumulation_step = 0
# LoRA gradient tracking over time
lora_gradient_history = {
"steps": [],
"lora_A_active": [],
"lora_B_active": [],
"lora_A_avg_grad": [],
"lora_B_avg_grad": [],
"total_active": [],
"gradient_ratio": [], # lora_B / lora_A gradient ratio
}
# Ensure model is in training mode for adaptation
model.train()
for epoch in range(cfg.epochs):
epoch_loss = 0.0
epoch_steps = 0
for batch in dataloader:
waveforms = batch[0].to(device)
x1, x2 = _two_views(waveforms, cfg)
# Forward two views through base model in batched manner
B = x1.size(0)
# BEATs requires per-sample processing: [C, T] format per call
# Cannot process batches directly, so iterate through samples
zs1 = []
zs2 = []
for i in range(B):
# Extract single sample: x1[i] should be [C, T] after indexing
emb1_i, _ = _extract_features_with_optional_grad(x1[i], want_grad=True)
emb2_i, _ = _extract_features_with_optional_grad(x2[i], want_grad=True)
z1_i = proj(_pool_embedding(emb1_i))
z2_i = proj(_pool_embedding(emb2_i))
zs1.append(z1_i)
zs2.append(z2_i)
z1 = torch.stack(zs1, dim=0)
z2 = torch.stack(zs2, dim=0)
if cfg.objective.lower() == "vicreg":
loss = _vicreg_loss(
z1,
z2,
sim_coeff=cfg.vicreg_sim_coeff,
var_coeff=cfg.vicreg_var_coeff,
cov_coeff=cfg.vicreg_cov_coeff,
)
else:
loss = _info_nce_loss(z1, z2, cfg.temperature)
# Scale loss for gradient accumulation
loss = loss / cfg.gradient_accumulation_steps
# Zero gradients at start of accumulation cycle
if accumulation_step == 0:
opt.zero_grad(set_to_none=True)
loss.backward()
# Create a fresh parameter list for gradient clipping to avoid any corruption
clip_params = []
if used_lora:
clip_params.extend([p for p in model.parameters() if p.requires_grad])
clip_params.extend(proj.parameters())
# Validate all are nn.Parameter
clip_params = [p for p in clip_params if isinstance(p, nn.Parameter)]
try:
nn.utils.clip_grad_norm_(clip_params, max_norm=1.0)
except Exception as e:
logger.error(f"Gradient clipping failed even with fresh params: {e}")
logger.error(
f"Fresh clip_params types: {[type(p).__name__ for p in clip_params[:5]]}"
)
logger.warning("Skipping gradient clipping - continuing without it")
# Enhanced diagnostics: grad norms and gradient flow validation
try:
if used_lora:
# Get only LoRA parameters from the model
current_lora_params = [
p for p in model.parameters() if p.requires_grad
]
lora_grads = [
p.grad.detach().norm().item()
for p in current_lora_params
if p.grad is not None
]
mean_lora_grad = (
sum(lora_grads) / len(lora_grads) if lora_grads else 0.0
)
max_lora_grad = max(lora_grads) if lora_grads else 0.0
# Detailed gradient flow analysis with time tracking
zero_grad_params = []
nonzero_grad_params = []
null_grad_params = []
# Separate tracking for lora_A and lora_B
lora_A_active = []
lora_B_active = []
lora_A_grads = []
lora_B_grads = []
for n, p in model.named_parameters():
if "lora_" in n and p.requires_grad:
if p.grad is None:
null_grad_params.append(n)
elif p.grad.abs().sum() == 0:
zero_grad_params.append(n)
else:
grad_norm = p.grad.abs().sum().item()
nonzero_grad_params.append((n, grad_norm))
# Track lora_A vs lora_B separately
if "lora_A" in n:
lora_A_active.append(n)
lora_A_grads.append(grad_norm)
elif "lora_B" in n:
lora_B_active.append(n)
lora_B_grads.append(grad_norm)
zero_grad_count = len(zero_grad_params) + len(null_grad_params)
total_lora_params = len(current_lora_params)
# Update gradient tracking history
lora_gradient_history["steps"].append(step)
lora_gradient_history["lora_A_active"].append(len(lora_A_active))
lora_gradient_history["lora_B_active"].append(len(lora_B_active))
lora_gradient_history["lora_A_avg_grad"].append(
sum(lora_A_grads) / len(lora_A_grads) if lora_A_grads else 0.0
)
lora_gradient_history["lora_B_avg_grad"].append(
sum(lora_B_grads) / len(lora_B_grads) if lora_B_grads else 0.0
)
lora_gradient_history["total_active"].append(
len(lora_A_active) + len(lora_B_active)
)
# Calculate gradient ratio (B/A) for tracking balance
a_avg = (
sum(lora_A_grads) / len(lora_A_grads) if lora_A_grads else 1e-8
)
b_avg = (
sum(lora_B_grads) / len(lora_B_grads) if lora_B_grads else 1e-8
)
gradient_ratio = b_avg / a_avg if a_avg > 1e-8 else float("inf")
lora_gradient_history["gradient_ratio"].append(gradient_ratio)
# Report gradient flow at key intervals
should_report = step == 0 or step % 50 == 0 or step % 100 == 0
if should_report:
logger.info(
f"=== LoRA Gradient Flow Analysis (Step {step}) ==="
)
logger.info(f"Total LoRA params: {total_lora_params}")
# Calculate expected counts dynamically (should be total_lora_params / 2)
expected_A_B_count = total_lora_params // 2
logger.info(
f"LoRA A active: {len(lora_A_active)}/{expected_A_B_count} (avg_grad={sum(lora_A_grads)/len(lora_A_grads) if lora_A_grads else 0:.6f})"
)
logger.info(
f"LoRA B active: {len(lora_B_active)}/{expected_A_B_count} (avg_grad={sum(lora_B_grads)/len(lora_B_grads) if lora_B_grads else 0:.6f})"
)
logger.info(f"Gradient ratio (B/A): {gradient_ratio:.2f}")
logger.info(
f"Total active: {len(lora_A_active) + len(lora_B_active)}/{total_lora_params} ({100*(len(lora_A_active) + len(lora_B_active))/total_lora_params:.1f}%)"
)
if nonzero_grad_params:
logger.info("Modules receiving gradients:")
gradient_summary = {}
for name, grad_norm in nonzero_grad_params[:10]:
module_type = (
name.split(".")[-2] if "." in name else "unknown"
)
if module_type not in gradient_summary:
gradient_summary[module_type] = []
gradient_summary[module_type].append((name, grad_norm))
for module_type, params in gradient_summary.items():
logger.info(
f" {module_type}: {len(params)} params, avg_grad={sum(p[1] for p in params)/len(params):.6f}"
)
if null_grad_params:
logger.warning(
"Modules with NULL gradients (not in forward path):"
)
module_types = {}
for name in null_grad_params[:10]:
module_type = (
name.split(".")[-2] if "." in name else "unknown"
)
module_types[module_type] = (
module_types.get(module_type, 0) + 1
)
for module_type, count in module_types.items():
logger.warning(
f" {module_type}: {count} params with null gradients"
)
if zero_grad_params:
logger.warning(
"Modules with ZERO gradients (in path but no gradient):"
)
module_types = {}
lora_types = {}
for name in zero_grad_params:
module_type = (
name.split(".")[-2] if "." in name else "unknown"
)
module_types[module_type] = (
module_types.get(module_type, 0) + 1
)
# Check if it's lora_A or lora_B
if "lora_A" in name:
lora_types["lora_A"] = (
lora_types.get("lora_A", 0) + 1
)
elif "lora_B" in name:
lora_types["lora_B"] = (
lora_types.get("lora_B", 0) + 1
)
for module_type, count in module_types.items():
logger.warning(
f" {module_type}: {count} params with zero gradients"
)
if lora_types:
logger.warning(
f"LoRA matrix breakdown (zero grads): {lora_types}"
)
# Show some specific examples
logger.warning(
f"Zero gradient examples: {zero_grad_params[:5]}"
)
if zero_grad_count == total_lora_params:
logger.error(
"CRITICAL: No gradients flowing to ANY LoRA params!"
)
elif zero_grad_count > total_lora_params * 0.5:
logger.warning(
f"CONCERNING: {zero_grad_count}/{total_lora_params} ({100*zero_grad_count/total_lora_params:.1f}%) LoRA params have no gradients"
)
else:
logger.info(
f"Acceptable gradient flow: {total_lora_params - zero_grad_count}/{total_lora_params} LoRA params receiving gradients"
)
else:
mean_lora_grad = float("nan")
max_lora_grad = float("nan")
head_grads = [
p.grad.detach().norm().item()
for p in proj.parameters()
if p.grad is not None
]
mean_head_grad = (
sum(head_grads) / len(head_grads) if head_grads else 0.0
)
max_head_grad = max(head_grads) if head_grads else 0.0
if step == 0:
logger.info(
f"Grad norms (first step): lora_mean={mean_lora_grad:.6f} lora_max={max_lora_grad:.6f} proj_mean={mean_head_grad:.6f} proj_max={max_head_grad:.6f}"
)
# Periodic gradient monitoring
if step % 100 == 0 and step > 0:
logger.debug(
f"Step {step} grad norms: lora_mean={mean_lora_grad:.6f} proj_mean={mean_head_grad:.6f}"
)
# Check for vanishing/exploding gradients
if used_lora and mean_lora_grad < 1e-8:
logger.warning(
f"Step {step}: Very small LoRA gradients ({mean_lora_grad:.2e}), possible vanishing gradient issue"
)
elif used_lora and mean_lora_grad > 10.0:
logger.warning(
f"Step {step}: Large LoRA gradients ({mean_lora_grad:.2e}), possible exploding gradient issue"
)
except Exception as e:
logger.warning(f"Failed to compute gradient diagnostics: {e}")
# Increment accumulation step
accumulation_step += 1
# Only step optimizer at end of accumulation cycle
if accumulation_step == cfg.gradient_accumulation_steps:
opt.step()
accumulation_step = 0 # Reset accumulation counter
step += 1 # Only increment global step after optimizer step
epoch_steps += 1
# Store unscaled loss for logging (multiply back by accumulation steps)
current_loss = loss.item() * cfg.gradient_accumulation_steps
running_loss += current_loss
epoch_loss += current_loss
loss_history.append(current_loss)
if step % 20 == 0:
avg_loss = running_loss / 20
logger.info(
f"LoRA adapt epoch {epoch+1} step {step}: loss={current_loss:.4f} (avg_20={avg_loss:.4f})"
)
running_loss = 0.0
if cfg.max_steps and step >= cfg.max_steps:
break
# Log epoch summary
if epoch_steps > 0:
epoch_avg = epoch_loss / epoch_steps
logger.info(
f"Epoch {epoch+1} completed: avg_loss={epoch_avg:.4f} ({epoch_steps} steps)"
)
if cfg.max_steps and step >= cfg.max_steps:
break
# Post-training diagnostics
if used_lora and lora_before is not None:
try:
deltas = []
norms = []
for n, p in _lora_named_params(model):
if n in lora_before:
dp = (p.detach() - lora_before[n]).float()
deltas.append(dp.norm().item())
norms.append(p.detach().float().norm().item())
if deltas:
mean_delta = sum(deltas) / len(deltas)
max_delta = max(deltas)
mean_norm = sum(norms) / len(norms) if norms else float("nan")
logger.info(
f"LoRA param delta norms: mean={mean_delta:.6f} max={max_delta:.6f} current_mean_norm={mean_norm:.6f}"
)
except Exception:
pass
# Enhanced post-adaptation diagnostics
try:
if pre_emb is not None:
# Evaluate post-adaptation embeddings in eval mode for determinism
was_training = model.training
model.eval()
with torch.no_grad():
first_batch = next(iter(dataloader))[0].to(device)
B = min(first_batch.shape[0], pre_emb.shape[0])
# BEATs requires per-sample processing
zs = []
for i in range(B):
# Extract single sample: first_batch[i] should be [C, T] after indexing
e, _ = _extract_features_with_optional_grad(
first_batch[i], want_grad=False
)
zs.append(_pool_embedding(e))
post_emb = torch.stack(zs, dim=0)
# Compute multiple similarity metrics
pre_n = nn.functional.normalize(pre_emb[:B], dim=-1)
post_n = nn.functional.normalize(post_emb[:B], dim=-1)
cos_sim = (pre_n * post_n).sum(dim=-1)
# L2 distance
l2_dist = torch.norm(pre_emb[:B] - post_emb[:B], dim=-1)
# Embedding magnitude change
pre_mag = torch.norm(pre_emb[:B], dim=-1)
post_mag = torch.norm(post_emb[:B], dim=-1)
mag_change = (post_mag - pre_mag) / pre_mag
logger.info(f"Adaptation metrics (over {B} samples):")
logger.info(
f" Cosine similarity pre/post: {cos_sim.mean().item():.6f} ± {cos_sim.std().item():.6f}"
)
logger.info(
f" L2 distance pre/post: {l2_dist.mean().item():.6f} ± {l2_dist.std().item():.6f}"
)
logger.info(
f" Magnitude change: {mag_change.mean().item():.6f} ± {mag_change.std().item():.6f}"
)
# Restore training state
model.train(was_training)
except Exception as e:
logger.warning(f"Failed to compute post-adaptation diagnostics: {e}")
# Comprehensive training summary with LoRA gradient trends
if loss_history:
final_losses = loss_history[-min(10, len(loss_history)) :]
logger.info(
"Final loss trend (last {} steps): {}",
len(final_losses),
[f"{loss_val:.4f}" for loss_val in final_losses],
)
if len(loss_history) > 20:
early_losses = loss_history[:10]
logger.info(
f"Early vs final loss: {sum(early_losses)/len(early_losses):.4f} -> {sum(final_losses)/len(final_losses):.4f}"
)
# LoRA gradient evolution summary
if lora_gradient_history["steps"]:
logger.info("=== LoRA Gradient Evolution Summary ===")
logger.info(f"Training steps: {len(lora_gradient_history['steps'])}")
# Initial vs final gradient patterns
initial_A = (
lora_gradient_history["lora_A_active"][0]
if lora_gradient_history["lora_A_active"]
else 0
)
initial_B = (
lora_gradient_history["lora_B_active"][0]
if lora_gradient_history["lora_B_active"]
else 0
)
final_A = (
lora_gradient_history["lora_A_active"][-1]
if lora_gradient_history["lora_A_active"]
else 0
)
final_B = (
lora_gradient_history["lora_B_active"][-1]
if lora_gradient_history["lora_B_active"]
else 0
)
expected_count = (
total_lora_params // 2 if "total_lora_params" in locals() else 72
)
logger.info(
f"LoRA A evolution: {initial_A}/{expected_count} -> {final_A}/{expected_count} active params ({final_A - initial_A:+d})"
)
logger.info(
f"LoRA B evolution: {initial_B}/{expected_count} -> {final_B}/{expected_count} active params ({final_B - initial_B:+d})"
)
# Gradient magnitude trends
initial_A_grad = (
lora_gradient_history["lora_A_avg_grad"][0]
if lora_gradient_history["lora_A_avg_grad"]
else 0
)
final_A_grad = (
lora_gradient_history["lora_A_avg_grad"][-1]
if lora_gradient_history["lora_A_avg_grad"]
else 0
)
initial_B_grad = (
lora_gradient_history["lora_B_avg_grad"][0]
if lora_gradient_history["lora_B_avg_grad"]
else 0
)
final_B_grad = (
lora_gradient_history["lora_B_avg_grad"][-1]
if lora_gradient_history["lora_B_avg_grad"]
else 0
)
logger.info(
f"LoRA A avg gradient: {initial_A_grad:.6f} -> {final_A_grad:.6f} ({final_A_grad/initial_A_grad if initial_A_grad > 1e-8 else float('inf'):.2f}x)"
)
logger.info(
f"LoRA B avg gradient: {initial_B_grad:.6f} -> {final_B_grad:.6f} ({final_B_grad/initial_B_grad if initial_B_grad > 1e-8 else float('inf'):.2f}x)"
)
# Overall utilization trend
initial_total = (
lora_gradient_history["total_active"][0]
if lora_gradient_history["total_active"]
else 0
)
final_total = (
lora_gradient_history["total_active"][-1]
if lora_gradient_history["total_active"]
else 0
)
total_expected = expected_count * 2
logger.info(
f"Total LoRA utilization: {initial_total}/{total_expected} -> {final_total}/{total_expected} ({100*final_total/total_expected:.1f}%, {final_total-initial_total:+d})"
)
# Gradient ratio evolution (helpful for understanding LoRA balance)
initial_ratio = (
lora_gradient_history["gradient_ratio"][0]
if lora_gradient_history["gradient_ratio"]
else float("inf")
)
final_ratio = (
lora_gradient_history["gradient_ratio"][-1]
if lora_gradient_history["gradient_ratio"]
else float("inf")
)
logger.info(
f"Gradient ratio (B/A) evolution: {initial_ratio:.2f} -> {final_ratio:.2f}"
)
if len(lora_gradient_history["steps"]) > 10:
# Show trend over middle of training
mid_idx = len(lora_gradient_history["steps"]) // 2
mid_total = lora_gradient_history["total_active"][mid_idx]
logger.info(
f"Mid-training utilization: {mid_total}/{total_expected} ({100*mid_total/total_expected:.1f}%)"
)
# Detect if LoRA utilization is improving
trend = (
"improving"
if final_total > initial_total
else "stable"
if final_total == initial_total
else "declining"
)
logger.info(f"LoRA utilization trend: {trend}")
# Ensure model is in eval mode for downstream evaluation
model.eval()
# Validate LoRA adapter state
if used_lora:
try:
# Check if LoRA adapters are properly enabled
if hasattr(model, "set_adapter"):
model.set_adapter("default")
if hasattr(model, "enable_adapter_layers"):
model.enable_adapter_layers()
# Count active LoRA parameters
active_lora_params = sum(
1
for n, p in model.named_parameters()
if "lora_" in n and p.requires_grad
)
total_lora_params = sum(
1 for n, p in model.named_parameters() if "lora_" in n
)
logger.info(
f"LoRA adapter state: {active_lora_params}/{total_lora_params} LoRA params active"
)
except Exception as e:
logger.warning(f"Failed to validate LoRA adapter state: {e}")
logger.info("LoRA adaptation finished; model set to eval mode")
return model