2025 - Adapter-Guided Distillation for ASR
Whisper-large-v2 + Hidden-State Alighment
-> Whisper-large-v2 → distil-whisper/distil-small.en
+ LoRA / Hidden-State Alignment
Teacher (Whisper-large-v2): ≈ 3 GB (FP16)
Student (distil-small) + LoRA + Optimizer + Gradients: ≈ 2.3 GB
Activations (per GPU): ≈ 6–8 GB
Batch collision margin: ≈ 2 GB for temporary buffering, communication, AMP cache
-> Single card peak ~ 13–15 GB VRAM -> Start your Experiment with FP16 + AMP
In **LoRATrainer.fwd()**, bypass PeftModel.forward() and **directly call the underlying model model.model** that has been injected with LoRA weights
For LoRA
1. Encoder fine-tuning: If the input feature distribution is quite different from that in pre-training, you can also consider adding LoRA to a few key layers of the encoder (such as the top self_attn) to help feature representations better adapt to downstream tasks.
2. Bias: If you set bias="none" in LoraConfig and find that the convergence is not smooth enough, you can try bias="all" or bias="lora_only" to leave some room for adjustment of the bias.
Teacher
- Model:
openai/whisper-large-v2
- 📍 ≈1.54 B parameters (FP16) - Input: raw waveform → 80-channel log-Mel spectrogram (mono, 16 kHz)
- Encoder
- Hidden size: 1 280
- Layers: 32
- Sequence length: ~ 1 500 frames
- All parameters frozen
- Decoder
- Auto-regressive transformer LM
- Hidden size: 1 280
- Layers: 32
- Output: token logits over vocabulary
- No parameters updated
Input Audio → Encoder → Decoder (Auto-Regressive) → Transcript Tokens
Student
- Backbone Model: distil-whisper/distil-small.en - 📍 ≈166 M parameters (FP16)
- Hidden size: 768
- Encoder
- Same 80-channel log-Mel input
- Layers: 12 (inherited from Whisper-small)
- All parameters frozen
- Decoder
- Layers: 4 (pre-distilled)
from transformers import WhisperForConditionalGeneration model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-small.en") print(len(model.model.decoder.layers)) <- 4#
- Auto-regressive transformer LM
- LoRA injection into every decoder layer:
- Rank r = [4,8,16], alpha = xx, dropout = 0.1 <- Small data volume → 0.1–0.2; Large data volume → 0.05–0.1.
-
Target modules per layer
- Decoder (Self-Attn)
self_attn.q_proj
,self_attn.k_proj
,self_attn.v_proj
,self_attn.out_proj
- Decoder(Encoder-Decoder Attn)
encoder_attn.q_proj
,encoder_attn.k_proj
,encoder_attn.v_proj
,encoder_attn.out_proj
- Decoder(Feed-Forward)
fc1
,fc2
- Decoder (Self-Attn)
- Trainable parameters: If only the 40 LoRA adapter tensors → 📍 ≈1.3 M scalars *~0.8 % of student, ~ 0.08 % of Whisper-large)
- simple 80/10/10 split for val/test, Do not touch your Test Set, SEED = 42
- INT8 - Inference - Post-Training Quantization
Input Audio → Encoder (frozen, 768-d)
↓
Decoder (4 layers) + LoRA adapters
↓
Auto-regressive token logits
H_student( B×T×768 )
──proj──▶ H_proj( B×T×1280 )
vs. H_teacher( B×T×1280 )
loss = MSE(H_proj, H_teacher)
In deep models, we project inputs into a high-dimensional Latent Space
whose coordinates are not directly observable or annotated. In this project, the dimensional mismatch between Teacher (1280-d) and Student (768-d)
creates the core challenge of Hidden Space Alignment
Due to no direct correspondence exists
between teacher and student hidden representations, we encounter 'no ground truth for cross-dimension alignment'
— the model relies on surrogate objectives like projection layers and MSE loss to align internal representations
Knowledge distillation, feature matching, and representation alignment
all depend on “proxy objectives” to bridge dimension differences. This framework directly applies reconstruction loss (MSE), soft distillation loss (KL), and hard supervision loss (CTC)
In this project, we explore methods for students to learn from teacher's hidden representations through dimensional projection
and visualize the alignment process
Run | Trainable Params | ≈ % of Teacher | ≈ % of Student | LoRA rank (r) | α (LoRA) | dropout | T (temp) | β (hidden MSE) | kl_w | Unfrozen Encoder Layers | LoRA Injection Layers |
---|---|---|---|---|---|---|---|---|---|---|---|
distil-small.en (baseline) | |||||||||||
Cell 2.4 (LoRA) | 32 | 16 | 0.1657 | 3.4066 | 1.9738 | 0.1522 | |||||
Cell 2.5 (large-r LoRA) | 64 | 16 | 0.1 | 1.0 | 0.0 | 0.0 | |||||
Cell 2.6 (Our work) | |||||||||||
Cell 2.7 (auxiliary teacher) |
Add a projection layer to ensure dimension alignment, and the projected student hidden state is aligned with the teacher hidden state by MSE for Hidden-State / Encoder Alignment Loss
Even without a decoder
, the student encoder
can internalize the teacher’s linguistic knowledge by Mimicking its Output Distributions
and Hidden‐State Representations
–> 📍 KL Loss + MSE Loss
Add Projection Layer - For The Distillation
- Student (whisper-small): 768 dimensions
- Teacher (whisper-large-v3): 1280 dimensions
-
Projection Layer
- Linear(768 → 1280) to align student and teacher hidden dimensions
Teacher (Whisper-large-v2) Student (distil-small.en + LoRA + Hidden Align)
───────────────────────────── ──────────────────────────────────────────────────
Audio Input Audio Input
1 × T samples 1 × T samples
│ │
▼ ▼
Whisper Encoder Whisper Encoder
1280-d hidden, L~1500 frames 768-d hidden, L~499 frames
(32 layers, FROZEN) (12 layers, FROZEN)
│ │
│ │
├─────── Hidden States ──────────────────────── ├─── Projection Layer ───┐
│ (B,1500,1280) │ (768→1280) │
│ │ │
▼ ▼ ▼
Whisper Decoder Whisper Decoder Aligned Hidden
(32 layers, FROZEN) (4 layers +/ LoRA) (B,1500,1280)
│ │ │
│ │ │
▼ ▼ │
Teacher Logits ────── Soft Targets ─────────▶ Student Logits │
(B,seq,vocab) (KL Loss) (B,seq,vocab) │
│ T=temperature │ │
│ │ │
│ ▼ │
│ Hard Labels ◀── Ground Truth │
│ (CTC Loss) │
│ │ │
│ │ │
│ ▼ │
│ Student Loss ◀─── MSE Loss ──────┘
│ │ (Hidden Align)
│ │
▼ ▼
No parameter updates (LoRA) + Projection parameters
(Inference only) ONLY these are trained
Teacher Encoder Output (t_h)
┌─────────────────────────────────────────┐
│ t_h: shape = (B, T≈1500, 1280) │
│ │
│ ┌─────┐ ┌─────────────┐ ┌─────────────┐│
│ │ B │ │ T_frames │ │ Hidden_dim ││
│ │ ├─│ ≈1500 ├─│ 1280 ││
│ └─────┘ └─────────────┘ └─────────────┘│
└─────────────────────────────────────────┘
Student Encoder Output (s_h)
┌──────────────────────────────────────────┐
│ s_h: shape = (B, T≈499, 768) │
│ │
│ ┌─────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ B │ │ T_frames │ │ Hidden_dim │ │
│ │ ├─│ ≈499 ├─│ 768 │ │
│ └─────┘ └─────────────┘ └─────────────┘ │
└──────────────────────────────────────────┘
Why T≈499
- Whisper feature extractor
- The original audio (30 s) generates about 3000 frames of 80-dimensional log-Mel features at a granularity of 10 ms per frame
- Before being fed into the Transformer encoder, these 3000 frames are first downsampled through a convolutional layer (stride=2), and then continuously merged or downsampled in the multi-layer Transformer block. The final output length is about 3000 / 2 / 3 = 500 frames (actually 499 frames)
📍 Visual Demo of the Dynamic Alignment
Distil-Whisper is also evaluated on the ESB benchmark
datasets as part of the OpenASR leaderboard, where it performs to within 0.2 % WER of Whisper
- Always remember to do Automatic checkpoint saving
- !pip install -U bitsandbytes>=0.41.0
- Put Your Teacher model on CPU
- MIN_DURATION = 2.0 # Same as Distil-Whisper-small.en
- MAX_DURATION = 30.0 # Same as Whisper-large-v2 maximum acceptance length
Paper | Venue | Data Size |
---|---|---|
DistilBERT | NeurIPS 2019 | 800 M words + 2.5 B words |
TinyBERT | EMNLP 2020 | 800 M words + 2.5 B words |
MobileBERT | ICLR 2020 | 800 M words + 2.5 B words |
Distil-Whisper-en | ≈ 22 000 hours of pseudo-labelled audio across 10 domains (>18 000 speakers) | |
Our Work | No Target Venue | ≈ XXXXX hours) |
In the design of LoRA, choosing which modules and with what rank 𝑟 to insert the adapter is essentially a trade-off between Parameter Overhead
and Adaptability
Orignial LoRA Paper
ΔW = A · B -> only low-rank increments are made to W_q and W_v in the attention
3 Choices of LoRA Injection
decoder.layers.*.encoder_attn.q_proj
decoder.layers.*.encoder_attn.v_proj
decoder.layers.*.self_attn.q_proj
decoder.layers.*.self_attn.v_proj
decoder.layers.*.encoder_attn.q_proj, encoder_attn.k_proj, encoder_attn.v_proj
decoder.layers.*.self_attn.q_proj, self_attn.k_proj, self_attn.v_proj
decoder.layers.*.fc2
decoder.layers.*.encoder_attn.q_proj, encoder_attn.k_proj, encoder_attn.v_proj, encoder_attn.out_proj
decoder.layers.*.self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.out_proj
decoder.layers.*.fc1, fc2
Features - LoRA
- End-to-end alignment: No extra alignment mechanism; the model learns acoustic-to-text alignment during training
- Scalable functionality: Supports ASR, speech translation, and multi-language recognition
- High decoding overhead: Requires decoder and beam search at each step, resulting in higher latency
- Balance: r = 8 is generally the most stable between effect and cost
- Pursuing the limit: r = 16 has the best expressiveness, but requires more video memory and gradients
Temperature
- Initial pilot temperature:
T =
- Search range:
[ ]
- Optuna hyperparameter: include
temp
as a tunable parameter - Guidance: prevent over-smoothing (i.e. avoid
T > 5
)
Hard vs. Soft Labels in Knowledge Distillation
-
Hard Labels: one-hot vectors from ground truth
y = [0, …, 1, …, 0]
• Strong supervision → binary certainty
• Forces correct classification -
Soft Labels: teacher’s softmax outputs
p_teacher = [0.6, 0.3, 0.1]
• Confidence & uncertainty
• Encodes inter-class similarity -
Combined Loss:
L_total = L_hard + α·T²·L_soft
•L_hard
: Cross-Entropy or CTC
•L_soft
: KL Divergence
•α
: balancing weight
•T
: temperature parameter
Why num_workers Affects GPU Performance
The num_workers parameter in PyTorch DataLoader controls the number of CPU processes responsible for data loading and preprocessing. This directly impacts GPU utilization through data pipeline optimization
Data Pipeline Architecture
Optimal Pipeline (num_workers > 0)
- CPU Thread 1: Load batch_1 → Preprocess → Transfer to GPU
- CPU Thread 2: Load batch_2 → Preprocess → Queue for transfer
- GPU: Process batch_1 while CPU prepares batch_2
Performance Comparison
Single-threaded (num_workers=0)
-
CPU: Load→Preprocess→Transfer GPU idle Load→Preprocess→Transfer -
GPU: Idle Compute Idle
Multi-threaded (num_workers=4)
-
CPU: Continuous data preparation (4 parallel threads)
-
GPU: Continuous computation (minimal idle time)
Key Insight
- Increasing num_workers enhances “CUDA kernel parallelism” not by adding GPU parallelism, but by eliminating GPU starvation. Multiple CPU workers ensure the GPU receives a steady stream of preprocessed data, maximizing hardware utilization and reducing training time
- The optimal num_workers typically ranges from 2-4 per GPU, depending on CPU core count and I/O bottlenecks
CTC Loss - Hard Supervision
Connectionist Temporal Classification (CTC) operates on the student’s latent feature sequence
\[\{\mathbf{h}_t\}_{t=1}^T\]mapping it onto the transcript tokens
\[\{y_u\}_{u=1}^U\]without explicit frame-wise labels
- Frame-to-Token Alignment
- Marginalizing Paths
- Gradient Signal
KL Distillation Loss - Soft Supervision
KL Distillation Loss compares the teacher’s and student’s posterior distributions over labels at each time-step in latent space
- Soft Distribution Matching
- Preference Transfer
- Capturing Uncertainty
Since the softmax outputs retain probabilities for all tokens, the KL term transfers the teacher’s uncertainty patterns—e.g., when the teacher is unsure between two phonemes, the student learns to mirror that ambiguity
Hidden-State / Encoder Alignment Loss - Representation-Level Supervision
Aligns the differences between teacher and student hidden states, enabling the student to learn the teacher’s internal “reasoning flow”
s_h = … # (B, T, 768)
t_h = … # (B, T, 1280)
s_proj = proj(s_h) # Linear(768→1280) → (B, T, 1280)
mse = F.mse_loss(s_proj, t_h)
where (s_i) and (t_i) are the student and teacher hidden components, and ((W,b)) are the learnable projection parameters.
Total Loss
\[L_{\text{total}} = L_{\mathrm{CTC}} + \alpha\,L_{\mathrm{KD}} + \beta\,L_{\mathrm{hidden\_align}}\]- (L_{\mathrm{CTC}}) is the hard cross-entropy (or CTC) loss
- (L_{\mathrm{KD}} = T^2\,\mathrm{KL}(p_{\rm teacher}^T\;|\;p_{\rm student}^T)) is the softened KL-divergence loss with temperature (T) and weight (\alpha)
- (L_{\mathrm{hidden_align}}) is the projected hidden-state MSE loss with weight (\beta)
Connectionist Temporal Classification (CTC) in Knowledge Distillation
-
Proposer & Year: Alex Graves et al. (2006 ICML)
- Motivation
- RNNs require exact frame-level alignment, which is unavailable in tasks like speech and handwriting recognition
- Key Innovations
- Automatic alignment
- Path marginalization over all possible alignments
- Blank token mechanism to handle repeats and separations
Fleurs en US - won’t be used
- FLEURS - ASR
- optional
Tasks: Automatic Speech Recognition Languages: Afrikaans, Amharic, Arabic, … + 99 Size: 10K < n < 100K ArXiv: 2205.12446, 2106.03193 Tags: speech-recognition License: cc-by-4.0 FLEURS: 102 languages × 550 samples = 56,100 samples - won’t be used
LibriSpeech
Tasks: Automatic Speech Recognition, Audio Classification (speaker-identification)
Languages: English
Size: 100K < n < 1M
License: CC-BY-4.0
from datasets import load_dataset, Audio
# 1. Load first xxxx examples of the clean-100 training split
# 📍 streaming=True Download otherwise you'll never succeed
xxx + xxx = ~ 60,862
See
Hyperparameter Optimization
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
def objective(trial):
# Distillation loss weights
alpha = trial.suggest_loguniform("alpha", 1e-3, 1e1)
beta = trial.suggest_loguniform("beta", 1e-3, 1e1)
# Optimization hyperparameters
lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
batch_size= trial.suggest_categorical("batch_size", [4, 8, 16, 32])
dropout = trial.suggest_float("dropout", 0.0, 0.5)
# Train & evaluate with these settings (implement train_and_evaluate accordingly)
wer = train_and_evaluate(
alpha=alpha,
beta=beta,
learning_rate=lr,
batch_size=batch_size,
dropout=dropout,
pruner=trial # for early stopping
)
return wer
# Pruner to stop unpromising trials early
pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=100)
sampler = TPESampler()
study = optuna.create_study(
direction="minimize",
sampler=sampler,
pruner=pruner
)
study.optimize(objective, n_trials=100)
print("Best hyperparameters:", study.best_params)
**Gradient Underflow**
feats = b["input_features"].half().to(device) <- FP32 to 16
self.scaler = **GradScaler()**
...
with autocast():
loss = model(input) # loss = float16
self.scaler.scale(loss).backward()
self.scaler.step(optimizer)
self.scaler.update()
PCA vs. t-SNE vs. UMAP vs. DTW
References
Connectionist Temporal Classification (CTC) in Knowledge Distillation
- Proposer & Year: Alex Graves et al. (2006 ICML)
- Motivation:
Frame–label alignment unavailable in speech/handwriting tasks - Mechanism:
- Automatic alignment of variable‐length audio to text
- Marginalization over all valid alignment paths
- Blank token to handle repeats and separations
- Role in Distillation:
Provides hard supervision—ensures correct sequence output without frame-level labels
Kullback–Leibler (KL) Distillation Loss
- Proposer & Year: Hinton et al. (2015 NIPS)
- Motivation:
Transfer “dark knowledge” (inter-class similarity and uncertainty) from teacher to student. - Mechanism:
- Compute teacher and student softmax distributions at each time step
- Apply temperature (T) to smooth distributions
- Minimize KL divergence between them
- Role in Distillation:
Provides soft supervision—guides student to match teacher’s probability patterns and improve generalization.
Rectification Loss (Representation-Level Supervision)
- Proposer & Year: Romero et al. (2014 ICLR “FitNets”)
- Motivation:
Teacher’s internal feature representations carry structural and reasoning cues beyond output labels. - Mechanism:
- Extract corresponding hidden-layer activations from teacher and student
- Minimize feature‐map discrepancy via L2 or similar loss
- Role in Distillation:
Provides intermediate supervision—aligns student’s internal representations with teacher’s, stabilizing training and preserving network–level knowledge.
Basic + Advanced Parallel
- First use linear projection + MSE as the basic alignment (to ensure training feasibility)
- At the same time, design a Group-wise Cross-Attention Projector (refer to ACCV2022) to capture more expressive mappings
Progressive Distillation Scheduling
In the early stage of training, only output distribution distillation + projection MSE is used;
Then gradually “unfreeze” the intermediate layer alignment loss, or add a blank frame / non-blank frame factorization strategy (refer to CTC-ASR).
Multi-stage Intermediary
If the teacher-student difference is too large, an auxiliary teacher with a more similar structure can be introduced to complete the alignment in two steps.
Outline
- 1 Introduction
- 2 Related Work
- 3 Methodology
- 3.1 Cross-Structure KD
- 3.2 LoRA Quantization
- 3.3 Training Details
- 3.4 Ablation: LoRA Only
- 3.5 Ablation: Hidden Alignment Only
- Group-wise Cross-Attention Projector
- An Auxiliary teacher
- 3.6 Combined LoRA + Hidden Alignment
- 4 Experimental Setup
- 4.1 Datasets
- 4.2 Hyper-parameter Search
- 5 Results
- 5.1 Main Results
- 5.2 Ablation Studies
- 6 Discussion
- 7 Conclusion & Future Work
- Appendices (code snippets, install, extra figs)
Topics in Cross-Architecture
-
2025 - Cross-Architecture Knowledge Distillation for Speech Enhancement: From Cmgan to Unet
-
2024 - Factorized and progressive knowledge distillation for CTC-based ASR models
VQ-VAE
Motivation for Discrete Latent Audio Representations
- **Information Bottleneck & Compression**
Continuous waveforms contain abundant redundancies; modeling them directly (e.g., WaveNet) is computationally expensive.
VQ-VAE quantizes latent vectors into a fixed-size codebook (e.g., 512 or 1024 entries), achieving lossy compression while preserving salient features.
Subsequent priors (PixelCNN/Transformer) operate on shorter discrete sequences, yielding far greater efficiency.
- **Advantages of Discrete Sequence Modeling**
Autoregressive Transformers, PixelCNNs, and RNNs excel at modeling discrete tokens.
Mapping audio to discrete tokens enables these models to capture long-range dependencies, linguistic structures, and prosody, without struggling in high-dimensional continuous spaces.
- **Denoising & Semantic Extraction**
Codebook entries tend to represent prototypical acoustic units (e.g., phonemes).
Quantization suppresses minor noise and emphasizes semantic/prosodic features, yielding more robust representations for synthesis and recognition.
- **Hierarchical & Multi-Rate Structure**
VQ-VAE-2 employs multi-scale codebooks:
- **Lower levels:** fine details (timbre, noise)
- **Higher levels:** global structure (intonation, rhythm)
This hierarchical discrete design is hard to achieve with continuous latent models.
1. NVIDIA Megatron-LM (GPT / BERT Pretraining)
- **Compute**: All forward/backward kernels (e.g., matrix multiplies, attention, activations) execute in **FP16** on Tensor Cores.
- **Master Weights & Optimizer State**: A **FP32** “master copy” of model weights is maintained; FP16 gradients are accumulated into FP32 weights, then cast back to FP16 for the next iteration.
- **Numerical Stability**: Overflow-sensitive ops (e.g., Softmax) remain in **FP32** to avoid divergence, without meaningful speed penalty.
- **Implementation**: Uses NVIDIA Apex/AMP or NeMo’s automatic mixed-precision modules to orchestrate autocasting, dynamic loss scaling, and master-weight management :contentReference[oaicite:0]{index=0}.
2. Google PaLM / AlphaFold2 (Transformer & Scientific Models)
- **Core Operators** (e.g., GEMM, LayerNorm, Dropout) run in **bfloat16** (BF16)—a 16-bit format with the same exponent range as FP32 but reduced mantissa.
- **Critical Sections** (e.g., logits normalization, loss computation) remain in **FP32** to prevent underflow/overflow.
- **Framework Support**: JAX/Flax users enable BF16 globally via flags like `jax.experimental.enable_x64=False` and per-op control via `jax.lax.cond` :contentReference[oaicite:1]{index=1}.
3. Stable Diffusion / GAN Inference
- **Inference Precision**: Entire U-Net, scheduler, and CLIP text encoder run in **FP16**, halving memory and nearly doubling throughput compared to FP32, with negligible impact on image quality.
- **Selective FP32 Fallback**: For maximum stability or visual fidelity (e.g., in the VAE decoder), critical layers can be temporarily cast back to **FP32**.
4. Frontiers: FP8-LM (8-bit Mixed-Precision)
- **FP8 Framework**: Base weights, gradients, and optimizer states use **8-bit FP8** (e.g., NVIDIA’s NF4 or Microsoft’s custom FP8); key stability ops (Softmax, LayerNorm) stay in FP16/BF16 or FP32.
- **Efficiency Gains**: In GPT-175B training, FP8 mixed-precision achieved a **75% speedup** and **39% memory reduction** over BF16 baselines—without changing hyperparameters :contentReference[oaicite:2]{index=2}.
- **Adoption**: Offers a drop-in replacement for existing FP16/BF16 pipelines, open-sourced via MS-AMP (aka.ms/MS.AMP).
*References*
- Micikevicius *et al.*, “Mixed Precision Training,” NVIDIA Developer Blog, 2018.
- Shoeybi *et al.*, “Megatron-LM: Training Multi-Billion Parameter Language Models,” *arXiv:1909.08053*, 2019.
- Narang *et al.*, “How To Fit a Bigger Model and Train It Faster,” Hugging Face Docs, 2022.
- Peng *et al.*, “FP8-LM: Training FP8 Large Language Models,” *arXiv:2310.18313*, 2023.