2025 - Adapter-Guided Distillation for ASR

Whisper-large-v2 + Hidden-State Alighment


-> Whisper-large-v2 → distil-whisper/distil-small.en + LoRA / Hidden-State Alignment


Teacher (Whisper-large-v2): ≈ 3 GB (FP16)
Student (distil-small) + LoRA + Optimizer + Gradients: ≈ 2.3 GB
Activations (per GPU): ≈ 6–8 GB
Batch collision margin: ≈ 2 GB for temporary buffering, communication, AMP cache

-> Single card peak ~ 13–15 GB VRAM -> Start your Experiment with FP16 + AMP

In **LoRATrainer.fwd()**, bypass PeftModel.forward() and **directly call the underlying model model.model** that has been injected with LoRA weights
For LoRA
1. Encoder fine-tuning: If the input feature distribution is quite different from that in pre-training, you can also consider adding LoRA to a few key layers of the encoder (such as the top self_attn) to help feature representations better adapt to downstream tasks.
2. Bias: If you set bias="none" in LoraConfig and find that the convergence is not smooth enough, you can try bias="all" or bias="lora_only" to leave some room for adjustment of the bias.


Teacher

  • Model: openai/whisper-large-v2 - 📍 ≈1.54 B parameters (FP16)
  • Input: raw waveform → 80-channel log-Mel spectrogram (mono, 16 kHz)
  • Encoder
    • Hidden size: 1 280
    • Layers: 32
    • Sequence length: ~ 1 500 frames
    • All parameters frozen
  • Decoder
    • Auto-regressive transformer LM
    • Hidden size: 1 280
    • Layers: 32
    • Output: token logits over vocabulary
    • No parameters updated

Input Audio → Encoder → Decoder (Auto-Regressive) → Transcript Tokens


Student

  • Backbone Model: distil-whisper/distil-small.en - 📍 ≈166 M parameters (FP16)
  • Hidden size: 768
  • Encoder
    • Same 80-channel log-Mel input
    • Layers: 12 (inherited from Whisper-small)
    • All parameters frozen
  • Decoder
    • Layers: 4 (pre-distilled)
    from transformers import WhisperForConditionalGeneration
    model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-small.en")
    print(len(model.model.decoder.layers)) <- 4#
    
    • Auto-regressive transformer LM
    • LoRA injection into every decoder layer:
      • Rank r = [4,8,16], alpha = xx, dropout = 0.1 <- Small data volume → 0.1–0.2; Large data volume → 0.05–0.1.
      • Target modules per layer

        1. Decoder (Self-Attn) self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.out_proj
        2. Decoder(Encoder-Decoder Attn)encoder_attn.q_proj, encoder_attn.k_proj, encoder_attn.v_proj, encoder_attn.out_proj
        3. Decoder(Feed-Forward)fc1, fc2
  • Trainable parameters: If only the 40 LoRA adapter tensors → 📍 ≈1.3 M scalars *~0.8 % of student, ~ 0.08 % of Whisper-large)
  • simple 80/10/10 split for val/test, Do not touch your Test Set, SEED = 42
  • INT8 - Inference - Post-Training Quantization


Input Audio → Encoder (frozen, 768-d)
  ↓
Decoder (4 layers) + LoRA adapters
  ↓
Auto-regressive token logits
H_student( B×T×768 ) 
    ──proj──▶ H_proj( B×T×1280 ) 
       vs. H_teacher( B×T×1280 ) 
       loss = MSE(H_proj, H_teacher)

In deep models, we project inputs into a high-dimensional Latent Space whose coordinates are not directly observable or annotated. In this project, the dimensional mismatch between Teacher (1280-d) and Student (768-d) creates the core challenge of Hidden Space Alignment

Due to no direct correspondence exists between teacher and student hidden representations, we encounter 'no ground truth for cross-dimension alignment' — the model relies on surrogate objectives like projection layers and MSE loss to align internal representations

Knowledge distillation, feature matching, and representation alignment all depend on “proxy objectives” to bridge dimension differences. This framework directly applies reconstruction loss (MSE), soft distillation loss (KL), and hard supervision loss (CTC)

In this project, we explore methods for students to learn from teacher's hidden representations through dimensional projection and visualize the alignment process


Run Trainable Params ≈ % of Teacher ≈ % of Student LoRA rank (r) α (LoRA) dropout T (temp) β (hidden MSE) kl_w Unfrozen Encoder Layers LoRA Injection Layers
distil-small.en (baseline)                      
Cell 2.4 (LoRA)       32 16 0.1657 3.4066 1.9738 0.1522    
Cell 2.5 (large-r LoRA)       64 16 0.1 1.0 0.0 0.0    
Cell 2.6 (Our work)                      
Cell 2.7 (auxiliary teacher)                      


Add a projection layer to ensure dimension alignment, and the projected student hidden state is aligned with the teacher hidden state by MSE for Hidden-State / Encoder Alignment Loss

Even without a decoder, the student encoder can internalize the teacher’s linguistic knowledge by Mimicking its Output Distributions and Hidden‐State Representations –> 📍 KL Loss + MSE Loss


Add Projection Layer - For The Distillation

  • Student (whisper-small): 768 dimensions
  • Teacher (whisper-large-v3): 1280 dimensions
  • Projection Layer - Linear(768 → 1280) to align student and teacher hidden dimensions


Teacher (Whisper-large-v2)                     Student (distil-small.en + LoRA + Hidden Align)
─────────────────────────────                  ──────────────────────────────────────────────────

Audio Input                                     Audio Input  
1 × T samples                                   1 × T samples
     │                                               │
     ▼                                               ▼
Whisper Encoder                                 Whisper Encoder
1280-d hidden, L~1500 frames                    768-d hidden, L~499 frames
(32 layers, FROZEN)                            (12 layers, FROZEN)
     │                                               │
     │                                               │
     ├─────── Hidden States ──────────────────────── ├─── Projection Layer ───┐
     │        (B,1500,1280)                          │    (768→1280)          │
     │                                               │                        │
     ▼                                               ▼                        ▼
Whisper Decoder                                 Whisper Decoder              Aligned Hidden
(32 layers, FROZEN)                            (4 layers +/ LoRA)            (B,1500,1280)
     │                                               │                        │
     │                                               │                        │
     ▼                                               ▼                        │
Teacher Logits ────── Soft Targets ─────────▶ Student Logits                  │
(B,seq,vocab)         (KL Loss)               (B,seq,vocab)                   │
     │                 T=temperature              │                           │
     │                                            │                           │
     │                                            ▼                           │
     │                                      Hard Labels ◀── Ground Truth      │
     │                                      (CTC Loss)                        │
     │                                            │                           │
     │                                            │                           │
     │                                            ▼                           │
     │                                       Student Loss ◀─── MSE Loss ──────┘
     │                                            │         (Hidden Align)
     │                                            │
     ▼                                            ▼
No parameter updates                        (LoRA) + Projection parameters
(Inference only)                             ONLY these are trained


Teacher Encoder Output (t_h)
┌─────────────────────────────────────────┐
│  t_h: shape = (B, T≈1500, 1280)         │
│                                         │
│  ┌─────┐ ┌─────────────┐ ┌─────────────┐│
│  │ B   │ │  T_frames   │ │ Hidden_dim  ││
│  │     ├─│ ≈1500       ├─│ 1280        ││
│  └─────┘ └─────────────┘ └─────────────┘│
└─────────────────────────────────────────┘

Student Encoder Output (s_h)
┌──────────────────────────────────────────┐
│  s_h: shape = (B, T≈499, 768)            │
│                                          │
│  ┌─────┐ ┌─────────────┐ ┌─────────────┐ │
│  │ B   │ │  T_frames   │ │ Hidden_dim  │ │
│  │     ├─│  ≈499       ├─│ 768         │ │
│  └─────┘ └─────────────┘ └─────────────┘ │
└──────────────────────────────────────────┘


Why T≈499

  • Whisper feature extractor
  • The original audio (30 s) generates about 3000 frames of 80-dimensional log-Mel features at a granularity of 10 ms per frame
  • Before being fed into the Transformer encoder, these 3000 frames are first downsampled through a convolutional layer (stride=2), and then continuously merged or downsampled in the multi-layer Transformer block. The final output length is about 3000 / 2 / 3 = 500 frames (actually 499 frames)


📍 Visual Demo of the Dynamic Alignment


Distil-Whisper is also evaluated on the ESB benchmark datasets as part of the OpenASR leaderboard, where it performs to within 0.2 % WER of Whisper


  • Always remember to do Automatic checkpoint saving
  • !pip install -U bitsandbytes>=0.41.0
  • Put Your Teacher model on CPU
  • MIN_DURATION = 2.0 # Same as Distil-Whisper-small.en
  • MAX_DURATION = 30.0 # Same as Whisper-large-v2 maximum acceptance length


Paper Venue Data Size
DistilBERT NeurIPS 2019 800 M words + 2.5 B words
TinyBERT EMNLP 2020 800 M words + 2.5 B words
MobileBERT ICLR 2020 800 M words + 2.5 B words
Distil-Whisper-en   ≈ 22 000 hours of pseudo-labelled audio across 10 domains (>18 000 speakers)
Our Work No Target Venue XXXXX hours)


Knowledge Map


In the design of LoRA, choosing which modules and with what rank 𝑟 to insert the adapter is essentially a trade-off between Parameter Overhead and Adaptability

Orignial LoRA Paper

ΔW = A · B -> only low-rank increments are made to W_q and W_v in the attention


3 Choices of LoRA Injection

decoder.layers.*.encoder_attn.q_proj
decoder.layers.*.encoder_attn.v_proj
decoder.layers.*.self_attn.q_proj
decoder.layers.*.self_attn.v_proj
decoder.layers.*.encoder_attn.q_proj, encoder_attn.k_proj, encoder_attn.v_proj
decoder.layers.*.self_attn.q_proj, self_attn.k_proj, self_attn.v_proj
decoder.layers.*.fc2
decoder.layers.*.encoder_attn.q_proj, encoder_attn.k_proj, encoder_attn.v_proj, encoder_attn.out_proj
decoder.layers.*.self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.out_proj
decoder.layers.*.fc1, fc2


Features - LoRA

  • End-to-end alignment: No extra alignment mechanism; the model learns acoustic-to-text alignment during training
  • Scalable functionality: Supports ASR, speech translation, and multi-language recognition
  • High decoding overhead: Requires decoder and beam search at each step, resulting in higher latency
  • Balance: r = 8 is generally the most stable between effect and cost
  • Pursuing the limit: r = 16 has the best expressiveness, but requires more video memory and gradients


Temperature

  • Initial pilot temperature: T =
  • Search range: [ ]
  • Optuna hyperparameter: include temp as a tunable parameter
  • Guidance: prevent over-smoothing (i.e. avoid T > 5)


Hard vs. Soft Labels in Knowledge Distillation

  • Hard Labels: one-hot vectors from ground truth
    y = [0, …, 1, …, 0]
    • Strong supervision → binary certainty
    • Forces correct classification

  • Soft Labels: teacher’s softmax outputs
    p_teacher = [0.6, 0.3, 0.1]
    • Confidence & uncertainty
    • Encodes inter-class similarity

  • Combined Loss:
    L_total = L_hard + α·T²·L_soft
    L_hard: Cross-Entropy or CTC
    L_soft: KL Divergence
    α: balancing weight
    T: temperature parameter


Why num_workers Affects GPU Performance

The num_workers parameter in PyTorch DataLoader controls the number of CPU processes responsible for data loading and preprocessing. This directly impacts GPU utilization through data pipeline optimization

Data Pipeline Architecture

Optimal Pipeline (num_workers > 0)

  • CPU Thread 1: Load batch_1 → Preprocess → Transfer to GPU
  • CPU Thread 2: Load batch_2 → Preprocess → Queue for transfer
  • GPU: Process batch_1 while CPU prepares batch_2


Performance Comparison

Single-threaded (num_workers=0)

  • CPU: Load→Preprocess→Transfer GPU idle Load→Preprocess→Transfer
  • GPU: Idle Compute Idle

Multi-threaded (num_workers=4)

  • CPU: Continuous data preparation (4 parallel threads)

  • GPU: Continuous computation (minimal idle time)

Key Insight

  • Increasing num_workers enhances “CUDA kernel parallelism” not by adding GPU parallelism, but by eliminating GPU starvation. Multiple CPU workers ensure the GPU receives a steady stream of preprocessed data, maximizing hardware utilization and reducing training time
  • The optimal num_workers typically ranges from 2-4 per GPU, depending on CPU core count and I/O bottlenecks


CTC Loss - Hard Supervision

Connectionist Temporal Classification (CTC) operates on the student’s latent feature sequence

\[\{\mathbf{h}_t\}_{t=1}^T\]

mapping it onto the transcript tokens

\[\{y_u\}_{u=1}^U\]

without explicit frame-wise labels

  • Frame-to-Token Alignment
  • Marginalizing Paths
  • Gradient Signal



KL Distillation Loss - Soft Supervision

KL Distillation Loss compares the teacher’s and student’s posterior distributions over labels at each time-step in latent space

  • Soft Distribution Matching
  • Preference Transfer
  • Capturing Uncertainty


Since the softmax outputs retain probabilities for all tokens, the KL term transfers the teacher’s uncertainty patterns—e.g., when the teacher is unsure between two phonemes, the student learns to mirror that ambiguity



Hidden-State / Encoder Alignment Loss - Representation-Level Supervision

Aligns the differences between teacher and student hidden states, enabling the student to learn the teacher’s internal “reasoning flow”

s_h = …           # (B, T, 768)
t_h = …           # (B, T, 1280)
s_proj = proj(s_h)  # Linear(768→1280) → (B, T, 1280)
mse = F.mse_loss(s_proj, t_h)
\[\mathcal{L}_{\text{hidden\_align}} = \frac{1}{N}\sum_{i=1}^{N} \bigl\lVert W\,s_i + b \;-\; t_i\bigr\rVert_{2}^{2}, \quad N = B \times T \times 1280\]


where (s_i) and (t_i) are the student and teacher hidden components, and ((W,b)) are the learnable projection parameters.


Total Loss

\[L_{\text{total}} = L_{\mathrm{CTC}} + \alpha\,L_{\mathrm{KD}} + \beta\,L_{\mathrm{hidden\_align}}\]


  • (L_{\mathrm{CTC}}) is the hard cross-entropy (or CTC) loss
  • (L_{\mathrm{KD}} = T^2\,\mathrm{KL}(p_{\rm teacher}^T\;|\;p_{\rm student}^T)) is the softened KL-divergence loss with temperature (T) and weight (\alpha)
  • (L_{\mathrm{hidden_align}}) is the projected hidden-state MSE loss with weight (\beta)


Connectionist Temporal Classification (CTC) in Knowledge Distillation

  • Proposer & Year: Alex Graves et al. (2006 ICML)

  • Motivation
    • RNNs require exact frame-level alignment, which is unavailable in tasks like speech and handwriting recognition
  • Key Innovations
    • Automatic alignment
    • Path marginalization over all possible alignments
    • Blank token mechanism to handle repeats and separations



Fleurs en US - won’t be used

  • FLEURS - ASR
  • optional
    Tasks: Automatic Speech Recognition Languages: Afrikaans, Amharic, Arabic, … + 99 Size: 10K < n < 100K ArXiv: 2205.12446, 2106.03193 Tags: speech-recognition License: cc-by-4.0 FLEURS: 102 languages ​​× 550 samples = 56,100 samples - won’t be used



LibriSpeech


Tasks: Automatic Speech Recognition, Audio Classification (speaker-identification)
Languages: English
Size: 100K < n < 1M
License: CC-BY-4.0

from datasets import load_dataset, Audio

# 1. Load first xxxx examples of the clean-100 training split
# 📍 streaming=True Download otherwise you'll never succeed
xxx + xxx = ~ 60,862
See



Hyperparameter Optimization

import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler

def objective(trial):
    # Distillation loss weights
    alpha = trial.suggest_loguniform("alpha", 1e-3, 1e1)
    beta  = trial.suggest_loguniform("beta",  1e-3, 1e1)
    # Optimization hyperparameters
    lr        = trial.suggest_loguniform("lr",        1e-5, 1e-3)
    batch_size= trial.suggest_categorical("batch_size", [4, 8, 16, 32])
    dropout   = trial.suggest_float("dropout", 0.0, 0.5)
    
    # Train & evaluate with these settings (implement train_and_evaluate accordingly)
    wer = train_and_evaluate(
        alpha=alpha,
        beta=beta,
        learning_rate=lr,
        batch_size=batch_size,
        dropout=dropout,
        pruner=trial  # for early stopping
    )
    return wer

# Pruner to stop unpromising trials early
pruner  = MedianPruner(n_startup_trials=5, n_warmup_steps=100)
sampler = TPESampler()

study = optuna.create_study(
    direction="minimize",
    sampler=sampler,
    pruner=pruner
)
study.optimize(objective, n_trials=100)

print("Best hyperparameters:", study.best_params)
**Gradient Underflow**
feats = b["input_features"].half().to(device) <- FP32 to 16

self.scaler = **GradScaler()**
...
with autocast():
    loss = model(input)  # loss = float16
self.scaler.scale(loss).backward() 
self.scaler.step(optimizer)  
self.scaler.update()



PCA vs. t-SNE vs. UMAP vs. DTW


Project 1 Visualization




References


Connectionist Temporal Classification (CTC) in Knowledge Distillation

  • Proposer & Year: Alex Graves et al. (2006 ICML)
  • Motivation:
    Frame–label alignment unavailable in speech/handwriting tasks
  • Mechanism:
    1. Automatic alignment of variable‐length audio to text
    2. Marginalization over all valid alignment paths
    3. Blank token to handle repeats and separations
  • Role in Distillation:
    Provides hard supervision—ensures correct sequence output without frame-level labels


Kullback–Leibler (KL) Distillation Loss

  • Proposer & Year: Hinton et al. (2015 NIPS)
  • Motivation:
    Transfer “dark knowledge” (inter-class similarity and uncertainty) from teacher to student.
  • Mechanism:
    1. Compute teacher and student softmax distributions at each time step
    2. Apply temperature (T) to smooth distributions
    3. Minimize KL divergence between them
  • Role in Distillation:
    Provides soft supervision—guides student to match teacher’s probability patterns and improve generalization.


Rectification Loss (Representation-Level Supervision)

  • Proposer & Year: Romero et al. (2014 ICLR “FitNets”)
  • Motivation:
    Teacher’s internal feature representations carry structural and reasoning cues beyond output labels.
  • Mechanism:
    1. Extract corresponding hidden-layer activations from teacher and student
    2. Minimize feature‐map discrepancy via L2 or similar loss
  • Role in Distillation:
    Provides intermediate supervision—aligns student’s internal representations with teacher’s, stabilizing training and preserving network–level knowledge.



Basic + Advanced Parallel

  • First use linear projection + MSE as the basic alignment (to ensure training feasibility)
  • At the same time, design a Group-wise Cross-Attention Projector (refer to ACCV2022) to capture more expressive mappings

Progressive Distillation Scheduling

In the early stage of training, only output distribution distillation + projection MSE is used;

Then gradually “unfreeze” the intermediate layer alignment loss, or add a blank frame / non-blank frame factorization strategy (refer to CTC-ASR).

Multi-stage Intermediary

If the teacher-student difference is too large, an auxiliary teacher with a more similar structure can be introduced to complete the alignment in two steps.

Outline

  • 1 Introduction
  • 2 Related Work
  • 3 Methodology
    • 3.1 Cross-Structure KD
    • 3.2 LoRA Quantization
    • 3.3 Training Details
    • 3.4 Ablation: LoRA Only
    • 3.5 Ablation: Hidden Alignment Only
      • Group-wise Cross-Attention Projector
      • An Auxiliary teacher
    • 3.6 Combined LoRA + Hidden Alignment
  • 4 Experimental Setup
    • 4.1 Datasets
    • 4.2 Hyper-parameter Search
  • 5 Results
    • 5.1 Main Results
    • 5.2 Ablation Studies
  • 6 Discussion
  • 7 Conclusion & Future Work
  • Appendices (code snippets, install, extra figs)


Topics in Cross-Architecture


VQ-VAE

Motivation for Discrete Latent Audio Representations

- **Information Bottleneck & Compression**  
  Continuous waveforms contain abundant redundancies; modeling them directly (e.g., WaveNet) is computationally expensive.  
  VQ-VAE quantizes latent vectors into a fixed-size codebook (e.g., 512 or 1024 entries), achieving lossy compression while preserving salient features.  
  Subsequent priors (PixelCNN/Transformer) operate on shorter discrete sequences, yielding far greater efficiency.

- **Advantages of Discrete Sequence Modeling**  
  Autoregressive Transformers, PixelCNNs, and RNNs excel at modeling discrete tokens.  
  Mapping audio to discrete tokens enables these models to capture long-range dependencies, linguistic structures, and prosody, without struggling in high-dimensional continuous spaces.

- **Denoising & Semantic Extraction**  
  Codebook entries tend to represent prototypical acoustic units (e.g., phonemes).  
  Quantization suppresses minor noise and emphasizes semantic/prosodic features, yielding more robust representations for synthesis and recognition.

- **Hierarchical & Multi-Rate Structure**  
  VQ-VAE-2 employs multi-scale codebooks:  
  - **Lower levels:** fine details (timbre, noise)  
  - **Higher levels:** global structure (intonation, rhythm)  
  This hierarchical discrete design is hard to achieve with continuous latent models.


1. NVIDIA Megatron-LM (GPT / BERT Pretraining)

- **Compute**: All forward/backward kernels (e.g., matrix multiplies, attention, activations) execute in **FP16** on Tensor Cores.  
- **Master Weights & Optimizer State**: A **FP32** “master copy” of model weights is maintained; FP16 gradients are accumulated into FP32 weights, then cast back to FP16 for the next iteration.  
- **Numerical Stability**: Overflow-sensitive ops (e.g., Softmax) remain in **FP32** to avoid divergence, without meaningful speed penalty.  
- **Implementation**: Uses NVIDIA Apex/AMP or NeMo’s automatic mixed-precision modules to orchestrate autocasting, dynamic loss scaling, and master-weight management :contentReference[oaicite:0]{index=0}.  

2. Google PaLM / AlphaFold2 (Transformer & Scientific Models)

- **Core Operators** (e.g., GEMM, LayerNorm, Dropout) run in **bfloat16** (BF16)—a 16-bit format with the same exponent range as FP32 but reduced mantissa.  
- **Critical Sections** (e.g., logits normalization, loss computation) remain in **FP32** to prevent underflow/overflow.  
- **Framework Support**: JAX/Flax users enable BF16 globally via flags like `jax.experimental.enable_x64=False` and per-op control via `jax.lax.cond` :contentReference[oaicite:1]{index=1}.  


3. Stable Diffusion / GAN Inference

- **Inference Precision**: Entire U-Net, scheduler, and CLIP text encoder run in **FP16**, halving memory and nearly doubling throughput compared to FP32, with negligible impact on image quality.  
- **Selective FP32 Fallback**: For maximum stability or visual fidelity (e.g., in the VAE decoder), critical layers can be temporarily cast back to **FP32**.  

4. Frontiers: FP8-LM (8-bit Mixed-Precision)

- **FP8 Framework**: Base weights, gradients, and optimizer states use **8-bit FP8** (e.g., NVIDIA’s NF4 or Microsoft’s custom FP8); key stability ops (Softmax, LayerNorm) stay in FP16/BF16 or FP32.  
- **Efficiency Gains**: In GPT-175B training, FP8 mixed-precision achieved a **75% speedup** and **39% memory reduction** over BF16 baselines—without changing hyperparameters :contentReference[oaicite:2]{index=2}.  
- **Adoption**: Offers a drop-in replacement for existing FP16/BF16 pipelines, open-sourced via MS-AMP (aka.ms/MS.AMP).  

*References*  
- Micikevicius *et al.*, “Mixed Precision Training,” NVIDIA Developer Blog, 2018.  
- Shoeybi *et al.*, “Megatron-LM: Training Multi-Billion Parameter Language Models,” *arXiv:1909.08053*, 2019.  
- Narang *et al.*, “How To Fit a Bigger Model and Train It Faster,” Hugging Face Docs, 2022.  
- Peng *et al.*, “FP8-LM: Training FP8 Large Language Models,” *arXiv:2310.18313*, 2023.




References