import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(torch.cuda.get_device_name(0))
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import bitsandbytes
import peft
import transformers
print(transformers.__version__)
print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)] PyTorch Version: 2.5.1+cu121 True 1 CUDA Version: 12.1 NVIDIA GeForce RTX 4080 Laptop GPU 4.50.0.dev0 bitsandbytes version: 0.45.3 peft version: 0.15.2.dev0 True
Load dataset, base model, and tokeniser¶
from datasets import load_dataset
imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score
# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")
# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)
dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)
# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))
#print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# Helper functions
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}
def count_trainable_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params, 100 * trainable_params / total_params
def freeze_model_layers(model, unfreeze_pre_classifier=False):
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze LoRA and DoRA-specific params, including lora_norm
for name, param in model.named_parameters():
if (
"lora.A" in name
or "lora.B" in name
or "lora_norm" in name
or name.endswith(".m") # For DoRA
or name.endswith(".m_in") # For DDoRA
or name.endswith(".m_out") # For DDoRA
or "scale" in name
):
param.requires_grad = True
# Unfreeze classifier layer (always)
for name, param in model.named_parameters():
if name.startswith("classifier."):
param.requires_grad = True
# unfreeze pre-classifier
if unfreeze_pre_classifier:
for name, param in model.named_parameters():
if name.startswith("pre_classifier."):
param.requires_grad = True
def monitor_lora_parameters(model, threshold=1e-7):
monitor = {
"A_abs_mean": [],
"B_abs_mean": [],
"A_grad_mean": [],
"B_grad_mean": [],
"lora_output_norm": [],
"B_nonzero_count": [],
}
hooks = []
for name, module in model.named_modules():
if hasattr(module, "lora") and hasattr(module.lora, "A") and hasattr(module.lora, "B"):
A_param = module.lora.A
B_param = module.lora.B
# Gradient hooks (directly on nn.Parameter)
if A_param.requires_grad:
hooks.append(A_param.register_hook(lambda grad, n=name: monitor["A_grad_mean"].append((n, grad.abs().mean().item()))))
if B_param.requires_grad:
hooks.append(B_param.register_hook(lambda grad, n=name: monitor["B_grad_mean"].append((n, grad.abs().mean().item()))))
# Forward hook for value stats
def forward_hook(mod, inp, out, n=name):
A_mean = mod.lora.A.abs().mean().item()
B_mean = mod.lora.B.abs().mean().item()
B_nnz = (mod.lora.B.abs() > threshold).sum().item()
monitor["A_abs_mean"].append((n, A_mean))
monitor["B_abs_mean"].append((n, B_mean))
monitor["B_nonzero_count"].append((n, B_nnz))
monitor["lora_output_norm"].append((n, mod.last_lora_output_norm))
hooks.append(module.register_forward_hook(forward_hook))
return hooks, monitor
LoRA¶
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.504900 | 0.335573 | 0.869600 | 0.869783 |
| 100 | 0.341200 | 0.318371 | 0.862400 | 0.862357 |
| 150 | 0.361400 | 0.307639 | 0.876000 | 0.875739 |
| 200 | 0.303400 | 0.296421 | 0.870400 | 0.870685 |
| 250 | 0.324500 | 0.276511 | 0.886400 | 0.886654 |
| 300 | 0.295500 | 0.261722 | 0.890400 | 0.889511 |
| 350 | 0.312700 | 0.276859 | 0.896800 | 0.896362 |
| 400 | 0.281200 | 0.253817 | 0.890400 | 0.890564 |
| 450 | 0.257400 | 0.244995 | 0.909600 | 0.909240 |
| 500 | 0.284400 | 0.264725 | 0.898400 | 0.898376 |
| 550 | 0.297700 | 0.238351 | 0.908800 | 0.908708 |
| 600 | 0.287900 | 0.237734 | 0.909600 | 0.909143 |
| 650 | 0.234500 | 0.224168 | 0.913600 | 0.913528 |
| 700 | 0.247900 | 0.222799 | 0.915200 | 0.915187 |
| 750 | 0.256000 | 0.227275 | 0.916000 | 0.916031 |
| 800 | 0.241700 | 0.282384 | 0.899200 | 0.898261 |
| 850 | 0.190800 | 0.242396 | 0.904000 | 0.904121 |
| 900 | 0.169500 | 0.258911 | 0.910400 | 0.910400 |
| 950 | 0.208600 | 0.260492 | 0.907200 | 0.906995 |
| 1000 | 0.189500 | 0.241334 | 0.913600 | 0.913649 |
| 1050 | 0.182400 | 0.214967 | 0.924800 | 0.924764 |
| 1100 | 0.192600 | 0.222719 | 0.915200 | 0.915213 |
| 1150 | 0.170000 | 0.231555 | 0.912000 | 0.911767 |
| 1200 | 0.182200 | 0.242794 | 0.916800 | 0.916928 |
| 1250 | 0.174700 | 0.221050 | 0.917600 | 0.917461 |
| 1300 | 0.187400 | 0.229759 | 0.916800 | 0.916896 |
| 1350 | 0.173600 | 0.224172 | 0.917600 | 0.917641 |
| 1400 | 0.163200 | 0.224013 | 0.920800 | 0.920870 |
| 1450 | 0.170500 | 0.221920 | 0.920800 | 0.920806 |
| 1500 | 0.165600 | 0.220823 | 0.918400 | 0.918361 |
| 1550 | 0.171000 | 0.221188 | 0.920800 | 0.920840 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002572, |∇A|=0.000277, |∇B|=0.08185, |LoRA(x)|=4.874, B≠0=12283 distilbert.transformer.layer.0.attention.k_lin: |A|=0.1992, |B|=0.0002632, |∇A|=0.0002245, |∇B|=0.07715, |LoRA(x)|=5.663, B≠0=12283 distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002103, |∇A|=0.0005365, |∇B|=0.1825, |LoRA(x)|=4.493, B≠0=12282 distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.000215, |∇A|=0.0004854, |∇B|=0.183, |LoRA(x)|=2.415, B≠0=12283 distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002391, |∇A|=0.000489, |∇B|=0.1047, |LoRA(x)|=15.38, B≠0=49133 distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002197, |∇A|=0.0001981, |∇B|=0.1984, |LoRA(x)|=4.228, B≠0=12282 distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002441, |∇A|=0.0002117, |∇B|=0.07861, |LoRA(x)|=5.422, B≠0=12283 distilbert.transformer.layer.1.attention.k_lin: |A|=0.2008, |B|=0.0002601, |∇A|=0.0002199, |∇B|=0.07653, |LoRA(x)|=4.799, B≠0=12283 distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001886, |∇A|=0.0003676, |∇B|=0.185, |LoRA(x)|=3.285, B≠0=12282 distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002105, |∇A|=0.0004261, |∇B|=0.1959, |LoRA(x)|=2.38, B≠0=12283 distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.000237, |∇A|=0.0004655, |∇B|=0.119, |LoRA(x)|=19.11, B≠0=49132 distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002041, |∇A|=0.0001751, |∇B|=0.1963, |LoRA(x)|=2.361, B≠0=12282 distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002397, |∇A|=0.0002494, |∇B|=0.09958, |LoRA(x)|=5.855, B≠0=12283 distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002423, |∇A|=0.0002714, |∇B|=0.1016, |LoRA(x)|=5.16, B≠0=12283 distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001928, |∇A|=0.0003262, |∇B|=0.1879, |LoRA(x)|=4.02, B≠0=12282 distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0001992, |∇A|=0.0003639, |∇B|=0.1642, |LoRA(x)|=2.252, B≠0=12282 distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002306, |∇A|=0.0005194, |∇B|=0.1186, |LoRA(x)|=19.02, B≠0=49132 distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002021, |∇A|=0.0002104, |∇B|=0.1721, |LoRA(x)|=2.549, B≠0=12283 distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002554, |∇A|=0.000265, |∇B|=0.09461, |LoRA(x)|=5.758, B≠0=12283 distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002541, |∇A|=0.0002945, |∇B|=0.09943, |LoRA(x)|=6.224, B≠0=12283 distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001641, |∇A|=0.0003354, |∇B|=0.2089, |LoRA(x)|=4.087, B≠0=12281 distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.000194, |∇A|=0.0004308, |∇B|=0.14, |LoRA(x)|=2.848, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002212, |∇A|=0.0004882, |∇B|=0.07969, |LoRA(x)|=11.88, B≠0=49131 distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001834, |∇A|=0.0001973, |∇B|=0.1337, |LoRA(x)|=2.234, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002504, |∇A|=0.0001735, |∇B|=0.07399, |LoRA(x)|=8.787, B≠0=12283 distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002486, |∇A|=0.0003957, |∇B|=0.07656, |LoRA(x)|=6.261, B≠0=12283 distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001708, |∇A|=0.0003348, |∇B|=0.11, |LoRA(x)|=3.692, B≠0=12281 distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0001745, |∇A|=0.000366, |∇B|=0.07387, |LoRA(x)|=2.375, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001853, |∇A|=0.0002817, |∇B|=0.04258, |LoRA(x)|=16.52, B≠0=49128 distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001656, |∇A|=0.0001065, |∇B|=0.0697, |LoRA(x)|=2.054, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002194, |∇A|=8.952e-05, |∇B|=0.02991, |LoRA(x)|=8.03, B≠0=12283 distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002024, |∇A|=0.0002036, |∇B|=0.03099, |LoRA(x)|=5.661, B≠0=12282 distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001348, |∇A|=0.0002633, |∇B|=0.04537, |LoRA(x)|=3.664, B≠0=12280 distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001793, |∇A|=0.0002507, |∇B|=0.03697, |LoRA(x)|=9.559, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.000147, |∇A|=0.0001532, |∇B|=0.01653, |LoRA(x)|=17.16, B≠0=49123 distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001564, |∇A|=5.048e-05, |∇B|=0.02768, |LoRA(x)|=5.171, B≠0=12282
Training summary¶
- Loss consistently decreases, and validation accuracy steadily improves until ~step 1050. Best observed performance: step ~1050: Accuracy: 92.48%, F1 Score: 92.48%, Val Loss: 0.214967. After that, metrics plateau or slightly fluctuate, suggesting diminishing returns or slight overfitting.
- LoRA adapts low-rank matrices 𝐴 and 𝐵 on each linear layer. ∣A∣ ~ 0.198 – 0.201 across layers (as expected, A is initialised and scaled normally). ∣B∣ ~ 0.00015 – 0.00026 - nearly 3 orders of magnitude smaller than A - expected as the B matrix starts close to zero and learns slowly. ∣LoRA(x)∣ ranges reasonably: ~2–19.
- LoRA's effective update is the product of BA. So even though A is large, B remains small early on → effective update is very low-rank and small in norm, especially during the first few dozen steps. This supports what we might be seeing in loss curves: slow or flat loss change initially. Effective Layer Utilisation, especially Layer 0–3: balanced activity across attention and FFN.
- Gradients: Gradients:∣∇A∣ ~ 0.0002–0.0005 (small but consistent updates); ∣∇B∣: larger, ranging up to 0.2, especially in v_lin and lin2 — LoRA is learning more actively in B.
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.522300 | 0.358082 | 0.854400 | 0.853802 |
| 100 | 0.374700 | 0.310322 | 0.862400 | 0.862513 |
| 150 | 0.386100 | 0.318261 | 0.869600 | 0.869504 |
| 200 | 0.326900 | 0.307559 | 0.869600 | 0.867806 |
| 250 | 0.351500 | 0.314006 | 0.877600 | 0.877870 |
| 300 | 0.332300 | 0.272534 | 0.892000 | 0.891881 |
| 350 | 0.343300 | 0.279443 | 0.893600 | 0.893502 |
| 400 | 0.327400 | 0.285108 | 0.890400 | 0.889991 |
| 450 | 0.295900 | 0.258271 | 0.898400 | 0.898208 |
| 500 | 0.327000 | 0.285135 | 0.883200 | 0.882071 |
| 550 | 0.311600 | 0.249657 | 0.904000 | 0.903745 |
| 600 | 0.315400 | 0.261444 | 0.902400 | 0.902096 |
| 650 | 0.290800 | 0.253156 | 0.896800 | 0.896605 |
| 700 | 0.299500 | 0.243638 | 0.900800 | 0.900537 |
| 750 | 0.297200 | 0.240773 | 0.908000 | 0.907932 |
| 800 | 0.283300 | 0.261479 | 0.903200 | 0.902599 |
| 850 | 0.277300 | 0.226801 | 0.907200 | 0.907317 |
| 900 | 0.265200 | 0.268644 | 0.904800 | 0.904983 |
| 950 | 0.276900 | 0.250334 | 0.909600 | 0.909370 |
| 1000 | 0.245300 | 0.251263 | 0.908800 | 0.908708 |
| 1050 | 0.241300 | 0.231262 | 0.914400 | 0.914366 |
| 1100 | 0.244200 | 0.239346 | 0.904000 | 0.903903 |
| 1150 | 0.242700 | 0.226735 | 0.908800 | 0.908741 |
| 1200 | 0.236600 | 0.237783 | 0.909600 | 0.909579 |
| 1250 | 0.220600 | 0.229475 | 0.912000 | 0.911986 |
| 1300 | 0.243500 | 0.234786 | 0.908800 | 0.908756 |
| 1350 | 0.246600 | 0.229699 | 0.905600 | 0.905699 |
| 1400 | 0.238500 | 0.222471 | 0.916800 | 0.916560 |
| 1450 | 0.221000 | 0.223475 | 0.910400 | 0.910400 |
| 1500 | 0.226000 | 0.222452 | 0.916000 | 0.915892 |
| 1550 | 0.209200 | 0.223851 | 0.915200 | 0.915145 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002581, |∇A|=0.0003806, |∇B|=0.1261, |LoRA(x)|=5.469, B≠0=12283 distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0002605, |∇A|=0.0003199, |∇B|=0.1166, |LoRA(x)|=6.167, B≠0=12283 distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002149, |∇A|=0.0005936, |∇B|=0.2469, |LoRA(x)|=5.135, B≠0=12283 distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.0002184, |∇A|=0.0005894, |∇B|=0.2659, |LoRA(x)|=2.72, B≠0=12283 distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002434, |∇A|=0.0006892, |∇B|=0.1599, |LoRA(x)|=17.21, B≠0=49132 distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002225, |∇A|=0.0002834, |∇B|=0.2989, |LoRA(x)|=4.928, B≠0=12283 distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002619, |∇A|=0.0003127, |∇B|=0.1176, |LoRA(x)|=6.326, B≠0=12283 distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0002593, |∇A|=0.0003151, |∇B|=0.1215, |LoRA(x)|=5.5, B≠0=12283 distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001929, |∇A|=0.0004882, |∇B|=0.257, |LoRA(x)|=3.703, B≠0=12282 distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002113, |∇A|=0.0005305, |∇B|=0.2663, |LoRA(x)|=2.635, B≠0=12283 distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0002375, |∇A|=0.0006094, |∇B|=0.1708, |LoRA(x)|=21.37, B≠0=49133 distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002082, |∇A|=0.0002324, |∇B|=0.2649, |LoRA(x)|=2.712, B≠0=12283 distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002454, |∇A|=0.0003173, |∇B|=0.1344, |LoRA(x)|=6.335, B≠0=12283 distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002375, |∇A|=0.0003538, |∇B|=0.1436, |LoRA(x)|=5.793, B≠0=12283 distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001974, |∇A|=0.0004345, |∇B|=0.2487, |LoRA(x)|=4.533, B≠0=12282 distilbert.transformer.layer.2.attention.out_lin: |A|=0.1984, |B|=0.0001982, |∇A|=0.0004565, |∇B|=0.2209, |LoRA(x)|=2.561, B≠0=12283 distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002298, |∇A|=0.0005989, |∇B|=0.1531, |LoRA(x)|=20.14, B≠0=49131 distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002027, |∇A|=0.0002184, |∇B|=0.2143, |LoRA(x)|=2.758, B≠0=12282 distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002525, |∇A|=0.0002761, |∇B|=0.1089, |LoRA(x)|=6.166, B≠0=12283 distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002438, |∇A|=0.0003425, |∇B|=0.1276, |LoRA(x)|=6.506, B≠0=12283 distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001687, |∇A|=0.0003466, |∇B|=0.2575, |LoRA(x)|=4.666, B≠0=12282 distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0001906, |∇A|=0.0004309, |∇B|=0.171, |LoRA(x)|=3.289, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002211, |∇A|=0.0005016, |∇B|=0.1029, |LoRA(x)|=14.59, B≠0=49131 distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001856, |∇A|=0.0002331, |∇B|=0.1571, |LoRA(x)|=2.627, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002418, |∇A|=0.0001731, |∇B|=0.08101, |LoRA(x)|=10.64, B≠0=12283 distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002293, |∇A|=0.0003568, |∇B|=0.08671, |LoRA(x)|=6.066, B≠0=12283 distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001757, |∇A|=0.0003265, |∇B|=0.1101, |LoRA(x)|=4.307, B≠0=12282 distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.000175, |∇A|=0.0003284, |∇B|=0.07959, |LoRA(x)|=2.766, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001862, |∇A|=0.0002638, |∇B|=0.04591, |LoRA(x)|=16.07, B≠0=49129 distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001647, |∇A|=0.0001085, |∇B|=0.08194, |LoRA(x)|=2.308, B≠0=12281 distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002014, |∇A|=9.653e-05, |∇B|=0.03063, |LoRA(x)|=7.295, B≠0=12282 distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002036, |∇A|=0.000199, |∇B|=0.03414, |LoRA(x)|=6.269, B≠0=12283 distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001532, |∇A|=0.0002869, |∇B|=0.05179, |LoRA(x)|=5.935, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001674, |∇A|=0.0001988, |∇B|=0.04984, |LoRA(x)|=8.743, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0001539, |∇A|=0.0001583, |∇B|=0.02406, |LoRA(x)|=18.93, B≠0=49125 distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001408, |∇A|=5.189e-05, |∇B|=0.04019, |LoRA(x)|=5.108, B≠0=12281
Training summary¶
With dropout = 0.4 (applied after projection with matrix A and before the final projection with matrix B) B is forced to adapt more robustly over time - without any significant loss of accuracy or F1. In effect, dropout acts as a reguliser, adding some noise.
Key Trends in |∇B|: Dropout = 0.4 amplifies ∇B across almost every layer — sometimes by 30–100% depending on depth and layer type. This means LoRA updates are more aggressive when dropout is used. It encourages larger updates to the B matrix, probably due to the increased noise and sparsity during training. Dropout significantly boosts LoRA's ∇B gradient magnitudes which are crucial for fast adaptation, capacity and expressivity.
Initialisation of lora.A¶
Matrix A maps from high-dimensional space to low-rank: it compresses features. Its output is often followed by dropout, which adds further noise. Matrix B does the heavy lifting—bringing compressed signals back up. Since LoRA is often trained with a low learning rate and A @ B is scaled by alpha / rank, small initialisation differences in A get washed out during early training steps. Changing A's init mostly affects early training (learning speed, initial activation variance), but doesn’t shift the final performance much. On the contrary, a poorly initialised B sends noise into the frozen model's internal representations. For this reason we initialise B as nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev). Another alternative, all zeroes (self.B = nn.Parameter(torch.zeros(rank, out_dim))) often works well but not always as symmetric initialisation can create a lot of harm.
Kaiming for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Kaiming initialisation for A
std_dev2 = math.sqrt(2 / (in_dim + out_dim))
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev2)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.641200 | 0.498864 | 0.825600 | 0.825801 |
| 100 | 0.398900 | 0.378827 | 0.832000 | 0.831737 |
| 150 | 0.342200 | 0.294475 | 0.880000 | 0.880212 |
| 200 | 0.277900 | 0.277569 | 0.884800 | 0.884265 |
| 250 | 0.299900 | 0.280815 | 0.888000 | 0.888238 |
| 300 | 0.267600 | 0.269878 | 0.896800 | 0.896389 |
| 350 | 0.277600 | 0.252492 | 0.896800 | 0.896903 |
| 400 | 0.263900 | 0.252141 | 0.903200 | 0.903285 |
| 450 | 0.251700 | 0.254847 | 0.900000 | 0.900150 |
| 500 | 0.270300 | 0.241287 | 0.904000 | 0.903867 |
| 550 | 0.253100 | 0.240126 | 0.901600 | 0.901454 |
| 600 | 0.271100 | 0.235365 | 0.904000 | 0.903921 |
| 650 | 0.241800 | 0.243780 | 0.903200 | 0.903353 |
| 700 | 0.233400 | 0.235918 | 0.905600 | 0.905600 |
| 750 | 0.235500 | 0.234417 | 0.907200 | 0.907186 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.02909, |B|=0.0002327, |∇A|=0.0002211, |∇B|=0.007083, |LoRA(x)|=0.6601, B≠0=12282 distilbert.transformer.layer.0.attention.k_lin: |A|=0.02878, |B|=0.0002298, |∇A|=0.0001474, |∇B|=0.007112, |LoRA(x)|=0.7828, B≠0=12281 distilbert.transformer.layer.0.attention.v_lin: |A|=0.02878, |B|=0.0002287, |∇A|=0.0008088, |∇B|=0.01975, |LoRA(x)|=0.7497, B≠0=12282 distilbert.transformer.layer.0.attention.out_lin: |A|=0.02893, |B|=0.0002272, |∇A|=0.001325, |∇B|=0.02585, |LoRA(x)|=0.4959, B≠0=12282 distilbert.transformer.layer.0.ffn.lin1: |A|=0.01825, |B|=0.000217, |∇A|=0.0006296, |∇B|=0.009524, |LoRA(x)|=2.611, B≠0=49126 distilbert.transformer.layer.0.ffn.lin2: |A|=0.01823, |B|=0.0002204, |∇A|=0.0003984, |∇B|=0.01871, |LoRA(x)|=0.5402, B≠0=12282 distilbert.transformer.layer.1.attention.q_lin: |A|=0.02888, |B|=0.0002127, |∇A|=0.000235, |∇B|=0.008725, |LoRA(x)|=0.8818, B≠0=12280 distilbert.transformer.layer.1.attention.k_lin: |A|=0.02899, |B|=0.0002166, |∇A|=0.000186, |∇B|=0.007929, |LoRA(x)|=0.5567, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.02856, |B|=0.000197, |∇A|=0.0005595, |∇B|=0.03967, |LoRA(x)|=0.7959, B≠0=12281 distilbert.transformer.layer.1.attention.out_lin: |A|=0.02866, |B|=0.0002057, |∇A|=0.00109, |∇B|=0.03448, |LoRA(x)|=0.5205, B≠0=12282 distilbert.transformer.layer.1.ffn.lin1: |A|=0.0181, |B|=0.0002256, |∇A|=0.0005969, |∇B|=0.01394, |LoRA(x)|=4.284, B≠0=49126 distilbert.transformer.layer.1.ffn.lin2: |A|=0.01816, |B|=0.0001935, |∇A|=0.0003225, |∇B|=0.01989, |LoRA(x)|=0.2153, B≠0=12281 distilbert.transformer.layer.2.attention.q_lin: |A|=0.02859, |B|=0.0002196, |∇A|=0.0002355, |∇B|=0.01189, |LoRA(x)|=1.088, B≠0=12282 distilbert.transformer.layer.2.attention.k_lin: |A|=0.02892, |B|=0.0002278, |∇A|=0.0002871, |∇B|=0.01122, |LoRA(x)|=0.6679, B≠0=12282 distilbert.transformer.layer.2.attention.v_lin: |A|=0.02891, |B|=0.000196, |∇A|=0.0007276, |∇B|=0.03889, |LoRA(x)|=0.7448, B≠0=12281 distilbert.transformer.layer.2.attention.out_lin: |A|=0.02864, |B|=0.0002144, |∇A|=0.001153, |∇B|=0.03337, |LoRA(x)|=0.521, B≠0=12282 distilbert.transformer.layer.2.ffn.lin1: |A|=0.01813, |B|=0.0002415, |∇A|=0.001008, |∇B|=0.01763, |LoRA(x)|=5.14, B≠0=49126 distilbert.transformer.layer.2.ffn.lin2: |A|=0.01825, |B|=0.0002205, |∇A|=0.0005423, |∇B|=0.02212, |LoRA(x)|=0.3106, B≠0=12282 distilbert.transformer.layer.3.attention.q_lin: |A|=0.02898, |B|=0.0002311, |∇A|=0.000282, |∇B|=0.009444, |LoRA(x)|=1.016, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.02904, |B|=0.0002356, |∇A|=0.0002989, |∇B|=0.009808, |LoRA(x)|=0.9602, B≠0=12282 distilbert.transformer.layer.3.attention.v_lin: |A|=0.02898, |B|=0.0002056, |∇A|=0.0007878, |∇B|=0.05867, |LoRA(x)|=1.151, B≠0=12282 distilbert.transformer.layer.3.attention.out_lin: |A|=0.02881, |B|=0.0002671, |∇A|=0.002097, |∇B|=0.02614, |LoRA(x)|=0.4825, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.01824, |B|=0.0002449, |∇A|=0.001385, |∇B|=0.01429, |LoRA(x)|=3.554, B≠0=49127 distilbert.transformer.layer.3.ffn.lin2: |A|=0.01825, |B|=0.0002515, |∇A|=0.0008056, |∇B|=0.02044, |LoRA(x)|=0.3727, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.02899, |B|=0.0002635, |∇A|=0.000508, |∇B|=0.01414, |LoRA(x)|=2.408, B≠0=12282 distilbert.transformer.layer.4.attention.k_lin: |A|=0.02889, |B|=0.00025, |∇A|=0.000996, |∇B|=0.01001, |LoRA(x)|=1.021, B≠0=12282 distilbert.transformer.layer.4.attention.v_lin: |A|=0.02878, |B|=0.0002282, |∇A|=0.00123, |∇B|=0.03556, |LoRA(x)|=1.406, B≠0=12282 distilbert.transformer.layer.4.attention.out_lin: |A|=0.02887, |B|=0.000246, |∇A|=0.001401, |∇B|=0.0175, |LoRA(x)|=0.645, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.01833, |B|=0.0002607, |∇A|=0.0008604, |∇B|=0.009056, |LoRA(x)|=5.717, B≠0=49129 distilbert.transformer.layer.4.ffn.lin2: |A|=0.01822, |B|=0.0002721, |∇A|=0.0004477, |∇B|=0.01619, |LoRA(x)|=0.7457, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.02888, |B|=0.0002728, |∇A|=0.0002167, |∇B|=0.00796, |LoRA(x)|=1.832, B≠0=12282 distilbert.transformer.layer.5.attention.k_lin: |A|=0.02903, |B|=0.0002349, |∇A|=0.0007984, |∇B|=0.004929, |LoRA(x)|=2.066, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.02869, |B|=0.0002458, |∇A|=0.001274, |∇B|=0.01634, |LoRA(x)|=1.781, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.02881, |B|=0.0003804, |∇A|=0.001541, |∇B|=0.008762, |LoRA(x)|=1.76, B≠0=12283 distilbert.transformer.layer.5.ffn.lin1: |A|=0.01824, |B|=0.0002594, |∇A|=0.001008, |∇B|=0.003997, |LoRA(x)|=6.319, B≠0=49127 distilbert.transformer.layer.5.ffn.lin2: |A|=0.01822, |B|=0.0003608, |∇A|=0.0004907, |∇B|=0.005384, |LoRA(x)|=1.169, B≠0=12283
Scaled Xavier for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Scaled Xavier initialisation for A
sc_f = 1.5
bound = math.sqrt(6 / (rank + in_dim)) * sc_f
self.A = nn.Parameter(torch.empty(in_dim, rank).uniform_(-bound, bound))
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.559900 | 0.342404 | 0.867200 | 0.866935 |
| 100 | 0.335600 | 0.298127 | 0.878400 | 0.878565 |
| 150 | 0.330100 | 0.269384 | 0.892000 | 0.891881 |
| 200 | 0.256100 | 0.272999 | 0.888800 | 0.888328 |
| 250 | 0.297100 | 0.247388 | 0.900800 | 0.900662 |
| 300 | 0.267600 | 0.260106 | 0.892000 | 0.891542 |
| 350 | 0.270000 | 0.245957 | 0.905600 | 0.905600 |
| 400 | 0.255400 | 0.242078 | 0.904800 | 0.904934 |
| 450 | 0.244500 | 0.241970 | 0.912000 | 0.911895 |
| 500 | 0.257700 | 0.235908 | 0.908800 | 0.908494 |
| 550 | 0.241300 | 0.237024 | 0.910400 | 0.910055 |
| 600 | 0.257700 | 0.227201 | 0.913600 | 0.913390 |
| 650 | 0.231800 | 0.231505 | 0.910400 | 0.910538 |
| 700 | 0.227300 | 0.224408 | 0.918400 | 0.918237 |
| 750 | 0.222500 | 0.222870 | 0.916000 | 0.915892 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.06532, |B|=0.0002035, |∇A|=0.0001663, |∇B|=0.01523, |LoRA(x)|=1.255, B≠0=12281 distilbert.transformer.layer.0.attention.k_lin: |A|=0.06534, |B|=0.0002003, |∇A|=0.0001101, |∇B|=0.01507, |LoRA(x)|=1.249, B≠0=12281 distilbert.transformer.layer.0.attention.v_lin: |A|=0.06566, |B|=0.0001929, |∇A|=0.0006117, |∇B|=0.03535, |LoRA(x)|=1.217, B≠0=12281 distilbert.transformer.layer.0.attention.out_lin: |A|=0.0661, |B|=0.0001711, |∇A|=0.0007501, |∇B|=0.04806, |LoRA(x)|=0.6202, B≠0=12281 distilbert.transformer.layer.0.ffn.lin1: |A|=0.06538, |B|=0.0001835, |∇A|=0.0004738, |∇B|=0.0315, |LoRA(x)|=6.167, B≠0=49123 distilbert.transformer.layer.0.ffn.lin2: |A|=0.03304, |B|=0.0001718, |∇A|=0.000289, |∇B|=0.02706, |LoRA(x)|=0.4766, B≠0=12280 distilbert.transformer.layer.1.attention.q_lin: |A|=0.06618, |B|=0.0001879, |∇A|=0.0001798, |∇B|=0.01696, |LoRA(x)|=1.24, B≠0=12280 distilbert.transformer.layer.1.attention.k_lin: |A|=0.06591, |B|=0.000191, |∇A|=0.0001451, |∇B|=0.01575, |LoRA(x)|=0.9638, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.06548, |B|=0.0001601, |∇A|=0.0004498, |∇B|=0.06403, |LoRA(x)|=0.9895, B≠0=12280 distilbert.transformer.layer.1.attention.out_lin: |A|=0.06559, |B|=0.0001658, |∇A|=0.0006209, |∇B|=0.05379, |LoRA(x)|=0.5784, B≠0=12281 distilbert.transformer.layer.1.ffn.lin1: |A|=0.06538, |B|=0.0001922, |∇A|=0.0004596, |∇B|=0.0344, |LoRA(x)|=8.039, B≠0=49123 distilbert.transformer.layer.1.ffn.lin2: |A|=0.03304, |B|=0.0001619, |∇A|=0.0002367, |∇B|=0.0329, |LoRA(x)|=0.297, B≠0=12280 distilbert.transformer.layer.2.attention.q_lin: |A|=0.06613, |B|=0.0001945, |∇A|=0.0002159, |∇B|=0.0249, |LoRA(x)|=1.65, B≠0=12281 distilbert.transformer.layer.2.attention.k_lin: |A|=0.06571, |B|=0.0001978, |∇A|=0.0002439, |∇B|=0.02332, |LoRA(x)|=1.206, B≠0=12281 distilbert.transformer.layer.2.attention.v_lin: |A|=0.06583, |B|=0.0001605, |∇A|=0.0005263, |∇B|=0.06787, |LoRA(x)|=0.9769, B≠0=12280 distilbert.transformer.layer.2.attention.out_lin: |A|=0.06489, |B|=0.0001754, |∇A|=0.0007919, |∇B|=0.04553, |LoRA(x)|=0.5327, B≠0=12280 distilbert.transformer.layer.2.ffn.lin1: |A|=0.06513, |B|=0.0001948, |∇A|=0.0006043, |∇B|=0.04044, |LoRA(x)|=7.488, B≠0=49124 distilbert.transformer.layer.2.ffn.lin2: |A|=0.03309, |B|=0.0001835, |∇A|=0.0004721, |∇B|=0.02781, |LoRA(x)|=0.3463, B≠0=12281 distilbert.transformer.layer.3.attention.q_lin: |A|=0.066, |B|=0.0002053, |∇A|=0.0002105, |∇B|=0.02061, |LoRA(x)|=2.099, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.06551, |B|=0.0002091, |∇A|=0.0002618, |∇B|=0.02079, |LoRA(x)|=1.555, B≠0=12281 distilbert.transformer.layer.3.attention.v_lin: |A|=0.06551, |B|=0.0001497, |∇A|=0.0004883, |∇B|=0.08872, |LoRA(x)|=1.46, B≠0=12280 distilbert.transformer.layer.3.attention.out_lin: |A|=0.06552, |B|=0.0001923, |∇A|=0.0008196, |∇B|=0.03743, |LoRA(x)|=0.5979, B≠0=12281 distilbert.transformer.layer.3.ffn.lin1: |A|=0.06555, |B|=0.000184, |∇A|=0.0006847, |∇B|=0.04193, |LoRA(x)|=8.602, B≠0=49123 distilbert.transformer.layer.3.ffn.lin2: |A|=0.03308, |B|=0.0001783, |∇A|=0.0004434, |∇B|=0.0232, |LoRA(x)|=0.3174, B≠0=12280 distilbert.transformer.layer.4.attention.q_lin: |A|=0.06642, |B|=0.0002137, |∇A|=0.0003046, |∇B|=0.02595, |LoRA(x)|=3.838, B≠0=12281 distilbert.transformer.layer.4.attention.k_lin: |A|=0.06605, |B|=0.0002044, |∇A|=0.000532, |∇B|=0.01982, |LoRA(x)|=2.01, B≠0=12281 distilbert.transformer.layer.4.attention.v_lin: |A|=0.0658, |B|=0.0001578, |∇A|=0.0006522, |∇B|=0.06364, |LoRA(x)|=1.776, B≠0=12280 distilbert.transformer.layer.4.attention.out_lin: |A|=0.0653, |B|=0.0001886, |∇A|=0.0008771, |∇B|=0.02987, |LoRA(x)|=0.784, B≠0=12281 distilbert.transformer.layer.4.ffn.lin1: |A|=0.06552, |B|=0.0002044, |∇A|=0.0006782, |∇B|=0.02294, |LoRA(x)|=9.608, B≠0=49125 distilbert.transformer.layer.4.ffn.lin2: |A|=0.03295, |B|=0.0002195, |∇A|=0.0004506, |∇B|=0.01567, |LoRA(x)|=0.4364, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.06579, |B|=0.000201, |∇A|=0.000149, |∇B|=0.01128, |LoRA(x)|=1.89, B≠0=12281 distilbert.transformer.layer.5.attention.k_lin: |A|=0.06566, |B|=0.0001923, |∇A|=0.0005654, |∇B|=0.01185, |LoRA(x)|=2.068, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.06563, |B|=0.0001716, |∇A|=0.0005953, |∇B|=0.03601, |LoRA(x)|=2.725, B≠0=12281 distilbert.transformer.layer.5.attention.out_lin: |A|=0.0654, |B|=0.0002853, |∇A|=0.0008672, |∇B|=0.01711, |LoRA(x)|=2.471, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.0657, |B|=0.0001973, |∇A|=0.0004453, |∇B|=0.009803, |LoRA(x)|=11.89, B≠0=49123 distilbert.transformer.layer.5.ffn.lin2: |A|=0.03308, |B|=0.0002221, |∇A|=0.0001933, |∇B|=0.008399, |LoRA(x)|=1.134, B≠0=12281
Orthogonal initialisation for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Orthogonal initialisation for A
self.A = nn.Parameter(torch.empty(in_dim, rank))
nn.init.orthogonal_(self.A)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.631800 | 0.476356 | 0.832000 | 0.832194 |
| 100 | 0.386100 | 0.353101 | 0.848000 | 0.848142 |
| 150 | 0.338900 | 0.293442 | 0.880800 | 0.881030 |
| 200 | 0.273900 | 0.276142 | 0.885600 | 0.885174 |
| 250 | 0.299300 | 0.277328 | 0.892000 | 0.892224 |
| 300 | 0.267800 | 0.270556 | 0.896000 | 0.895489 |
| 350 | 0.275800 | 0.252276 | 0.899200 | 0.899257 |
| 400 | 0.261600 | 0.252071 | 0.902400 | 0.902492 |
| 450 | 0.248400 | 0.254714 | 0.898400 | 0.898543 |
| 500 | 0.268600 | 0.241609 | 0.907200 | 0.907034 |
| 550 | 0.251900 | 0.238318 | 0.904800 | 0.904695 |
| 600 | 0.268800 | 0.233752 | 0.906400 | 0.906314 |
| 650 | 0.239300 | 0.242723 | 0.902400 | 0.902550 |
| 700 | 0.233100 | 0.234716 | 0.906400 | 0.906393 |
| 750 | 0.233200 | 0.233564 | 0.904800 | 0.904793 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.02894, |B|=0.0002256, |∇A|=0.0002051, |∇B|=0.00704, |LoRA(x)|=0.6222, B≠0=12282 distilbert.transformer.layer.0.attention.k_lin: |A|=0.0289, |B|=0.0002237, |∇A|=0.0001388, |∇B|=0.007028, |LoRA(x)|=0.7385, B≠0=12282 distilbert.transformer.layer.0.attention.v_lin: |A|=0.02868, |B|=0.0002192, |∇A|=0.0007466, |∇B|=0.01938, |LoRA(x)|=0.6943, B≠0=12281 distilbert.transformer.layer.0.attention.out_lin: |A|=0.02873, |B|=0.0002161, |∇A|=0.001234, |∇B|=0.02534, |LoRA(x)|=0.4499, B≠0=12281 distilbert.transformer.layer.0.ffn.lin1: |A|=0.02887, |B|=0.0002077, |∇A|=0.0005898, |∇B|=0.01437, |LoRA(x)|=3.549, B≠0=49125 distilbert.transformer.layer.0.ffn.lin2: |A|=0.01439, |B|=0.0002112, |∇A|=0.0003794, |∇B|=0.0146, |LoRA(x)|=0.4078, B≠0=12281 distilbert.transformer.layer.1.attention.q_lin: |A|=0.02871, |B|=0.0002067, |∇A|=0.0002277, |∇B|=0.008702, |LoRA(x)|=0.8418, B≠0=12281 distilbert.transformer.layer.1.attention.k_lin: |A|=0.02878, |B|=0.0002133, |∇A|=0.0001856, |∇B|=0.007924, |LoRA(x)|=0.5599, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.02885, |B|=0.0001863, |∇A|=0.0005139, |∇B|=0.03951, |LoRA(x)|=0.7551, B≠0=12281 distilbert.transformer.layer.1.attention.out_lin: |A|=0.02884, |B|=0.0001969, |∇A|=0.001035, |∇B|=0.03354, |LoRA(x)|=0.4694, B≠0=12281 distilbert.transformer.layer.1.ffn.lin1: |A|=0.02882, |B|=0.0002174, |∇A|=0.0005914, |∇B|=0.02101, |LoRA(x)|=5.848, B≠0=49126 distilbert.transformer.layer.1.ffn.lin2: |A|=0.01441, |B|=0.0001863, |∇A|=0.0003113, |∇B|=0.01551, |LoRA(x)|=0.1606, B≠0=12281 distilbert.transformer.layer.2.attention.q_lin: |A|=0.02863, |B|=0.0002147, |∇A|=0.0002332, |∇B|=0.01161, |LoRA(x)|=0.9744, B≠0=12281 distilbert.transformer.layer.2.attention.k_lin: |A|=0.02885, |B|=0.0002233, |∇A|=0.000279, |∇B|=0.01124, |LoRA(x)|=0.6466, B≠0=12282 distilbert.transformer.layer.2.attention.v_lin: |A|=0.02879, |B|=0.0001866, |∇A|=0.0006717, |∇B|=0.03784, |LoRA(x)|=0.6731, B≠0=12281 distilbert.transformer.layer.2.attention.out_lin: |A|=0.02879, |B|=0.0002054, |∇A|=0.001122, |∇B|=0.03079, |LoRA(x)|=0.4415, B≠0=12281 distilbert.transformer.layer.2.ffn.lin1: |A|=0.02879, |B|=0.0002269, |∇A|=0.0008912, |∇B|=0.02662, |LoRA(x)|=7.34, B≠0=49126 distilbert.transformer.layer.2.ffn.lin2: |A|=0.01442, |B|=0.0002097, |∇A|=0.0005021, |∇B|=0.01674, |LoRA(x)|=0.2238, B≠0=12281 distilbert.transformer.layer.3.attention.q_lin: |A|=0.02891, |B|=0.000226, |∇A|=0.0002603, |∇B|=0.009309, |LoRA(x)|=0.9878, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.02882, |B|=0.0002275, |∇A|=0.000281, |∇B|=0.009663, |LoRA(x)|=0.885, B≠0=12282 distilbert.transformer.layer.3.attention.v_lin: |A|=0.02887, |B|=0.0001937, |∇A|=0.0007371, |∇B|=0.0562, |LoRA(x)|=1.023, B≠0=12281 distilbert.transformer.layer.3.attention.out_lin: |A|=0.02891, |B|=0.0002549, |∇A|=0.001971, |∇B|=0.02536, |LoRA(x)|=0.4443, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.02888, |B|=0.0002292, |∇A|=0.001239, |∇B|=0.02104, |LoRA(x)|=4.993, B≠0=49126 distilbert.transformer.layer.3.ffn.lin2: |A|=0.0144, |B|=0.0002404, |∇A|=0.0007727, |∇B|=0.01577, |LoRA(x)|=0.2704, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.02871, |B|=0.0002524, |∇A|=0.000489, |∇B|=0.01352, |LoRA(x)|=2.283, B≠0=12282 distilbert.transformer.layer.4.attention.k_lin: |A|=0.02884, |B|=0.0002416, |∇A|=0.0009936, |∇B|=0.009785, |LoRA(x)|=0.9456, B≠0=12281 distilbert.transformer.layer.4.attention.v_lin: |A|=0.02885, |B|=0.0002123, |∇A|=0.00103, |∇B|=0.0365, |LoRA(x)|=1.395, B≠0=12281 distilbert.transformer.layer.4.attention.out_lin: |A|=0.0288, |B|=0.000233, |∇A|=0.001299, |∇B|=0.01674, |LoRA(x)|=0.5857, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.02887, |B|=0.0002406, |∇A|=0.0007501, |∇B|=0.01257, |LoRA(x)|=7.437, B≠0=49128 distilbert.transformer.layer.4.ffn.lin2: |A|=0.01443, |B|=0.0002652, |∇A|=0.0004402, |∇B|=0.01286, |LoRA(x)|=0.5918, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.02885, |B|=0.0002558, |∇A|=0.0001962, |∇B|=0.007697, |LoRA(x)|=1.709, B≠0=12281 distilbert.transformer.layer.5.attention.k_lin: |A|=0.02884, |B|=0.0002236, |∇A|=0.0007342, |∇B|=0.004787, |LoRA(x)|=1.879, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.02887, |B|=0.0002386, |∇A|=0.001155, |∇B|=0.01626, |LoRA(x)|=1.73, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.02883, |B|=0.0003673, |∇A|=0.001476, |∇B|=0.008586, |LoRA(x)|=1.781, B≠0=12283 distilbert.transformer.layer.5.ffn.lin1: |A|=0.02882, |B|=0.0002423, |∇A|=0.0007944, |∇B|=0.005601, |LoRA(x)|=9.013, B≠0=49127 distilbert.transformer.layer.5.ffn.lin2: |A|=0.0144, |B|=0.0003482, |∇A|=0.0004651, |∇B|=0.004582, |LoRA(x)|=0.9978, B≠0=12283
Increase std.dev for B and apply dropout to boost lora.B adaptation¶
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = nn.Parameter(1e-3 * torch.randn(rank, out_dim) * std_dev) # Increase std.dev for B
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
| Step | Training Loss | Validation Loss | Accuracy | F1 |
|---|---|---|---|---|
| 50 | 0.572400 | 0.361935 | 0.853600 | 0.853897 |
| 100 | 0.395200 | 0.314839 | 0.871200 | 0.871393 |
| 150 | 0.388700 | 0.324395 | 0.872000 | 0.872105 |
| 200 | 0.359600 | 0.288238 | 0.875200 | 0.875027 |
| 250 | 0.375300 | 0.355832 | 0.841600 | 0.841525 |
| 300 | 0.348600 | 0.279084 | 0.888000 | 0.887753 |
| 350 | 0.360400 | 0.307965 | 0.874400 | 0.874675 |
| 400 | 0.331700 | 0.285904 | 0.896800 | 0.897028 |
| 450 | 0.313200 | 0.311799 | 0.895200 | 0.895123 |
| 500 | 0.353700 | 0.291399 | 0.883200 | 0.882492 |
| 550 | 0.313900 | 0.274373 | 0.881600 | 0.881711 |
| 600 | 0.344400 | 0.289618 | 0.887200 | 0.886285 |
| 650 | 0.311900 | 0.266004 | 0.895200 | 0.895175 |
| 700 | 0.313900 | 0.260952 | 0.904800 | 0.904677 |
| 750 | 0.311800 | 0.276892 | 0.896800 | 0.896441 |
| 800 | 0.331900 | 0.266435 | 0.900000 | 0.899438 |
| 850 | 0.289800 | 0.267684 | 0.902400 | 0.902573 |
| 900 | 0.283600 | 0.294683 | 0.899200 | 0.899422 |
| 950 | 0.307300 | 0.265540 | 0.906400 | 0.906297 |
| 1000 | 0.270400 | 0.288483 | 0.900000 | 0.900111 |
| 1050 | 0.266700 | 0.258305 | 0.907200 | 0.906888 |
| 1100 | 0.276200 | 0.250781 | 0.905600 | 0.905555 |
| 1150 | 0.253400 | 0.249464 | 0.902400 | 0.902573 |
| 1200 | 0.278300 | 0.245663 | 0.912800 | 0.912704 |
| 1250 | 0.260800 | 0.250558 | 0.904800 | 0.904659 |
| 1300 | 0.266500 | 0.247603 | 0.907200 | 0.907308 |
| 1350 | 0.260500 | 0.247938 | 0.904800 | 0.904915 |
| 1400 | 0.250800 | 0.251896 | 0.904000 | 0.903701 |
| 1450 | 0.252900 | 0.245024 | 0.907200 | 0.907171 |
| 1500 | 0.255200 | 0.245073 | 0.904800 | 0.904639 |
| 1550 | 0.242100 | 0.244082 | 0.908800 | 0.908786 |
LoRA (All Attention) Test Results: {'eval_loss': 0.2366224229335785, 'eval_accuracy': 0.907621052631579, 'eval_f1': 0.9076122672728953, 'eval_runtime': 184.1411, 'eval_samples_per_second': 128.977, 'eval_steps_per_second': 4.035, 'epoch': 2.0}
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 563068 KiB | 4397 MiB | 232710 GiB | 232710 GiB |
| from large pool | 546990 KiB | 4344 MiB | 230922 GiB | 230922 GiB |
| from small pool | 16078 KiB | 55 MiB | 1788 GiB | 1788 GiB |
|---------------------------------------------------------------------------|
| Active memory | 563068 KiB | 4397 MiB | 232710 GiB | 232710 GiB |
| from large pool | 546990 KiB | 4344 MiB | 230922 GiB | 230922 GiB |
| from small pool | 16078 KiB | 55 MiB | 1788 GiB | 1788 GiB |
|---------------------------------------------------------------------------|
| Requested memory | 559912 KiB | 4394 MiB | 232500 GiB | 232499 GiB |
| from large pool | 543836 KiB | 4340 MiB | 230718 GiB | 230718 GiB |
| from small pool | 16076 KiB | 55 MiB | 1781 GiB | 1781 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 4510 MiB | 4580 MiB | 23456 MiB | 18946 MiB |
| from large pool | 4450 MiB | 4520 MiB | 23204 MiB | 18754 MiB |
| from small pool | 60 MiB | 62 MiB | 252 MiB | 192 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 182404 KiB | 382935 KiB | 79215 GiB | 79215 GiB |
| from large pool | 175954 KiB | 376402 KiB | 77291 GiB | 77290 GiB |
| from small pool | 6450 KiB | 30566 KiB | 1924 GiB | 1924 GiB |
|---------------------------------------------------------------------------|
| Allocations | 436 | 623 | 30226 K | 30225 K |
| from large pool | 82 | 152 | 6772 K | 6772 K |
| from small pool | 354 | 515 | 23453 K | 23453 K |
|---------------------------------------------------------------------------|
| Active allocs | 436 | 623 | 30226 K | 30225 K |
| from large pool | 82 | 152 | 6772 K | 6772 K |
| from small pool | 354 | 515 | 23453 K | 23453 K |
|---------------------------------------------------------------------------|
| GPU reserved segments | 97 | 101 | 431 | 334 |
| from large pool | 67 | 71 | 305 | 238 |
| from small pool | 30 | 31 | 126 | 96 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 46 | 73 | 16377 K | 16377 K |
| from large pool | 16 | 22 | 2502 K | 2502 K |
| from small pool | 30 | 57 | 13874 K | 13874 K |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0003389, |∇A|=0.000502, |∇B|=0.158, |LoRA(x)|=6.894, B≠0=12285
distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0003358, |∇A|=0.0004346, |∇B|=0.1471, |LoRA(x)|=7.695, B≠0=12286
distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.000301, |∇A|=0.0007981, |∇B|=0.2967, |LoRA(x)|=6.75, B≠0=12285
distilbert.transformer.layer.0.attention.out_lin: |A|=0.2003, |B|=0.0003029, |∇A|=0.0007932, |∇B|=0.321, |LoRA(x)|=3.656, B≠0=12285
distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0003268, |∇A|=0.0009316, |∇B|=0.1899, |LoRA(x)|=21.53, B≠0=49143
distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0003015, |∇A|=0.0003772, |∇B|=0.3445, |LoRA(x)|=6.206, B≠0=12286
distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0003353, |∇A|=0.0003761, |∇B|=0.1395, |LoRA(x)|=7.866, B≠0=12285
distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0003366, |∇A|=0.0004358, |∇B|=0.1484, |LoRA(x)|=6.999, B≠0=12285
distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0002858, |∇A|=0.0006746, |∇B|=0.2917, |LoRA(x)|=5.321, B≠0=12285
distilbert.transformer.layer.1.attention.out_lin: |A|=0.1984, |B|=0.00029, |∇A|=0.0007089, |∇B|=0.3005, |LoRA(x)|=3.779, B≠0=12285
distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0003159, |∇A|=0.0007903, |∇B|=0.1846, |LoRA(x)|=23.45, B≠0=49142
distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.000288, |∇A|=0.0002804, |∇B|=0.2852, |LoRA(x)|=3.764, B≠0=12284
distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0003262, |∇A|=0.0003942, |∇B|=0.1447, |LoRA(x)|=8.057, B≠0=12286
distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0003176, |∇A|=0.000448, |∇B|=0.1577, |LoRA(x)|=7.385, B≠0=12285
distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0002813, |∇A|=0.0005137, |∇B|=0.2423, |LoRA(x)|=6.056, B≠0=12286
distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0002805, |∇A|=0.0005144, |∇B|=0.2204, |LoRA(x)|=3.703, B≠0=12285
distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0003084, |∇A|=0.0006538, |∇B|=0.15, |LoRA(x)|=24.32, B≠0=49141
distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002843, |∇A|=0.000218, |∇B|=0.2078, |LoRA(x)|=3.736, B≠0=12285
distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.000324, |∇A|=0.0002934, |∇B|=0.1105, |LoRA(x)|=8.299, B≠0=12285
distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0003226, |∇A|=0.0003681, |∇B|=0.1199, |LoRA(x)|=7.896, B≠0=12285
distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.00026, |∇A|=0.0003734, |∇B|=0.2297, |LoRA(x)|=6.205, B≠0=12285
distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0002658, |∇A|=0.0004406, |∇B|=0.1566, |LoRA(x)|=4.282, B≠0=12285
distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0003003, |∇A|=0.0005145, |∇B|=0.08879, |LoRA(x)|=18.41, B≠0=49141
distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0002639, |∇A|=0.0001617, |∇B|=0.1338, |LoRA(x)|=3.407, B≠0=12285
distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0003261, |∇A|=0.0001543, |∇B|=0.06833, |LoRA(x)|=12.96, B≠0=12284
distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0003134, |∇A|=0.0003658, |∇B|=0.07867, |LoRA(x)|=8.107, B≠0=12286
distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0002635, |∇A|=0.0003247, |∇B|=0.09568, |LoRA(x)|=6.022, B≠0=12285
distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0002577, |∇A|=0.0002548, |∇B|=0.0689, |LoRA(x)|=4.138, B≠0=12285
distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0002738, |∇A|=0.0002338, |∇B|=0.03957, |LoRA(x)|=21.57, B≠0=49141
distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0002625, |∇A|=0.0001001, |∇B|=0.06575, |LoRA(x)|=4.007, B≠0=12284
distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0003016, |∇A|=8.459e-05, |∇B|=0.02227, |LoRA(x)|=9.447, B≠0=12286
distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002925, |∇A|=0.0001426, |∇B|=0.02499, |LoRA(x)|=7.739, B≠0=12285
distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0002508, |∇A|=0.0002601, |∇B|=0.04984, |LoRA(x)|=8.357, B≠0=12285
distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0002483, |∇A|=0.000202, |∇B|=0.0489, |LoRA(x)|=15.56, B≠0=12286
distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0002525, |∇A|=0.0001525, |∇B|=0.01987, |LoRA(x)|=20.78, B≠0=49138
distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0002402, |∇A|=4.512e-05, |∇B|=0.03451, |LoRA(x)|=5.875, B≠0=12285
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20136898756027222
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00036163683398626745
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.199139803647995
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.0003591857384890318
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.199319526553154
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.00031731155468150973
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20034024119377136
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00031905301148071885
distilbert.transformer.layer.0.ffn.lin1.lora.A 0.19975820183753967
distilbert.transformer.layer.0.ffn.lin1.lora.B 0.00035009029670618474
distilbert.transformer.layer.0.ffn.lin2.lora.A 0.19959013164043427
distilbert.transformer.layer.0.ffn.lin2.lora.B 0.0003168959519825876
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1999898999929428
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.00035791145637631416
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.20074164867401123
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.0003592821885831654
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.19777165353298187
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030186009826138616
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.19843807816505432
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.0003039248113054782
distilbert.transformer.layer.1.ffn.lin1.lora.A 0.1979658454656601
distilbert.transformer.layer.1.ffn.lin1.lora.B 0.00033694933517836034
distilbert.transformer.layer.1.ffn.lin2.lora.A 0.19888746738433838
distilbert.transformer.layer.1.ffn.lin2.lora.B 0.0003027324564754963
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19796311855316162
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00034728783066384494
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.20021361112594604
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00033590427483431995
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.20018713176250458
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00029516470385715365
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19833998382091522
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.00029389417613856494
distilbert.transformer.layer.2.ffn.lin1.lora.A 0.1982342004776001
distilbert.transformer.layer.2.ffn.lin1.lora.B 0.0003286522696726024
distilbert.transformer.layer.2.ffn.lin2.lora.A 0.19985038042068481
distilbert.transformer.layer.2.ffn.lin2.lora.B 0.000297279329970479
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.20064640045166016
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00034552914439700544
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.20107726752758026
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00034139861236326396
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.20070061087608337
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00027113681426271796
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.1994575560092926
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.0002771209110505879
distilbert.transformer.layer.3.ffn.lin1.lora.A 0.19948390126228333
distilbert.transformer.layer.3.ffn.lin1.lora.B 0.00031956000020727515
distilbert.transformer.layer.3.ffn.lin2.lora.A 0.1997573971748352
distilbert.transformer.layer.3.ffn.lin2.lora.B 0.00027455651434138417
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.20080135762691498
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.00035053043393418193
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.19999122619628906
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00033398327650502324
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.199321448802948
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.0002780442591756582
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19989004731178284
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00027118308935314417
distilbert.transformer.layer.4.ffn.lin1.lora.A 0.20036581158638
distilbert.transformer.layer.4.ffn.lin1.lora.B 0.00029024441028013825
distilbert.transformer.layer.4.ffn.lin2.lora.A 0.19942979514598846
distilbert.transformer.layer.4.ffn.lin2.lora.B 0.0002758408372756094
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.19984808564186096
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00032156246015802026
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.20102792978286743
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003097845474258065
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19865462183952332
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002626449568197131
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.1993870586156845
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0002558119304012507
distilbert.transformer.layer.5.ffn.lin1.lora.A 0.19927702844142914
distilbert.transformer.layer.5.ffn.lin1.lora.B 0.0002630744711495936
distilbert.transformer.layer.5.ffn.lin2.lora.A 0.19940851628780365
distilbert.transformer.layer.5.ffn.lin2.lora.B 0.0002476528752595186
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8493
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0506
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.5909
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0500
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.8039
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0444
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.9052
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0442
distilbert.transformer.layer.0.ffn.lin1.lora.A weight norm: 27.7016
distilbert.transformer.layer.0.ffn.lin1.lora.B weight norm: 0.0976
distilbert.transformer.layer.0.ffn.lin2.lora.A weight norm: 55.5094
distilbert.transformer.layer.0.ffn.lin2.lora.B weight norm: 0.0440
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8651
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0501
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.9079
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0503
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.4143
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0421
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.5450
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0424
distilbert.transformer.layer.1.ffn.lin1.lora.A weight norm: 27.5116
distilbert.transformer.layer.1.ffn.lin1.lora.B weight norm: 0.0939
distilbert.transformer.layer.1.ffn.lin2.lora.A weight norm: 55.2299
distilbert.transformer.layer.1.ffn.lin2.lora.B weight norm: 0.0421
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.6563
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0481
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.7654
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0468
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.8159
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0414
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.5538
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0409
distilbert.transformer.layer.2.ffn.lin1.lora.A weight norm: 27.5591
distilbert.transformer.layer.2.ffn.lin1.lora.B weight norm: 0.0917
distilbert.transformer.layer.2.ffn.lin2.lora.A weight norm: 55.4738
distilbert.transformer.layer.2.ffn.lin2.lora.B weight norm: 0.0416
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.7736
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0482
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.9000
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0476
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.7956
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0376
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.6487
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0386
distilbert.transformer.layer.3.ffn.lin1.lora.A weight norm: 27.6616
distilbert.transformer.layer.3.ffn.lin1.lora.B weight norm: 0.0894
distilbert.transformer.layer.3.ffn.lin2.lora.A weight norm: 55.5141
distilbert.transformer.layer.3.ffn.lin2.lora.B weight norm: 0.0384
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.9841
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0488
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.7779
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0466
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.6690
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0386
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.7600
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0378
distilbert.transformer.layer.4.ffn.lin1.lora.A weight norm: 27.7930
distilbert.transformer.layer.4.ffn.lin1.lora.B weight norm: 0.0813
distilbert.transformer.layer.4.ffn.lin2.lora.A weight norm: 55.3399
distilbert.transformer.layer.4.ffn.lin2.lora.B weight norm: 0.0385
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.7641
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0448
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.9006
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0432
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.5460
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0366
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7014
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0356
distilbert.transformer.layer.5.ffn.lin1.lora.A weight norm: 27.6943
distilbert.transformer.layer.5.ffn.lin1.lora.B weight norm: 0.0736
distilbert.transformer.layer.5.ffn.lin2.lora.A weight norm: 55.4635
distilbert.transformer.layer.5.ffn.lin2.lora.B weight norm: 0.0346
Key Observations from the Dump¶
- Despite small absolute values (e.g., |B| ≈ 0.00025–0.00033), the gradients |∇B| are significantly larger, especially in early layers: |∇B| ≈ 0.15–0.34 in layers 0–2; |∇B| drops gradually deeper → ~0.02 at layer 5. This suggests strong early gradient flow into B, tapering off with depth. LoRA is adapting more in early layers, less in deep ones.
- |LoRA(x)| ranges are relatively large despite tiny B. This is desired: LoRA is learning meaningful low-rank deltas.
- Instead of B growing, we are seeing sharp, effective updates guided by ∇B — especially with dropout pushing adaptation.
- Althouhg LoRA.B values remain small (|B| ≈ 0.00025–0.00033), this is not a sign of stagnation/ undertraining — the signal is directional: LoRA module effectively computes: ΔW=α⋅AB where A has relatively large values (directional basis) and B contains very small coefficients (activation selectors). The direction of AB matters much more than the norm, because alpha rescales the final output. Here, the norm of AB is suppressed during training, while its orientation in parameter space becomes useful for model adaptation. Small B encourages better generalization. A large B would effectively cause large ΔW, risking overfitting—defeating the point of parameter-efficient fine-tuning.