In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(torch.cuda.get_device_name(0))

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

import bitsandbytes
import peft
import transformers

print(transformers.__version__)

print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
PyTorch Version: 2.5.1+cu121
True
1
CUDA Version: 12.1
NVIDIA GeForce RTX 4080 Laptop GPU
4.50.0.dev0
bitsandbytes version: 0.45.3
peft version: 0.15.2.dev0
True

Load dataset, base model, and tokeniser¶

In [2]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score

# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")

# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)

dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)

# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
         .from_pretrained(model_ckpt, num_labels=num_labels)
         .to(device))
#print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [3]:
# Helper functions
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}
    
def count_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params, 100 * trainable_params / total_params

def freeze_model_layers(model, unfreeze_pre_classifier=False):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze LoRA and DoRA-specific params, including lora_norm
    for name, param in model.named_parameters():
        if (
            "lora.A" in name
            or "lora.B" in name
            or "lora_norm" in name  
            or name.endswith(".m")   # For DoRA
            or name.endswith(".m_in") # For DDoRA
            or name.endswith(".m_out") # For DDoRA
            or "scale" in name
        ):
            param.requires_grad = True

    # Unfreeze classifier layer (always)
    for name, param in model.named_parameters():
        if name.startswith("classifier."):
            param.requires_grad = True

    # unfreeze pre-classifier
    if unfreeze_pre_classifier:
        for name, param in model.named_parameters():
            if name.startswith("pre_classifier."):
                param.requires_grad = True

def monitor_lora_parameters(model, threshold=1e-7):
    monitor = {
        "A_abs_mean": [],
        "B_abs_mean": [],
        "A_grad_mean": [],
        "B_grad_mean": [],
        "lora_output_norm": [],
        "B_nonzero_count": [],
    }
    hooks = []

    for name, module in model.named_modules():
        if hasattr(module, "lora") and hasattr(module.lora, "A") and hasattr(module.lora, "B"):
            A_param = module.lora.A
            B_param = module.lora.B

            # Gradient hooks (directly on nn.Parameter)
            if A_param.requires_grad:
                hooks.append(A_param.register_hook(lambda grad, n=name: monitor["A_grad_mean"].append((n, grad.abs().mean().item()))))
            if B_param.requires_grad:
                hooks.append(B_param.register_hook(lambda grad, n=name: monitor["B_grad_mean"].append((n, grad.abs().mean().item()))))

            # Forward hook for value stats
            def forward_hook(mod, inp, out, n=name):
                A_mean = mod.lora.A.abs().mean().item()
                B_mean = mod.lora.B.abs().mean().item()
                B_nnz = (mod.lora.B.abs() > threshold).sum().item()
                monitor["A_abs_mean"].append((n, A_mean))
                monitor["B_abs_mean"].append((n, B_mean))
                monitor["B_nonzero_count"].append((n, B_nnz))
                monitor["lora_output_norm"].append((n, mod.last_lora_output_norm))

            hooks.append(module.register_forward_hook(forward_hook))

    return hooks, monitor

LoRA¶

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev)  # Not all zeroes!
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Dropout applied to the projection to the lower-dimensional space by A
        dropped = self.dropout(x @ self.A)
        return self.alpha * (dropped @ self.B)

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.last_lora_output_norm = 0.0  # for monitoring

    def forward(self, x):
        #return self.linear(x) + self.lora(x)
        lora_out = self.lora(x)
        self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
        return self.linear(x) + lora_out


# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    # target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            lora_linear.lora.dropout = nn.Dropout(dropout_rate)
            
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
In [5]:
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=2,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[1564/1564 56:47, Epoch 2/2]
Step Training Loss Validation Loss Accuracy F1
50 0.504900 0.335573 0.869600 0.869783
100 0.341200 0.318371 0.862400 0.862357
150 0.361400 0.307639 0.876000 0.875739
200 0.303400 0.296421 0.870400 0.870685
250 0.324500 0.276511 0.886400 0.886654
300 0.295500 0.261722 0.890400 0.889511
350 0.312700 0.276859 0.896800 0.896362
400 0.281200 0.253817 0.890400 0.890564
450 0.257400 0.244995 0.909600 0.909240
500 0.284400 0.264725 0.898400 0.898376
550 0.297700 0.238351 0.908800 0.908708
600 0.287900 0.237734 0.909600 0.909143
650 0.234500 0.224168 0.913600 0.913528
700 0.247900 0.222799 0.915200 0.915187
750 0.256000 0.227275 0.916000 0.916031
800 0.241700 0.282384 0.899200 0.898261
850 0.190800 0.242396 0.904000 0.904121
900 0.169500 0.258911 0.910400 0.910400
950 0.208600 0.260492 0.907200 0.906995
1000 0.189500 0.241334 0.913600 0.913649
1050 0.182400 0.214967 0.924800 0.924764
1100 0.192600 0.222719 0.915200 0.915213
1150 0.170000 0.231555 0.912000 0.911767
1200 0.182200 0.242794 0.916800 0.916928
1250 0.174700 0.221050 0.917600 0.917461
1300 0.187400 0.229759 0.916800 0.916896
1350 0.173600 0.224172 0.917600 0.917641
1400 0.163200 0.224013 0.920800 0.920870
1450 0.170500 0.221920 0.920800 0.920806
1500 0.165600 0.220823 0.918400 0.918361
1550 0.171000 0.221188 0.920800 0.920840

distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002572, |∇A|=0.000277, |∇B|=0.08185, |LoRA(x)|=4.874, B≠0=12283
distilbert.transformer.layer.0.attention.k_lin: |A|=0.1992, |B|=0.0002632, |∇A|=0.0002245, |∇B|=0.07715, |LoRA(x)|=5.663, B≠0=12283
distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002103, |∇A|=0.0005365, |∇B|=0.1825, |LoRA(x)|=4.493, B≠0=12282
distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.000215, |∇A|=0.0004854, |∇B|=0.183, |LoRA(x)|=2.415, B≠0=12283
distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002391, |∇A|=0.000489, |∇B|=0.1047, |LoRA(x)|=15.38, B≠0=49133
distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002197, |∇A|=0.0001981, |∇B|=0.1984, |LoRA(x)|=4.228, B≠0=12282
distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002441, |∇A|=0.0002117, |∇B|=0.07861, |LoRA(x)|=5.422, B≠0=12283
distilbert.transformer.layer.1.attention.k_lin: |A|=0.2008, |B|=0.0002601, |∇A|=0.0002199, |∇B|=0.07653, |LoRA(x)|=4.799, B≠0=12283
distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001886, |∇A|=0.0003676, |∇B|=0.185, |LoRA(x)|=3.285, B≠0=12282
distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002105, |∇A|=0.0004261, |∇B|=0.1959, |LoRA(x)|=2.38, B≠0=12283
distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.000237, |∇A|=0.0004655, |∇B|=0.119, |LoRA(x)|=19.11, B≠0=49132
distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002041, |∇A|=0.0001751, |∇B|=0.1963, |LoRA(x)|=2.361, B≠0=12282
distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002397, |∇A|=0.0002494, |∇B|=0.09958, |LoRA(x)|=5.855, B≠0=12283
distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002423, |∇A|=0.0002714, |∇B|=0.1016, |LoRA(x)|=5.16, B≠0=12283
distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001928, |∇A|=0.0003262, |∇B|=0.1879, |LoRA(x)|=4.02, B≠0=12282
distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0001992, |∇A|=0.0003639, |∇B|=0.1642, |LoRA(x)|=2.252, B≠0=12282
distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002306, |∇A|=0.0005194, |∇B|=0.1186, |LoRA(x)|=19.02, B≠0=49132
distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002021, |∇A|=0.0002104, |∇B|=0.1721, |LoRA(x)|=2.549, B≠0=12283
distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002554, |∇A|=0.000265, |∇B|=0.09461, |LoRA(x)|=5.758, B≠0=12283
distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002541, |∇A|=0.0002945, |∇B|=0.09943, |LoRA(x)|=6.224, B≠0=12283
distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001641, |∇A|=0.0003354, |∇B|=0.2089, |LoRA(x)|=4.087, B≠0=12281
distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.000194, |∇A|=0.0004308, |∇B|=0.14, |LoRA(x)|=2.848, B≠0=12282
distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002212, |∇A|=0.0004882, |∇B|=0.07969, |LoRA(x)|=11.88, B≠0=49131
distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001834, |∇A|=0.0001973, |∇B|=0.1337, |LoRA(x)|=2.234, B≠0=12282
distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002504, |∇A|=0.0001735, |∇B|=0.07399, |LoRA(x)|=8.787, B≠0=12283
distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002486, |∇A|=0.0003957, |∇B|=0.07656, |LoRA(x)|=6.261, B≠0=12283
distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001708, |∇A|=0.0003348, |∇B|=0.11, |LoRA(x)|=3.692, B≠0=12281
distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0001745, |∇A|=0.000366, |∇B|=0.07387, |LoRA(x)|=2.375, B≠0=12282
distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001853, |∇A|=0.0002817, |∇B|=0.04258, |LoRA(x)|=16.52, B≠0=49128
distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001656, |∇A|=0.0001065, |∇B|=0.0697, |LoRA(x)|=2.054, B≠0=12282
distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002194, |∇A|=8.952e-05, |∇B|=0.02991, |LoRA(x)|=8.03, B≠0=12283
distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002024, |∇A|=0.0002036, |∇B|=0.03099, |LoRA(x)|=5.661, B≠0=12282
distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001348, |∇A|=0.0002633, |∇B|=0.04537, |LoRA(x)|=3.664, B≠0=12280
distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001793, |∇A|=0.0002507, |∇B|=0.03697, |LoRA(x)|=9.559, B≠0=12282
distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.000147, |∇A|=0.0001532, |∇B|=0.01653, |LoRA(x)|=17.16, B≠0=49123
distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001564, |∇A|=5.048e-05, |∇B|=0.02768, |LoRA(x)|=5.171, B≠0=12282

Training summary¶

  1. Loss consistently decreases, and validation accuracy steadily improves until ~step 1050. Best observed performance: step ~1050: Accuracy: 92.48%, F1 Score: 92.48%, Val Loss: 0.214967. After that, metrics plateau or slightly fluctuate, suggesting diminishing returns or slight overfitting.
  2. LoRA adapts low-rank matrices 𝐴 and 𝐵 on each linear layer. ∣A∣ ~ 0.198 – 0.201 across layers (as expected, A is initialised and scaled normally). ∣B∣ ~ 0.00015 – 0.00026 - nearly 3 orders of magnitude smaller than A - expected as the B matrix starts close to zero and learns slowly. ∣LoRA(x)∣ ranges reasonably: ~2–19.
  3. LoRA's effective update is the product of BA. So even though A is large, B remains small early on → effective update is very low-rank and small in norm, especially during the first few dozen steps. This supports what we might be seeing in loss curves: slow or flat loss change initially. Effective Layer Utilisation, especially Layer 0–3: balanced activity across attention and FFN.
  4. Gradients: Gradients:∣∇A∣ ~ 0.0002–0.0005 (small but consistent updates); ∣∇B∣: larger, ranging up to 0.2, especially in v_lin and lin2 — LoRA is learning more actively in B.
In [6]:
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=2,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[1564/1564 57:27, Epoch 2/2]
Step Training Loss Validation Loss Accuracy F1
50 0.522300 0.358082 0.854400 0.853802
100 0.374700 0.310322 0.862400 0.862513
150 0.386100 0.318261 0.869600 0.869504
200 0.326900 0.307559 0.869600 0.867806
250 0.351500 0.314006 0.877600 0.877870
300 0.332300 0.272534 0.892000 0.891881
350 0.343300 0.279443 0.893600 0.893502
400 0.327400 0.285108 0.890400 0.889991
450 0.295900 0.258271 0.898400 0.898208
500 0.327000 0.285135 0.883200 0.882071
550 0.311600 0.249657 0.904000 0.903745
600 0.315400 0.261444 0.902400 0.902096
650 0.290800 0.253156 0.896800 0.896605
700 0.299500 0.243638 0.900800 0.900537
750 0.297200 0.240773 0.908000 0.907932
800 0.283300 0.261479 0.903200 0.902599
850 0.277300 0.226801 0.907200 0.907317
900 0.265200 0.268644 0.904800 0.904983
950 0.276900 0.250334 0.909600 0.909370
1000 0.245300 0.251263 0.908800 0.908708
1050 0.241300 0.231262 0.914400 0.914366
1100 0.244200 0.239346 0.904000 0.903903
1150 0.242700 0.226735 0.908800 0.908741
1200 0.236600 0.237783 0.909600 0.909579
1250 0.220600 0.229475 0.912000 0.911986
1300 0.243500 0.234786 0.908800 0.908756
1350 0.246600 0.229699 0.905600 0.905699
1400 0.238500 0.222471 0.916800 0.916560
1450 0.221000 0.223475 0.910400 0.910400
1500 0.226000 0.222452 0.916000 0.915892
1550 0.209200 0.223851 0.915200 0.915145

distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002581, |∇A|=0.0003806, |∇B|=0.1261, |LoRA(x)|=5.469, B≠0=12283
distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0002605, |∇A|=0.0003199, |∇B|=0.1166, |LoRA(x)|=6.167, B≠0=12283
distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002149, |∇A|=0.0005936, |∇B|=0.2469, |LoRA(x)|=5.135, B≠0=12283
distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.0002184, |∇A|=0.0005894, |∇B|=0.2659, |LoRA(x)|=2.72, B≠0=12283
distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002434, |∇A|=0.0006892, |∇B|=0.1599, |LoRA(x)|=17.21, B≠0=49132
distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002225, |∇A|=0.0002834, |∇B|=0.2989, |LoRA(x)|=4.928, B≠0=12283
distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002619, |∇A|=0.0003127, |∇B|=0.1176, |LoRA(x)|=6.326, B≠0=12283
distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0002593, |∇A|=0.0003151, |∇B|=0.1215, |LoRA(x)|=5.5, B≠0=12283
distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001929, |∇A|=0.0004882, |∇B|=0.257, |LoRA(x)|=3.703, B≠0=12282
distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002113, |∇A|=0.0005305, |∇B|=0.2663, |LoRA(x)|=2.635, B≠0=12283
distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0002375, |∇A|=0.0006094, |∇B|=0.1708, |LoRA(x)|=21.37, B≠0=49133
distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002082, |∇A|=0.0002324, |∇B|=0.2649, |LoRA(x)|=2.712, B≠0=12283
distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002454, |∇A|=0.0003173, |∇B|=0.1344, |LoRA(x)|=6.335, B≠0=12283
distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002375, |∇A|=0.0003538, |∇B|=0.1436, |LoRA(x)|=5.793, B≠0=12283
distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001974, |∇A|=0.0004345, |∇B|=0.2487, |LoRA(x)|=4.533, B≠0=12282
distilbert.transformer.layer.2.attention.out_lin: |A|=0.1984, |B|=0.0001982, |∇A|=0.0004565, |∇B|=0.2209, |LoRA(x)|=2.561, B≠0=12283
distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002298, |∇A|=0.0005989, |∇B|=0.1531, |LoRA(x)|=20.14, B≠0=49131
distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002027, |∇A|=0.0002184, |∇B|=0.2143, |LoRA(x)|=2.758, B≠0=12282
distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002525, |∇A|=0.0002761, |∇B|=0.1089, |LoRA(x)|=6.166, B≠0=12283
distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002438, |∇A|=0.0003425, |∇B|=0.1276, |LoRA(x)|=6.506, B≠0=12283
distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001687, |∇A|=0.0003466, |∇B|=0.2575, |LoRA(x)|=4.666, B≠0=12282
distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0001906, |∇A|=0.0004309, |∇B|=0.171, |LoRA(x)|=3.289, B≠0=12282
distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002211, |∇A|=0.0005016, |∇B|=0.1029, |LoRA(x)|=14.59, B≠0=49131
distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001856, |∇A|=0.0002331, |∇B|=0.1571, |LoRA(x)|=2.627, B≠0=12282
distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002418, |∇A|=0.0001731, |∇B|=0.08101, |LoRA(x)|=10.64, B≠0=12283
distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002293, |∇A|=0.0003568, |∇B|=0.08671, |LoRA(x)|=6.066, B≠0=12283
distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001757, |∇A|=0.0003265, |∇B|=0.1101, |LoRA(x)|=4.307, B≠0=12282
distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.000175, |∇A|=0.0003284, |∇B|=0.07959, |LoRA(x)|=2.766, B≠0=12282
distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001862, |∇A|=0.0002638, |∇B|=0.04591, |LoRA(x)|=16.07, B≠0=49129
distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001647, |∇A|=0.0001085, |∇B|=0.08194, |LoRA(x)|=2.308, B≠0=12281
distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002014, |∇A|=9.653e-05, |∇B|=0.03063, |LoRA(x)|=7.295, B≠0=12282
distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002036, |∇A|=0.000199, |∇B|=0.03414, |LoRA(x)|=6.269, B≠0=12283
distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001532, |∇A|=0.0002869, |∇B|=0.05179, |LoRA(x)|=5.935, B≠0=12282
distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001674, |∇A|=0.0001988, |∇B|=0.04984, |LoRA(x)|=8.743, B≠0=12282
distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0001539, |∇A|=0.0001583, |∇B|=0.02406, |LoRA(x)|=18.93, B≠0=49125
distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001408, |∇A|=5.189e-05, |∇B|=0.04019, |LoRA(x)|=5.108, B≠0=12281

Training summary¶

With dropout = 0.4 (applied after projection with matrix A and before the final projection with matrix B) B is forced to adapt more robustly over time - without any significant loss of accuracy or F1. In effect, dropout acts as a reguliser, adding some noise.

Key Trends in |∇B|: Dropout = 0.4 amplifies ∇B across almost every layer — sometimes by 30–100% depending on depth and layer type. This means LoRA updates are more aggressive when dropout is used. It encourages larger updates to the B matrix, probably due to the increased noise and sparsity during training. Dropout significantly boosts LoRA's ∇B gradient magnitudes which are crucial for fast adaptation, capacity and expressivity.

Initialisation of lora.A¶

Matrix A maps from high-dimensional space to low-rank: it compresses features. Its output is often followed by dropout, which adds further noise. Matrix B does the heavy lifting—bringing compressed signals back up. Since LoRA is often trained with a low learning rate and A @ B is scaled by alpha / rank, small initialisation differences in A get washed out during early training steps. Changing A's init mostly affects early training (learning speed, initial activation variance), but doesn’t shift the final performance much. On the contrary, a poorly initialised B sends noise into the frozen model's internal representations. For this reason we initialise B as nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev). Another alternative, all zeroes (self.B = nn.Parameter(torch.zeros(rank, out_dim))) often works well but not always as symmetric initialisation can create a lot of harm.

Kaiming for A¶

In [7]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        #self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        # Kaiming initialisation for A
        std_dev2 = math.sqrt(2 / (in_dim + out_dim))
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev2)
        self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev)  # Not all zeroes!
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Dropout applied to the projection to the lower-dimensional space by A
        dropped = self.dropout(x @ self.A)
        return self.alpha * (dropped @ self.B)

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.last_lora_output_norm = 0.0  # for monitoring

    def forward(self, x):
        #return self.linear(x) + self.lora(x)
        lora_out = self.lora(x)
        self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
        return self.linear(x) + lora_out


# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    # target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            lora_linear.lora.dropout = nn.Dropout(dropout_rate)
            
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
In [8]:
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=1,
    #max_steps=200,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[782/782 28:34, Epoch 1/1]
Step Training Loss Validation Loss Accuracy F1
50 0.641200 0.498864 0.825600 0.825801
100 0.398900 0.378827 0.832000 0.831737
150 0.342200 0.294475 0.880000 0.880212
200 0.277900 0.277569 0.884800 0.884265
250 0.299900 0.280815 0.888000 0.888238
300 0.267600 0.269878 0.896800 0.896389
350 0.277600 0.252492 0.896800 0.896903
400 0.263900 0.252141 0.903200 0.903285
450 0.251700 0.254847 0.900000 0.900150
500 0.270300 0.241287 0.904000 0.903867
550 0.253100 0.240126 0.901600 0.901454
600 0.271100 0.235365 0.904000 0.903921
650 0.241800 0.243780 0.903200 0.903353
700 0.233400 0.235918 0.905600 0.905600
750 0.235500 0.234417 0.907200 0.907186

distilbert.transformer.layer.0.attention.q_lin: |A|=0.02909, |B|=0.0002327, |∇A|=0.0002211, |∇B|=0.007083, |LoRA(x)|=0.6601, B≠0=12282
distilbert.transformer.layer.0.attention.k_lin: |A|=0.02878, |B|=0.0002298, |∇A|=0.0001474, |∇B|=0.007112, |LoRA(x)|=0.7828, B≠0=12281
distilbert.transformer.layer.0.attention.v_lin: |A|=0.02878, |B|=0.0002287, |∇A|=0.0008088, |∇B|=0.01975, |LoRA(x)|=0.7497, B≠0=12282
distilbert.transformer.layer.0.attention.out_lin: |A|=0.02893, |B|=0.0002272, |∇A|=0.001325, |∇B|=0.02585, |LoRA(x)|=0.4959, B≠0=12282
distilbert.transformer.layer.0.ffn.lin1: |A|=0.01825, |B|=0.000217, |∇A|=0.0006296, |∇B|=0.009524, |LoRA(x)|=2.611, B≠0=49126
distilbert.transformer.layer.0.ffn.lin2: |A|=0.01823, |B|=0.0002204, |∇A|=0.0003984, |∇B|=0.01871, |LoRA(x)|=0.5402, B≠0=12282
distilbert.transformer.layer.1.attention.q_lin: |A|=0.02888, |B|=0.0002127, |∇A|=0.000235, |∇B|=0.008725, |LoRA(x)|=0.8818, B≠0=12280
distilbert.transformer.layer.1.attention.k_lin: |A|=0.02899, |B|=0.0002166, |∇A|=0.000186, |∇B|=0.007929, |LoRA(x)|=0.5567, B≠0=12281
distilbert.transformer.layer.1.attention.v_lin: |A|=0.02856, |B|=0.000197, |∇A|=0.0005595, |∇B|=0.03967, |LoRA(x)|=0.7959, B≠0=12281
distilbert.transformer.layer.1.attention.out_lin: |A|=0.02866, |B|=0.0002057, |∇A|=0.00109, |∇B|=0.03448, |LoRA(x)|=0.5205, B≠0=12282
distilbert.transformer.layer.1.ffn.lin1: |A|=0.0181, |B|=0.0002256, |∇A|=0.0005969, |∇B|=0.01394, |LoRA(x)|=4.284, B≠0=49126
distilbert.transformer.layer.1.ffn.lin2: |A|=0.01816, |B|=0.0001935, |∇A|=0.0003225, |∇B|=0.01989, |LoRA(x)|=0.2153, B≠0=12281
distilbert.transformer.layer.2.attention.q_lin: |A|=0.02859, |B|=0.0002196, |∇A|=0.0002355, |∇B|=0.01189, |LoRA(x)|=1.088, B≠0=12282
distilbert.transformer.layer.2.attention.k_lin: |A|=0.02892, |B|=0.0002278, |∇A|=0.0002871, |∇B|=0.01122, |LoRA(x)|=0.6679, B≠0=12282
distilbert.transformer.layer.2.attention.v_lin: |A|=0.02891, |B|=0.000196, |∇A|=0.0007276, |∇B|=0.03889, |LoRA(x)|=0.7448, B≠0=12281
distilbert.transformer.layer.2.attention.out_lin: |A|=0.02864, |B|=0.0002144, |∇A|=0.001153, |∇B|=0.03337, |LoRA(x)|=0.521, B≠0=12282
distilbert.transformer.layer.2.ffn.lin1: |A|=0.01813, |B|=0.0002415, |∇A|=0.001008, |∇B|=0.01763, |LoRA(x)|=5.14, B≠0=49126
distilbert.transformer.layer.2.ffn.lin2: |A|=0.01825, |B|=0.0002205, |∇A|=0.0005423, |∇B|=0.02212, |LoRA(x)|=0.3106, B≠0=12282
distilbert.transformer.layer.3.attention.q_lin: |A|=0.02898, |B|=0.0002311, |∇A|=0.000282, |∇B|=0.009444, |LoRA(x)|=1.016, B≠0=12281
distilbert.transformer.layer.3.attention.k_lin: |A|=0.02904, |B|=0.0002356, |∇A|=0.0002989, |∇B|=0.009808, |LoRA(x)|=0.9602, B≠0=12282
distilbert.transformer.layer.3.attention.v_lin: |A|=0.02898, |B|=0.0002056, |∇A|=0.0007878, |∇B|=0.05867, |LoRA(x)|=1.151, B≠0=12282
distilbert.transformer.layer.3.attention.out_lin: |A|=0.02881, |B|=0.0002671, |∇A|=0.002097, |∇B|=0.02614, |LoRA(x)|=0.4825, B≠0=12282
distilbert.transformer.layer.3.ffn.lin1: |A|=0.01824, |B|=0.0002449, |∇A|=0.001385, |∇B|=0.01429, |LoRA(x)|=3.554, B≠0=49127
distilbert.transformer.layer.3.ffn.lin2: |A|=0.01825, |B|=0.0002515, |∇A|=0.0008056, |∇B|=0.02044, |LoRA(x)|=0.3727, B≠0=12282
distilbert.transformer.layer.4.attention.q_lin: |A|=0.02899, |B|=0.0002635, |∇A|=0.000508, |∇B|=0.01414, |LoRA(x)|=2.408, B≠0=12282
distilbert.transformer.layer.4.attention.k_lin: |A|=0.02889, |B|=0.00025, |∇A|=0.000996, |∇B|=0.01001, |LoRA(x)|=1.021, B≠0=12282
distilbert.transformer.layer.4.attention.v_lin: |A|=0.02878, |B|=0.0002282, |∇A|=0.00123, |∇B|=0.03556, |LoRA(x)|=1.406, B≠0=12282
distilbert.transformer.layer.4.attention.out_lin: |A|=0.02887, |B|=0.000246, |∇A|=0.001401, |∇B|=0.0175, |LoRA(x)|=0.645, B≠0=12282
distilbert.transformer.layer.4.ffn.lin1: |A|=0.01833, |B|=0.0002607, |∇A|=0.0008604, |∇B|=0.009056, |LoRA(x)|=5.717, B≠0=49129
distilbert.transformer.layer.4.ffn.lin2: |A|=0.01822, |B|=0.0002721, |∇A|=0.0004477, |∇B|=0.01619, |LoRA(x)|=0.7457, B≠0=12282
distilbert.transformer.layer.5.attention.q_lin: |A|=0.02888, |B|=0.0002728, |∇A|=0.0002167, |∇B|=0.00796, |LoRA(x)|=1.832, B≠0=12282
distilbert.transformer.layer.5.attention.k_lin: |A|=0.02903, |B|=0.0002349, |∇A|=0.0007984, |∇B|=0.004929, |LoRA(x)|=2.066, B≠0=12281
distilbert.transformer.layer.5.attention.v_lin: |A|=0.02869, |B|=0.0002458, |∇A|=0.001274, |∇B|=0.01634, |LoRA(x)|=1.781, B≠0=12282
distilbert.transformer.layer.5.attention.out_lin: |A|=0.02881, |B|=0.0003804, |∇A|=0.001541, |∇B|=0.008762, |LoRA(x)|=1.76, B≠0=12283
distilbert.transformer.layer.5.ffn.lin1: |A|=0.01824, |B|=0.0002594, |∇A|=0.001008, |∇B|=0.003997, |LoRA(x)|=6.319, B≠0=49127
distilbert.transformer.layer.5.ffn.lin2: |A|=0.01822, |B|=0.0003608, |∇A|=0.0004907, |∇B|=0.005384, |LoRA(x)|=1.169, B≠0=12283

Scaled Xavier for A¶

In [9]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        #self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        # Scaled Xavier initialisation for A
        sc_f = 1.5
        bound = math.sqrt(6 / (rank + in_dim)) * sc_f
        self.A = nn.Parameter(torch.empty(in_dim, rank).uniform_(-bound, bound))
        self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev)  # Not all zeroes!
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Dropout applied to the projection to the lower-dimensional space by A
        dropped = self.dropout(x @ self.A)
        return self.alpha * (dropped @ self.B)

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.last_lora_output_norm = 0.0  # for monitoring

    def forward(self, x):
        #return self.linear(x) + self.lora(x)
        lora_out = self.lora(x)
        self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
        return self.linear(x) + lora_out


# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    # target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            lora_linear.lora.dropout = nn.Dropout(dropout_rate)
            
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
In [10]:
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=1,
    #max_steps=200,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[782/782 28:29, Epoch 1/1]
Step Training Loss Validation Loss Accuracy F1
50 0.559900 0.342404 0.867200 0.866935
100 0.335600 0.298127 0.878400 0.878565
150 0.330100 0.269384 0.892000 0.891881
200 0.256100 0.272999 0.888800 0.888328
250 0.297100 0.247388 0.900800 0.900662
300 0.267600 0.260106 0.892000 0.891542
350 0.270000 0.245957 0.905600 0.905600
400 0.255400 0.242078 0.904800 0.904934
450 0.244500 0.241970 0.912000 0.911895
500 0.257700 0.235908 0.908800 0.908494
550 0.241300 0.237024 0.910400 0.910055
600 0.257700 0.227201 0.913600 0.913390
650 0.231800 0.231505 0.910400 0.910538
700 0.227300 0.224408 0.918400 0.918237
750 0.222500 0.222870 0.916000 0.915892

distilbert.transformer.layer.0.attention.q_lin: |A|=0.06532, |B|=0.0002035, |∇A|=0.0001663, |∇B|=0.01523, |LoRA(x)|=1.255, B≠0=12281
distilbert.transformer.layer.0.attention.k_lin: |A|=0.06534, |B|=0.0002003, |∇A|=0.0001101, |∇B|=0.01507, |LoRA(x)|=1.249, B≠0=12281
distilbert.transformer.layer.0.attention.v_lin: |A|=0.06566, |B|=0.0001929, |∇A|=0.0006117, |∇B|=0.03535, |LoRA(x)|=1.217, B≠0=12281
distilbert.transformer.layer.0.attention.out_lin: |A|=0.0661, |B|=0.0001711, |∇A|=0.0007501, |∇B|=0.04806, |LoRA(x)|=0.6202, B≠0=12281
distilbert.transformer.layer.0.ffn.lin1: |A|=0.06538, |B|=0.0001835, |∇A|=0.0004738, |∇B|=0.0315, |LoRA(x)|=6.167, B≠0=49123
distilbert.transformer.layer.0.ffn.lin2: |A|=0.03304, |B|=0.0001718, |∇A|=0.000289, |∇B|=0.02706, |LoRA(x)|=0.4766, B≠0=12280
distilbert.transformer.layer.1.attention.q_lin: |A|=0.06618, |B|=0.0001879, |∇A|=0.0001798, |∇B|=0.01696, |LoRA(x)|=1.24, B≠0=12280
distilbert.transformer.layer.1.attention.k_lin: |A|=0.06591, |B|=0.000191, |∇A|=0.0001451, |∇B|=0.01575, |LoRA(x)|=0.9638, B≠0=12281
distilbert.transformer.layer.1.attention.v_lin: |A|=0.06548, |B|=0.0001601, |∇A|=0.0004498, |∇B|=0.06403, |LoRA(x)|=0.9895, B≠0=12280
distilbert.transformer.layer.1.attention.out_lin: |A|=0.06559, |B|=0.0001658, |∇A|=0.0006209, |∇B|=0.05379, |LoRA(x)|=0.5784, B≠0=12281
distilbert.transformer.layer.1.ffn.lin1: |A|=0.06538, |B|=0.0001922, |∇A|=0.0004596, |∇B|=0.0344, |LoRA(x)|=8.039, B≠0=49123
distilbert.transformer.layer.1.ffn.lin2: |A|=0.03304, |B|=0.0001619, |∇A|=0.0002367, |∇B|=0.0329, |LoRA(x)|=0.297, B≠0=12280
distilbert.transformer.layer.2.attention.q_lin: |A|=0.06613, |B|=0.0001945, |∇A|=0.0002159, |∇B|=0.0249, |LoRA(x)|=1.65, B≠0=12281
distilbert.transformer.layer.2.attention.k_lin: |A|=0.06571, |B|=0.0001978, |∇A|=0.0002439, |∇B|=0.02332, |LoRA(x)|=1.206, B≠0=12281
distilbert.transformer.layer.2.attention.v_lin: |A|=0.06583, |B|=0.0001605, |∇A|=0.0005263, |∇B|=0.06787, |LoRA(x)|=0.9769, B≠0=12280
distilbert.transformer.layer.2.attention.out_lin: |A|=0.06489, |B|=0.0001754, |∇A|=0.0007919, |∇B|=0.04553, |LoRA(x)|=0.5327, B≠0=12280
distilbert.transformer.layer.2.ffn.lin1: |A|=0.06513, |B|=0.0001948, |∇A|=0.0006043, |∇B|=0.04044, |LoRA(x)|=7.488, B≠0=49124
distilbert.transformer.layer.2.ffn.lin2: |A|=0.03309, |B|=0.0001835, |∇A|=0.0004721, |∇B|=0.02781, |LoRA(x)|=0.3463, B≠0=12281
distilbert.transformer.layer.3.attention.q_lin: |A|=0.066, |B|=0.0002053, |∇A|=0.0002105, |∇B|=0.02061, |LoRA(x)|=2.099, B≠0=12281
distilbert.transformer.layer.3.attention.k_lin: |A|=0.06551, |B|=0.0002091, |∇A|=0.0002618, |∇B|=0.02079, |LoRA(x)|=1.555, B≠0=12281
distilbert.transformer.layer.3.attention.v_lin: |A|=0.06551, |B|=0.0001497, |∇A|=0.0004883, |∇B|=0.08872, |LoRA(x)|=1.46, B≠0=12280
distilbert.transformer.layer.3.attention.out_lin: |A|=0.06552, |B|=0.0001923, |∇A|=0.0008196, |∇B|=0.03743, |LoRA(x)|=0.5979, B≠0=12281
distilbert.transformer.layer.3.ffn.lin1: |A|=0.06555, |B|=0.000184, |∇A|=0.0006847, |∇B|=0.04193, |LoRA(x)|=8.602, B≠0=49123
distilbert.transformer.layer.3.ffn.lin2: |A|=0.03308, |B|=0.0001783, |∇A|=0.0004434, |∇B|=0.0232, |LoRA(x)|=0.3174, B≠0=12280
distilbert.transformer.layer.4.attention.q_lin: |A|=0.06642, |B|=0.0002137, |∇A|=0.0003046, |∇B|=0.02595, |LoRA(x)|=3.838, B≠0=12281
distilbert.transformer.layer.4.attention.k_lin: |A|=0.06605, |B|=0.0002044, |∇A|=0.000532, |∇B|=0.01982, |LoRA(x)|=2.01, B≠0=12281
distilbert.transformer.layer.4.attention.v_lin: |A|=0.0658, |B|=0.0001578, |∇A|=0.0006522, |∇B|=0.06364, |LoRA(x)|=1.776, B≠0=12280
distilbert.transformer.layer.4.attention.out_lin: |A|=0.0653, |B|=0.0001886, |∇A|=0.0008771, |∇B|=0.02987, |LoRA(x)|=0.784, B≠0=12281
distilbert.transformer.layer.4.ffn.lin1: |A|=0.06552, |B|=0.0002044, |∇A|=0.0006782, |∇B|=0.02294, |LoRA(x)|=9.608, B≠0=49125
distilbert.transformer.layer.4.ffn.lin2: |A|=0.03295, |B|=0.0002195, |∇A|=0.0004506, |∇B|=0.01567, |LoRA(x)|=0.4364, B≠0=12282
distilbert.transformer.layer.5.attention.q_lin: |A|=0.06579, |B|=0.000201, |∇A|=0.000149, |∇B|=0.01128, |LoRA(x)|=1.89, B≠0=12281
distilbert.transformer.layer.5.attention.k_lin: |A|=0.06566, |B|=0.0001923, |∇A|=0.0005654, |∇B|=0.01185, |LoRA(x)|=2.068, B≠0=12281
distilbert.transformer.layer.5.attention.v_lin: |A|=0.06563, |B|=0.0001716, |∇A|=0.0005953, |∇B|=0.03601, |LoRA(x)|=2.725, B≠0=12281
distilbert.transformer.layer.5.attention.out_lin: |A|=0.0654, |B|=0.0002853, |∇A|=0.0008672, |∇B|=0.01711, |LoRA(x)|=2.471, B≠0=12282
distilbert.transformer.layer.5.ffn.lin1: |A|=0.0657, |B|=0.0001973, |∇A|=0.0004453, |∇B|=0.009803, |LoRA(x)|=11.89, B≠0=49123
distilbert.transformer.layer.5.ffn.lin2: |A|=0.03308, |B|=0.0002221, |∇A|=0.0001933, |∇B|=0.008399, |LoRA(x)|=1.134, B≠0=12281

Orthogonal initialisation for A¶

In [11]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        #self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        # Orthogonal initialisation for A
        self.A = nn.Parameter(torch.empty(in_dim, rank))
        nn.init.orthogonal_(self.A)
        self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev)  # Not all zeroes!
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Dropout applied to the projection to the lower-dimensional space by A
        dropped = self.dropout(x @ self.A)
        return self.alpha * (dropped @ self.B)

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.last_lora_output_norm = 0.0  # for monitoring

    def forward(self, x):
        #return self.linear(x) + self.lora(x)
        lora_out = self.lora(x)
        self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
        return self.linear(x) + lora_out


# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    # target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            lora_linear.lora.dropout = nn.Dropout(dropout_rate)
            
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
In [12]:
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=1,
    #max_steps=200,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
#    if "lora" in name:
#        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[782/782 28:32, Epoch 1/1]
Step Training Loss Validation Loss Accuracy F1
50 0.631800 0.476356 0.832000 0.832194
100 0.386100 0.353101 0.848000 0.848142
150 0.338900 0.293442 0.880800 0.881030
200 0.273900 0.276142 0.885600 0.885174
250 0.299300 0.277328 0.892000 0.892224
300 0.267800 0.270556 0.896000 0.895489
350 0.275800 0.252276 0.899200 0.899257
400 0.261600 0.252071 0.902400 0.902492
450 0.248400 0.254714 0.898400 0.898543
500 0.268600 0.241609 0.907200 0.907034
550 0.251900 0.238318 0.904800 0.904695
600 0.268800 0.233752 0.906400 0.906314
650 0.239300 0.242723 0.902400 0.902550
700 0.233100 0.234716 0.906400 0.906393
750 0.233200 0.233564 0.904800 0.904793

distilbert.transformer.layer.0.attention.q_lin: |A|=0.02894, |B|=0.0002256, |∇A|=0.0002051, |∇B|=0.00704, |LoRA(x)|=0.6222, B≠0=12282
distilbert.transformer.layer.0.attention.k_lin: |A|=0.0289, |B|=0.0002237, |∇A|=0.0001388, |∇B|=0.007028, |LoRA(x)|=0.7385, B≠0=12282
distilbert.transformer.layer.0.attention.v_lin: |A|=0.02868, |B|=0.0002192, |∇A|=0.0007466, |∇B|=0.01938, |LoRA(x)|=0.6943, B≠0=12281
distilbert.transformer.layer.0.attention.out_lin: |A|=0.02873, |B|=0.0002161, |∇A|=0.001234, |∇B|=0.02534, |LoRA(x)|=0.4499, B≠0=12281
distilbert.transformer.layer.0.ffn.lin1: |A|=0.02887, |B|=0.0002077, |∇A|=0.0005898, |∇B|=0.01437, |LoRA(x)|=3.549, B≠0=49125
distilbert.transformer.layer.0.ffn.lin2: |A|=0.01439, |B|=0.0002112, |∇A|=0.0003794, |∇B|=0.0146, |LoRA(x)|=0.4078, B≠0=12281
distilbert.transformer.layer.1.attention.q_lin: |A|=0.02871, |B|=0.0002067, |∇A|=0.0002277, |∇B|=0.008702, |LoRA(x)|=0.8418, B≠0=12281
distilbert.transformer.layer.1.attention.k_lin: |A|=0.02878, |B|=0.0002133, |∇A|=0.0001856, |∇B|=0.007924, |LoRA(x)|=0.5599, B≠0=12281
distilbert.transformer.layer.1.attention.v_lin: |A|=0.02885, |B|=0.0001863, |∇A|=0.0005139, |∇B|=0.03951, |LoRA(x)|=0.7551, B≠0=12281
distilbert.transformer.layer.1.attention.out_lin: |A|=0.02884, |B|=0.0001969, |∇A|=0.001035, |∇B|=0.03354, |LoRA(x)|=0.4694, B≠0=12281
distilbert.transformer.layer.1.ffn.lin1: |A|=0.02882, |B|=0.0002174, |∇A|=0.0005914, |∇B|=0.02101, |LoRA(x)|=5.848, B≠0=49126
distilbert.transformer.layer.1.ffn.lin2: |A|=0.01441, |B|=0.0001863, |∇A|=0.0003113, |∇B|=0.01551, |LoRA(x)|=0.1606, B≠0=12281
distilbert.transformer.layer.2.attention.q_lin: |A|=0.02863, |B|=0.0002147, |∇A|=0.0002332, |∇B|=0.01161, |LoRA(x)|=0.9744, B≠0=12281
distilbert.transformer.layer.2.attention.k_lin: |A|=0.02885, |B|=0.0002233, |∇A|=0.000279, |∇B|=0.01124, |LoRA(x)|=0.6466, B≠0=12282
distilbert.transformer.layer.2.attention.v_lin: |A|=0.02879, |B|=0.0001866, |∇A|=0.0006717, |∇B|=0.03784, |LoRA(x)|=0.6731, B≠0=12281
distilbert.transformer.layer.2.attention.out_lin: |A|=0.02879, |B|=0.0002054, |∇A|=0.001122, |∇B|=0.03079, |LoRA(x)|=0.4415, B≠0=12281
distilbert.transformer.layer.2.ffn.lin1: |A|=0.02879, |B|=0.0002269, |∇A|=0.0008912, |∇B|=0.02662, |LoRA(x)|=7.34, B≠0=49126
distilbert.transformer.layer.2.ffn.lin2: |A|=0.01442, |B|=0.0002097, |∇A|=0.0005021, |∇B|=0.01674, |LoRA(x)|=0.2238, B≠0=12281
distilbert.transformer.layer.3.attention.q_lin: |A|=0.02891, |B|=0.000226, |∇A|=0.0002603, |∇B|=0.009309, |LoRA(x)|=0.9878, B≠0=12281
distilbert.transformer.layer.3.attention.k_lin: |A|=0.02882, |B|=0.0002275, |∇A|=0.000281, |∇B|=0.009663, |LoRA(x)|=0.885, B≠0=12282
distilbert.transformer.layer.3.attention.v_lin: |A|=0.02887, |B|=0.0001937, |∇A|=0.0007371, |∇B|=0.0562, |LoRA(x)|=1.023, B≠0=12281
distilbert.transformer.layer.3.attention.out_lin: |A|=0.02891, |B|=0.0002549, |∇A|=0.001971, |∇B|=0.02536, |LoRA(x)|=0.4443, B≠0=12282
distilbert.transformer.layer.3.ffn.lin1: |A|=0.02888, |B|=0.0002292, |∇A|=0.001239, |∇B|=0.02104, |LoRA(x)|=4.993, B≠0=49126
distilbert.transformer.layer.3.ffn.lin2: |A|=0.0144, |B|=0.0002404, |∇A|=0.0007727, |∇B|=0.01577, |LoRA(x)|=0.2704, B≠0=12282
distilbert.transformer.layer.4.attention.q_lin: |A|=0.02871, |B|=0.0002524, |∇A|=0.000489, |∇B|=0.01352, |LoRA(x)|=2.283, B≠0=12282
distilbert.transformer.layer.4.attention.k_lin: |A|=0.02884, |B|=0.0002416, |∇A|=0.0009936, |∇B|=0.009785, |LoRA(x)|=0.9456, B≠0=12281
distilbert.transformer.layer.4.attention.v_lin: |A|=0.02885, |B|=0.0002123, |∇A|=0.00103, |∇B|=0.0365, |LoRA(x)|=1.395, B≠0=12281
distilbert.transformer.layer.4.attention.out_lin: |A|=0.0288, |B|=0.000233, |∇A|=0.001299, |∇B|=0.01674, |LoRA(x)|=0.5857, B≠0=12282
distilbert.transformer.layer.4.ffn.lin1: |A|=0.02887, |B|=0.0002406, |∇A|=0.0007501, |∇B|=0.01257, |LoRA(x)|=7.437, B≠0=49128
distilbert.transformer.layer.4.ffn.lin2: |A|=0.01443, |B|=0.0002652, |∇A|=0.0004402, |∇B|=0.01286, |LoRA(x)|=0.5918, B≠0=12282
distilbert.transformer.layer.5.attention.q_lin: |A|=0.02885, |B|=0.0002558, |∇A|=0.0001962, |∇B|=0.007697, |LoRA(x)|=1.709, B≠0=12281
distilbert.transformer.layer.5.attention.k_lin: |A|=0.02884, |B|=0.0002236, |∇A|=0.0007342, |∇B|=0.004787, |LoRA(x)|=1.879, B≠0=12281
distilbert.transformer.layer.5.attention.v_lin: |A|=0.02887, |B|=0.0002386, |∇A|=0.001155, |∇B|=0.01626, |LoRA(x)|=1.73, B≠0=12282
distilbert.transformer.layer.5.attention.out_lin: |A|=0.02883, |B|=0.0003673, |∇A|=0.001476, |∇B|=0.008586, |LoRA(x)|=1.781, B≠0=12283
distilbert.transformer.layer.5.ffn.lin1: |A|=0.02882, |B|=0.0002423, |∇A|=0.0007944, |∇B|=0.005601, |LoRA(x)|=9.013, B≠0=49127
distilbert.transformer.layer.5.ffn.lin2: |A|=0.0144, |B|=0.0003482, |∇A|=0.0004651, |∇B|=0.004582, |LoRA(x)|=0.9978, B≠0=12283

Increase std.dev for B and apply dropout to boost lora.B adaptation¶

In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(1e-3 * torch.randn(rank, out_dim) * std_dev)  # Increase std.dev for B
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Dropout applied to the projection to the lower-dimensional space by A
        dropped = self.dropout(x @ self.A)
        return self.alpha * (dropped @ self.B)

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.last_lora_output_norm = 0.0  # for monitoring

    def forward(self, x):
        #return self.linear(x) + self.lora(x)
        lora_out = self.lora(x)
        self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
        return self.linear(x) + lora_out


# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    # target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            lora_linear.lora.dropout = nn.Dropout(dropout_rate)
            
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
In [14]:
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"

import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
#    if param.requires_grad:
#        print(name)

eval_steps = 50
logging_steps = 50

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=2,
    #max_steps=200,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    max_grad_norm=1.0, ##################
    report_to="none",
    log_level="error"
)

    
trainer_lora_all_attn = Trainer(
    model=model_lora_all_attn,
    args=training_args_lora_all_attn,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["validation"],
    compute_metrics=compute_metrics,
)


hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)

#Train!
trainer_lora_all_attn.train()

eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print (torch.cuda.memory_summary())

for hook in hooks:
    hook.remove()

# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
    grouped = defaultdict(list)
    for name, val in vals:
        grouped[name].append(val)
    agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}

for name in agg["A_abs_mean"]:
    print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
          f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
          f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")


print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114
LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[1564/1564 1:01:22, Epoch 2/2]
Step Training Loss Validation Loss Accuracy F1
50 0.572400 0.361935 0.853600 0.853897
100 0.395200 0.314839 0.871200 0.871393
150 0.388700 0.324395 0.872000 0.872105
200 0.359600 0.288238 0.875200 0.875027
250 0.375300 0.355832 0.841600 0.841525
300 0.348600 0.279084 0.888000 0.887753
350 0.360400 0.307965 0.874400 0.874675
400 0.331700 0.285904 0.896800 0.897028
450 0.313200 0.311799 0.895200 0.895123
500 0.353700 0.291399 0.883200 0.882492
550 0.313900 0.274373 0.881600 0.881711
600 0.344400 0.289618 0.887200 0.886285
650 0.311900 0.266004 0.895200 0.895175
700 0.313900 0.260952 0.904800 0.904677
750 0.311800 0.276892 0.896800 0.896441
800 0.331900 0.266435 0.900000 0.899438
850 0.289800 0.267684 0.902400 0.902573
900 0.283600 0.294683 0.899200 0.899422
950 0.307300 0.265540 0.906400 0.906297
1000 0.270400 0.288483 0.900000 0.900111
1050 0.266700 0.258305 0.907200 0.906888
1100 0.276200 0.250781 0.905600 0.905555
1150 0.253400 0.249464 0.902400 0.902573
1200 0.278300 0.245663 0.912800 0.912704
1250 0.260800 0.250558 0.904800 0.904659
1300 0.266500 0.247603 0.907200 0.907308
1350 0.260500 0.247938 0.904800 0.904915
1400 0.250800 0.251896 0.904000 0.903701
1450 0.252900 0.245024 0.907200 0.907171
1500 0.255200 0.245073 0.904800 0.904639
1550 0.242100 0.244082 0.908800 0.908786

[743/743 03:03]
LoRA (All Attention) Test Results: {'eval_loss': 0.2366224229335785, 'eval_accuracy': 0.907621052631579, 'eval_f1': 0.9076122672728953, 'eval_runtime': 184.1411, 'eval_samples_per_second': 128.977, 'eval_steps_per_second': 4.035, 'epoch': 2.0}
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 563068 KiB |   4397 MiB | 232710 GiB | 232710 GiB |
|       from large pool | 546990 KiB |   4344 MiB | 230922 GiB | 230922 GiB |
|       from small pool |  16078 KiB |     55 MiB |   1788 GiB |   1788 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 563068 KiB |   4397 MiB | 232710 GiB | 232710 GiB |
|       from large pool | 546990 KiB |   4344 MiB | 230922 GiB | 230922 GiB |
|       from small pool |  16078 KiB |     55 MiB |   1788 GiB |   1788 GiB |
|---------------------------------------------------------------------------|
| Requested memory      | 559912 KiB |   4394 MiB | 232500 GiB | 232499 GiB |
|       from large pool | 543836 KiB |   4340 MiB | 230718 GiB | 230718 GiB |
|       from small pool |  16076 KiB |     55 MiB |   1781 GiB |   1781 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   4510 MiB |   4580 MiB |  23456 MiB |  18946 MiB |
|       from large pool |   4450 MiB |   4520 MiB |  23204 MiB |  18754 MiB |
|       from small pool |     60 MiB |     62 MiB |    252 MiB |    192 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 182404 KiB | 382935 KiB |  79215 GiB |  79215 GiB |
|       from large pool | 175954 KiB | 376402 KiB |  77291 GiB |  77290 GiB |
|       from small pool |   6450 KiB |  30566 KiB |   1924 GiB |   1924 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     436    |     623    |   30226 K  |   30225 K  |
|       from large pool |      82    |     152    |    6772 K  |    6772 K  |
|       from small pool |     354    |     515    |   23453 K  |   23453 K  |
|---------------------------------------------------------------------------|
| Active allocs         |     436    |     623    |   30226 K  |   30225 K  |
|       from large pool |      82    |     152    |    6772 K  |    6772 K  |
|       from small pool |     354    |     515    |   23453 K  |   23453 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      97    |     101    |     431    |     334    |
|       from large pool |      67    |      71    |     305    |     238    |
|       from small pool |      30    |      31    |     126    |      96    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      46    |      73    |   16377 K  |   16377 K  |
|       from large pool |      16    |      22    |    2502 K  |    2502 K  |
|       from small pool |      30    |      57    |   13874 K  |   13874 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0003389, |∇A|=0.000502, |∇B|=0.158, |LoRA(x)|=6.894, B≠0=12285
distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0003358, |∇A|=0.0004346, |∇B|=0.1471, |LoRA(x)|=7.695, B≠0=12286
distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.000301, |∇A|=0.0007981, |∇B|=0.2967, |LoRA(x)|=6.75, B≠0=12285
distilbert.transformer.layer.0.attention.out_lin: |A|=0.2003, |B|=0.0003029, |∇A|=0.0007932, |∇B|=0.321, |LoRA(x)|=3.656, B≠0=12285
distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0003268, |∇A|=0.0009316, |∇B|=0.1899, |LoRA(x)|=21.53, B≠0=49143
distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0003015, |∇A|=0.0003772, |∇B|=0.3445, |LoRA(x)|=6.206, B≠0=12286
distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0003353, |∇A|=0.0003761, |∇B|=0.1395, |LoRA(x)|=7.866, B≠0=12285
distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0003366, |∇A|=0.0004358, |∇B|=0.1484, |LoRA(x)|=6.999, B≠0=12285
distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0002858, |∇A|=0.0006746, |∇B|=0.2917, |LoRA(x)|=5.321, B≠0=12285
distilbert.transformer.layer.1.attention.out_lin: |A|=0.1984, |B|=0.00029, |∇A|=0.0007089, |∇B|=0.3005, |LoRA(x)|=3.779, B≠0=12285
distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0003159, |∇A|=0.0007903, |∇B|=0.1846, |LoRA(x)|=23.45, B≠0=49142
distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.000288, |∇A|=0.0002804, |∇B|=0.2852, |LoRA(x)|=3.764, B≠0=12284
distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0003262, |∇A|=0.0003942, |∇B|=0.1447, |LoRA(x)|=8.057, B≠0=12286
distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0003176, |∇A|=0.000448, |∇B|=0.1577, |LoRA(x)|=7.385, B≠0=12285
distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0002813, |∇A|=0.0005137, |∇B|=0.2423, |LoRA(x)|=6.056, B≠0=12286
distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0002805, |∇A|=0.0005144, |∇B|=0.2204, |LoRA(x)|=3.703, B≠0=12285
distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0003084, |∇A|=0.0006538, |∇B|=0.15, |LoRA(x)|=24.32, B≠0=49141
distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002843, |∇A|=0.000218, |∇B|=0.2078, |LoRA(x)|=3.736, B≠0=12285
distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.000324, |∇A|=0.0002934, |∇B|=0.1105, |LoRA(x)|=8.299, B≠0=12285
distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0003226, |∇A|=0.0003681, |∇B|=0.1199, |LoRA(x)|=7.896, B≠0=12285
distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.00026, |∇A|=0.0003734, |∇B|=0.2297, |LoRA(x)|=6.205, B≠0=12285
distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0002658, |∇A|=0.0004406, |∇B|=0.1566, |LoRA(x)|=4.282, B≠0=12285
distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0003003, |∇A|=0.0005145, |∇B|=0.08879, |LoRA(x)|=18.41, B≠0=49141
distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0002639, |∇A|=0.0001617, |∇B|=0.1338, |LoRA(x)|=3.407, B≠0=12285
distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0003261, |∇A|=0.0001543, |∇B|=0.06833, |LoRA(x)|=12.96, B≠0=12284
distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0003134, |∇A|=0.0003658, |∇B|=0.07867, |LoRA(x)|=8.107, B≠0=12286
distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0002635, |∇A|=0.0003247, |∇B|=0.09568, |LoRA(x)|=6.022, B≠0=12285
distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0002577, |∇A|=0.0002548, |∇B|=0.0689, |LoRA(x)|=4.138, B≠0=12285
distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0002738, |∇A|=0.0002338, |∇B|=0.03957, |LoRA(x)|=21.57, B≠0=49141
distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0002625, |∇A|=0.0001001, |∇B|=0.06575, |LoRA(x)|=4.007, B≠0=12284
distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0003016, |∇A|=8.459e-05, |∇B|=0.02227, |LoRA(x)|=9.447, B≠0=12286
distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002925, |∇A|=0.0001426, |∇B|=0.02499, |LoRA(x)|=7.739, B≠0=12285
distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0002508, |∇A|=0.0002601, |∇B|=0.04984, |LoRA(x)|=8.357, B≠0=12285
distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0002483, |∇A|=0.000202, |∇B|=0.0489, |LoRA(x)|=15.56, B≠0=12286
distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0002525, |∇A|=0.0001525, |∇B|=0.01987, |LoRA(x)|=20.78, B≠0=49138
distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0002402, |∇A|=4.512e-05, |∇B|=0.03451, |LoRA(x)|=5.875, B≠0=12285
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20136898756027222
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00036163683398626745
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.199139803647995
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.0003591857384890318
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.199319526553154
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.00031731155468150973
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20034024119377136
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00031905301148071885
distilbert.transformer.layer.0.ffn.lin1.lora.A 0.19975820183753967
distilbert.transformer.layer.0.ffn.lin1.lora.B 0.00035009029670618474
distilbert.transformer.layer.0.ffn.lin2.lora.A 0.19959013164043427
distilbert.transformer.layer.0.ffn.lin2.lora.B 0.0003168959519825876
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1999898999929428
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.00035791145637631416
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.20074164867401123
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.0003592821885831654
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.19777165353298187
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030186009826138616
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.19843807816505432
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.0003039248113054782
distilbert.transformer.layer.1.ffn.lin1.lora.A 0.1979658454656601
distilbert.transformer.layer.1.ffn.lin1.lora.B 0.00033694933517836034
distilbert.transformer.layer.1.ffn.lin2.lora.A 0.19888746738433838
distilbert.transformer.layer.1.ffn.lin2.lora.B 0.0003027324564754963
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19796311855316162
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00034728783066384494
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.20021361112594604
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00033590427483431995
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.20018713176250458
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00029516470385715365
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19833998382091522
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.00029389417613856494
distilbert.transformer.layer.2.ffn.lin1.lora.A 0.1982342004776001
distilbert.transformer.layer.2.ffn.lin1.lora.B 0.0003286522696726024
distilbert.transformer.layer.2.ffn.lin2.lora.A 0.19985038042068481
distilbert.transformer.layer.2.ffn.lin2.lora.B 0.000297279329970479
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.20064640045166016
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00034552914439700544
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.20107726752758026
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00034139861236326396
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.20070061087608337
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00027113681426271796
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.1994575560092926
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.0002771209110505879
distilbert.transformer.layer.3.ffn.lin1.lora.A 0.19948390126228333
distilbert.transformer.layer.3.ffn.lin1.lora.B 0.00031956000020727515
distilbert.transformer.layer.3.ffn.lin2.lora.A 0.1997573971748352
distilbert.transformer.layer.3.ffn.lin2.lora.B 0.00027455651434138417
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.20080135762691498
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.00035053043393418193
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.19999122619628906
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00033398327650502324
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.199321448802948
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.0002780442591756582
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19989004731178284
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00027118308935314417
distilbert.transformer.layer.4.ffn.lin1.lora.A 0.20036581158638
distilbert.transformer.layer.4.ffn.lin1.lora.B 0.00029024441028013825
distilbert.transformer.layer.4.ffn.lin2.lora.A 0.19942979514598846
distilbert.transformer.layer.4.ffn.lin2.lora.B 0.0002758408372756094
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.19984808564186096
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00032156246015802026
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.20102792978286743
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003097845474258065
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19865462183952332
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002626449568197131
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.1993870586156845
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0002558119304012507
distilbert.transformer.layer.5.ffn.lin1.lora.A 0.19927702844142914
distilbert.transformer.layer.5.ffn.lin1.lora.B 0.0002630744711495936
distilbert.transformer.layer.5.ffn.lin2.lora.A 0.19940851628780365
distilbert.transformer.layer.5.ffn.lin2.lora.B 0.0002476528752595186
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8493
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0506
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.5909
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0500
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.8039
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0444
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.9052
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0442
distilbert.transformer.layer.0.ffn.lin1.lora.A weight norm: 27.7016
distilbert.transformer.layer.0.ffn.lin1.lora.B weight norm: 0.0976
distilbert.transformer.layer.0.ffn.lin2.lora.A weight norm: 55.5094
distilbert.transformer.layer.0.ffn.lin2.lora.B weight norm: 0.0440
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8651
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0501
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.9079
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0503
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.4143
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0421
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.5450
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0424
distilbert.transformer.layer.1.ffn.lin1.lora.A weight norm: 27.5116
distilbert.transformer.layer.1.ffn.lin1.lora.B weight norm: 0.0939
distilbert.transformer.layer.1.ffn.lin2.lora.A weight norm: 55.2299
distilbert.transformer.layer.1.ffn.lin2.lora.B weight norm: 0.0421
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.6563
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0481
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.7654
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0468
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.8159
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0414
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.5538
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0409
distilbert.transformer.layer.2.ffn.lin1.lora.A weight norm: 27.5591
distilbert.transformer.layer.2.ffn.lin1.lora.B weight norm: 0.0917
distilbert.transformer.layer.2.ffn.lin2.lora.A weight norm: 55.4738
distilbert.transformer.layer.2.ffn.lin2.lora.B weight norm: 0.0416
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.7736
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0482
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.9000
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0476
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.7956
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0376
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.6487
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0386
distilbert.transformer.layer.3.ffn.lin1.lora.A weight norm: 27.6616
distilbert.transformer.layer.3.ffn.lin1.lora.B weight norm: 0.0894
distilbert.transformer.layer.3.ffn.lin2.lora.A weight norm: 55.5141
distilbert.transformer.layer.3.ffn.lin2.lora.B weight norm: 0.0384
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.9841
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0488
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.7779
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0466
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.6690
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0386
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.7600
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0378
distilbert.transformer.layer.4.ffn.lin1.lora.A weight norm: 27.7930
distilbert.transformer.layer.4.ffn.lin1.lora.B weight norm: 0.0813
distilbert.transformer.layer.4.ffn.lin2.lora.A weight norm: 55.3399
distilbert.transformer.layer.4.ffn.lin2.lora.B weight norm: 0.0385
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.7641
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0448
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.9006
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0432
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.5460
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0366
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7014
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0356
distilbert.transformer.layer.5.ffn.lin1.lora.A weight norm: 27.6943
distilbert.transformer.layer.5.ffn.lin1.lora.B weight norm: 0.0736
distilbert.transformer.layer.5.ffn.lin2.lora.A weight norm: 55.4635
distilbert.transformer.layer.5.ffn.lin2.lora.B weight norm: 0.0346

Key Observations from the Dump¶

  1. Despite small absolute values (e.g., |B| ≈ 0.00025–0.00033), the gradients |∇B| are significantly larger, especially in early layers: |∇B| ≈ 0.15–0.34 in layers 0–2; |∇B| drops gradually deeper → ~0.02 at layer 5. This suggests strong early gradient flow into B, tapering off with depth. LoRA is adapting more in early layers, less in deep ones.
  2. |LoRA(x)| ranges are relatively large despite tiny B. This is desired: LoRA is learning meaningful low-rank deltas.
  3. Instead of B growing, we are seeing sharp, effective updates guided by ∇B — especially with dropout pushing adaptation.
  4. Althouhg LoRA.B values remain small (|B| ≈ 0.00025–0.00033), this is not a sign of stagnation/ undertraining — the signal is directional: LoRA module effectively computes: ΔW=α⋅AB where A has relatively large values (directional basis) and B contains very small coefficients (activation selectors). The direction of AB matters much more than the norm, because alpha rescales the final output. Here, the norm of AB is suppressed during training, while its orientation in parameter space becomes useful for model adaptation. Small B encourages better generalization. A large B would effectively cause large ΔW, risking overfitting—defeating the point of parameter-efficient fine-tuning.
In [ ]: