import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(torch.cuda.get_device_name(0))
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import bitsandbytes
import peft
import transformers
print(transformers.__version__)
print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)] PyTorch Version: 2.5.1+cu121 True 1 CUDA Version: 12.1 NVIDIA GeForce RTX 4080 Laptop GPU 4.50.0.dev0 bitsandbytes version: 0.45.3 peft version: 0.15.2.dev0 True
Load dataset, base model, and tokeniser¶
from datasets import load_dataset
imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score
# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")
# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)
dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)
# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))
#print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# Helper functions
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}
def count_trainable_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params, 100 * trainable_params / total_params
def freeze_model_layers(model, unfreeze_pre_classifier=False):
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze LoRA and DoRA-specific params, including lora_norm
for name, param in model.named_parameters():
if (
"lora.A" in name
or "lora.B" in name
or "lora_norm" in name
or name.endswith(".m") # For DoRA
or name.endswith(".m_in") # For DDoRA
or name.endswith(".m_out") # For DDoRA
or "scale" in name
):
param.requires_grad = True
# Unfreeze classifier layer (always)
for name, param in model.named_parameters():
if name.startswith("classifier."):
param.requires_grad = True
# unfreeze pre-classifier
if unfreeze_pre_classifier:
for name, param in model.named_parameters():
if name.startswith("pre_classifier."):
param.requires_grad = True
def monitor_lora_parameters(model, threshold=1e-7):
monitor = {
"A_abs_mean": [],
"B_abs_mean": [],
"A_grad_mean": [],
"B_grad_mean": [],
"lora_output_norm": [],
"B_nonzero_count": [],
}
hooks = []
for name, module in model.named_modules():
if hasattr(module, "lora") and hasattr(module.lora, "A") and hasattr(module.lora, "B"):
A_param = module.lora.A
B_param = module.lora.B
# Gradient hooks (directly on nn.Parameter)
if A_param.requires_grad:
hooks.append(A_param.register_hook(lambda grad, n=name: monitor["A_grad_mean"].append((n, grad.abs().mean().item()))))
if B_param.requires_grad:
hooks.append(B_param.register_hook(lambda grad, n=name: monitor["B_grad_mean"].append((n, grad.abs().mean().item()))))
# Forward hook for value stats
def forward_hook(mod, inp, out, n=name):
A_mean = mod.lora.A.abs().mean().item()
B_mean = mod.lora.B.abs().mean().item()
B_nnz = (mod.lora.B.abs() > threshold).sum().item()
monitor["A_abs_mean"].append((n, A_mean))
monitor["B_abs_mean"].append((n, B_mean))
monitor["B_nonzero_count"].append((n, B_nnz))
monitor["lora_output_norm"].append((n, mod.last_lora_output_norm))
hooks.append(module.register_forward_hook(forward_hook))
return hooks, monitor
LoRA¶
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.504900 | 0.335573 | 0.869600 | 0.869783 |
100 | 0.341200 | 0.318371 | 0.862400 | 0.862357 |
150 | 0.361400 | 0.307639 | 0.876000 | 0.875739 |
200 | 0.303400 | 0.296421 | 0.870400 | 0.870685 |
250 | 0.324500 | 0.276511 | 0.886400 | 0.886654 |
300 | 0.295500 | 0.261722 | 0.890400 | 0.889511 |
350 | 0.312700 | 0.276859 | 0.896800 | 0.896362 |
400 | 0.281200 | 0.253817 | 0.890400 | 0.890564 |
450 | 0.257400 | 0.244995 | 0.909600 | 0.909240 |
500 | 0.284400 | 0.264725 | 0.898400 | 0.898376 |
550 | 0.297700 | 0.238351 | 0.908800 | 0.908708 |
600 | 0.287900 | 0.237734 | 0.909600 | 0.909143 |
650 | 0.234500 | 0.224168 | 0.913600 | 0.913528 |
700 | 0.247900 | 0.222799 | 0.915200 | 0.915187 |
750 | 0.256000 | 0.227275 | 0.916000 | 0.916031 |
800 | 0.241700 | 0.282384 | 0.899200 | 0.898261 |
850 | 0.190800 | 0.242396 | 0.904000 | 0.904121 |
900 | 0.169500 | 0.258911 | 0.910400 | 0.910400 |
950 | 0.208600 | 0.260492 | 0.907200 | 0.906995 |
1000 | 0.189500 | 0.241334 | 0.913600 | 0.913649 |
1050 | 0.182400 | 0.214967 | 0.924800 | 0.924764 |
1100 | 0.192600 | 0.222719 | 0.915200 | 0.915213 |
1150 | 0.170000 | 0.231555 | 0.912000 | 0.911767 |
1200 | 0.182200 | 0.242794 | 0.916800 | 0.916928 |
1250 | 0.174700 | 0.221050 | 0.917600 | 0.917461 |
1300 | 0.187400 | 0.229759 | 0.916800 | 0.916896 |
1350 | 0.173600 | 0.224172 | 0.917600 | 0.917641 |
1400 | 0.163200 | 0.224013 | 0.920800 | 0.920870 |
1450 | 0.170500 | 0.221920 | 0.920800 | 0.920806 |
1500 | 0.165600 | 0.220823 | 0.918400 | 0.918361 |
1550 | 0.171000 | 0.221188 | 0.920800 | 0.920840 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002572, |∇A|=0.000277, |∇B|=0.08185, |LoRA(x)|=4.874, B≠0=12283 distilbert.transformer.layer.0.attention.k_lin: |A|=0.1992, |B|=0.0002632, |∇A|=0.0002245, |∇B|=0.07715, |LoRA(x)|=5.663, B≠0=12283 distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002103, |∇A|=0.0005365, |∇B|=0.1825, |LoRA(x)|=4.493, B≠0=12282 distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.000215, |∇A|=0.0004854, |∇B|=0.183, |LoRA(x)|=2.415, B≠0=12283 distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002391, |∇A|=0.000489, |∇B|=0.1047, |LoRA(x)|=15.38, B≠0=49133 distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002197, |∇A|=0.0001981, |∇B|=0.1984, |LoRA(x)|=4.228, B≠0=12282 distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002441, |∇A|=0.0002117, |∇B|=0.07861, |LoRA(x)|=5.422, B≠0=12283 distilbert.transformer.layer.1.attention.k_lin: |A|=0.2008, |B|=0.0002601, |∇A|=0.0002199, |∇B|=0.07653, |LoRA(x)|=4.799, B≠0=12283 distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001886, |∇A|=0.0003676, |∇B|=0.185, |LoRA(x)|=3.285, B≠0=12282 distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002105, |∇A|=0.0004261, |∇B|=0.1959, |LoRA(x)|=2.38, B≠0=12283 distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.000237, |∇A|=0.0004655, |∇B|=0.119, |LoRA(x)|=19.11, B≠0=49132 distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002041, |∇A|=0.0001751, |∇B|=0.1963, |LoRA(x)|=2.361, B≠0=12282 distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002397, |∇A|=0.0002494, |∇B|=0.09958, |LoRA(x)|=5.855, B≠0=12283 distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002423, |∇A|=0.0002714, |∇B|=0.1016, |LoRA(x)|=5.16, B≠0=12283 distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001928, |∇A|=0.0003262, |∇B|=0.1879, |LoRA(x)|=4.02, B≠0=12282 distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0001992, |∇A|=0.0003639, |∇B|=0.1642, |LoRA(x)|=2.252, B≠0=12282 distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002306, |∇A|=0.0005194, |∇B|=0.1186, |LoRA(x)|=19.02, B≠0=49132 distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002021, |∇A|=0.0002104, |∇B|=0.1721, |LoRA(x)|=2.549, B≠0=12283 distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002554, |∇A|=0.000265, |∇B|=0.09461, |LoRA(x)|=5.758, B≠0=12283 distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002541, |∇A|=0.0002945, |∇B|=0.09943, |LoRA(x)|=6.224, B≠0=12283 distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001641, |∇A|=0.0003354, |∇B|=0.2089, |LoRA(x)|=4.087, B≠0=12281 distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.000194, |∇A|=0.0004308, |∇B|=0.14, |LoRA(x)|=2.848, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002212, |∇A|=0.0004882, |∇B|=0.07969, |LoRA(x)|=11.88, B≠0=49131 distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001834, |∇A|=0.0001973, |∇B|=0.1337, |LoRA(x)|=2.234, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002504, |∇A|=0.0001735, |∇B|=0.07399, |LoRA(x)|=8.787, B≠0=12283 distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002486, |∇A|=0.0003957, |∇B|=0.07656, |LoRA(x)|=6.261, B≠0=12283 distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001708, |∇A|=0.0003348, |∇B|=0.11, |LoRA(x)|=3.692, B≠0=12281 distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0001745, |∇A|=0.000366, |∇B|=0.07387, |LoRA(x)|=2.375, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001853, |∇A|=0.0002817, |∇B|=0.04258, |LoRA(x)|=16.52, B≠0=49128 distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001656, |∇A|=0.0001065, |∇B|=0.0697, |LoRA(x)|=2.054, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002194, |∇A|=8.952e-05, |∇B|=0.02991, |LoRA(x)|=8.03, B≠0=12283 distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002024, |∇A|=0.0002036, |∇B|=0.03099, |LoRA(x)|=5.661, B≠0=12282 distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001348, |∇A|=0.0002633, |∇B|=0.04537, |LoRA(x)|=3.664, B≠0=12280 distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001793, |∇A|=0.0002507, |∇B|=0.03697, |LoRA(x)|=9.559, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.000147, |∇A|=0.0001532, |∇B|=0.01653, |LoRA(x)|=17.16, B≠0=49123 distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001564, |∇A|=5.048e-05, |∇B|=0.02768, |LoRA(x)|=5.171, B≠0=12282
Training summary¶
- Loss consistently decreases, and validation accuracy steadily improves until ~step 1050. Best observed performance: step ~1050: Accuracy: 92.48%, F1 Score: 92.48%, Val Loss: 0.214967. After that, metrics plateau or slightly fluctuate, suggesting diminishing returns or slight overfitting.
- LoRA adapts low-rank matrices 𝐴 and 𝐵 on each linear layer. ∣A∣ ~ 0.198 – 0.201 across layers (as expected, A is initialised and scaled normally). ∣B∣ ~ 0.00015 – 0.00026 - nearly 3 orders of magnitude smaller than A - expected as the B matrix starts close to zero and learns slowly. ∣LoRA(x)∣ ranges reasonably: ~2–19.
- LoRA's effective update is the product of BA. So even though A is large, B remains small early on → effective update is very low-rank and small in norm, especially during the first few dozen steps. This supports what we might be seeing in loss curves: slow or flat loss change initially. Effective Layer Utilisation, especially Layer 0–3: balanced activity across attention and FFN.
- Gradients: Gradients:∣∇A∣ ~ 0.0002–0.0005 (small but consistent updates); ∣∇B∣: larger, ranging up to 0.2, especially in v_lin and lin2 — LoRA is learning more actively in B.
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.522300 | 0.358082 | 0.854400 | 0.853802 |
100 | 0.374700 | 0.310322 | 0.862400 | 0.862513 |
150 | 0.386100 | 0.318261 | 0.869600 | 0.869504 |
200 | 0.326900 | 0.307559 | 0.869600 | 0.867806 |
250 | 0.351500 | 0.314006 | 0.877600 | 0.877870 |
300 | 0.332300 | 0.272534 | 0.892000 | 0.891881 |
350 | 0.343300 | 0.279443 | 0.893600 | 0.893502 |
400 | 0.327400 | 0.285108 | 0.890400 | 0.889991 |
450 | 0.295900 | 0.258271 | 0.898400 | 0.898208 |
500 | 0.327000 | 0.285135 | 0.883200 | 0.882071 |
550 | 0.311600 | 0.249657 | 0.904000 | 0.903745 |
600 | 0.315400 | 0.261444 | 0.902400 | 0.902096 |
650 | 0.290800 | 0.253156 | 0.896800 | 0.896605 |
700 | 0.299500 | 0.243638 | 0.900800 | 0.900537 |
750 | 0.297200 | 0.240773 | 0.908000 | 0.907932 |
800 | 0.283300 | 0.261479 | 0.903200 | 0.902599 |
850 | 0.277300 | 0.226801 | 0.907200 | 0.907317 |
900 | 0.265200 | 0.268644 | 0.904800 | 0.904983 |
950 | 0.276900 | 0.250334 | 0.909600 | 0.909370 |
1000 | 0.245300 | 0.251263 | 0.908800 | 0.908708 |
1050 | 0.241300 | 0.231262 | 0.914400 | 0.914366 |
1100 | 0.244200 | 0.239346 | 0.904000 | 0.903903 |
1150 | 0.242700 | 0.226735 | 0.908800 | 0.908741 |
1200 | 0.236600 | 0.237783 | 0.909600 | 0.909579 |
1250 | 0.220600 | 0.229475 | 0.912000 | 0.911986 |
1300 | 0.243500 | 0.234786 | 0.908800 | 0.908756 |
1350 | 0.246600 | 0.229699 | 0.905600 | 0.905699 |
1400 | 0.238500 | 0.222471 | 0.916800 | 0.916560 |
1450 | 0.221000 | 0.223475 | 0.910400 | 0.910400 |
1500 | 0.226000 | 0.222452 | 0.916000 | 0.915892 |
1550 | 0.209200 | 0.223851 | 0.915200 | 0.915145 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0002581, |∇A|=0.0003806, |∇B|=0.1261, |LoRA(x)|=5.469, B≠0=12283 distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0002605, |∇A|=0.0003199, |∇B|=0.1166, |LoRA(x)|=6.167, B≠0=12283 distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.0002149, |∇A|=0.0005936, |∇B|=0.2469, |LoRA(x)|=5.135, B≠0=12283 distilbert.transformer.layer.0.attention.out_lin: |A|=0.2004, |B|=0.0002184, |∇A|=0.0005894, |∇B|=0.2659, |LoRA(x)|=2.72, B≠0=12283 distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0002434, |∇A|=0.0006892, |∇B|=0.1599, |LoRA(x)|=17.21, B≠0=49132 distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0002225, |∇A|=0.0002834, |∇B|=0.2989, |LoRA(x)|=4.928, B≠0=12283 distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0002619, |∇A|=0.0003127, |∇B|=0.1176, |LoRA(x)|=6.326, B≠0=12283 distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0002593, |∇A|=0.0003151, |∇B|=0.1215, |LoRA(x)|=5.5, B≠0=12283 distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0001929, |∇A|=0.0004882, |∇B|=0.257, |LoRA(x)|=3.703, B≠0=12282 distilbert.transformer.layer.1.attention.out_lin: |A|=0.1985, |B|=0.0002113, |∇A|=0.0005305, |∇B|=0.2663, |LoRA(x)|=2.635, B≠0=12283 distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0002375, |∇A|=0.0006094, |∇B|=0.1708, |LoRA(x)|=21.37, B≠0=49133 distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.0002082, |∇A|=0.0002324, |∇B|=0.2649, |LoRA(x)|=2.712, B≠0=12283 distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0002454, |∇A|=0.0003173, |∇B|=0.1344, |LoRA(x)|=6.335, B≠0=12283 distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0002375, |∇A|=0.0003538, |∇B|=0.1436, |LoRA(x)|=5.793, B≠0=12283 distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0001974, |∇A|=0.0004345, |∇B|=0.2487, |LoRA(x)|=4.533, B≠0=12282 distilbert.transformer.layer.2.attention.out_lin: |A|=0.1984, |B|=0.0001982, |∇A|=0.0004565, |∇B|=0.2209, |LoRA(x)|=2.561, B≠0=12283 distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0002298, |∇A|=0.0005989, |∇B|=0.1531, |LoRA(x)|=20.14, B≠0=49131 distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002027, |∇A|=0.0002184, |∇B|=0.2143, |LoRA(x)|=2.758, B≠0=12282 distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.0002525, |∇A|=0.0002761, |∇B|=0.1089, |LoRA(x)|=6.166, B≠0=12283 distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0002438, |∇A|=0.0003425, |∇B|=0.1276, |LoRA(x)|=6.506, B≠0=12283 distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.0001687, |∇A|=0.0003466, |∇B|=0.2575, |LoRA(x)|=4.666, B≠0=12282 distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0001906, |∇A|=0.0004309, |∇B|=0.171, |LoRA(x)|=3.289, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0002211, |∇A|=0.0005016, |∇B|=0.1029, |LoRA(x)|=14.59, B≠0=49131 distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0001856, |∇A|=0.0002331, |∇B|=0.1571, |LoRA(x)|=2.627, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0002418, |∇A|=0.0001731, |∇B|=0.08101, |LoRA(x)|=10.64, B≠0=12283 distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0002293, |∇A|=0.0003568, |∇B|=0.08671, |LoRA(x)|=6.066, B≠0=12283 distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0001757, |∇A|=0.0003265, |∇B|=0.1101, |LoRA(x)|=4.307, B≠0=12282 distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.000175, |∇A|=0.0003284, |∇B|=0.07959, |LoRA(x)|=2.766, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0001862, |∇A|=0.0002638, |∇B|=0.04591, |LoRA(x)|=16.07, B≠0=49129 distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0001647, |∇A|=0.0001085, |∇B|=0.08194, |LoRA(x)|=2.308, B≠0=12281 distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0002014, |∇A|=9.653e-05, |∇B|=0.03063, |LoRA(x)|=7.295, B≠0=12282 distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002036, |∇A|=0.000199, |∇B|=0.03414, |LoRA(x)|=6.269, B≠0=12283 distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0001532, |∇A|=0.0002869, |∇B|=0.05179, |LoRA(x)|=5.935, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0001674, |∇A|=0.0001988, |∇B|=0.04984, |LoRA(x)|=8.743, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0001539, |∇A|=0.0001583, |∇B|=0.02406, |LoRA(x)|=18.93, B≠0=49125 distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0001408, |∇A|=5.189e-05, |∇B|=0.04019, |LoRA(x)|=5.108, B≠0=12281
Training summary¶
With dropout = 0.4 (applied after projection with matrix A and before the final projection with matrix B) B is forced to adapt more robustly over time - without any significant loss of accuracy or F1. In effect, dropout acts as a reguliser, adding some noise.
Key Trends in |∇B|: Dropout = 0.4 amplifies ∇B across almost every layer — sometimes by 30–100% depending on depth and layer type. This means LoRA updates are more aggressive when dropout is used. It encourages larger updates to the B matrix, probably due to the increased noise and sparsity during training. Dropout significantly boosts LoRA's ∇B gradient magnitudes which are crucial for fast adaptation, capacity and expressivity.
Initialisation of lora.A¶
Matrix A maps from high-dimensional space to low-rank: it compresses features. Its output is often followed by dropout, which adds further noise. Matrix B does the heavy lifting—bringing compressed signals back up. Since LoRA is often trained with a low learning rate and A @ B is scaled by alpha / rank, small initialisation differences in A get washed out during early training steps. Changing A's init mostly affects early training (learning speed, initial activation variance), but doesn’t shift the final performance much. On the contrary, a poorly initialised B sends noise into the frozen model's internal representations. For this reason we initialise B as nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev). Another alternative, all zeroes (self.B = nn.Parameter(torch.zeros(rank, out_dim))) often works well but not always as symmetric initialisation can create a lot of harm.
Kaiming for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Kaiming initialisation for A
std_dev2 = math.sqrt(2 / (in_dim + out_dim))
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev2)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.641200 | 0.498864 | 0.825600 | 0.825801 |
100 | 0.398900 | 0.378827 | 0.832000 | 0.831737 |
150 | 0.342200 | 0.294475 | 0.880000 | 0.880212 |
200 | 0.277900 | 0.277569 | 0.884800 | 0.884265 |
250 | 0.299900 | 0.280815 | 0.888000 | 0.888238 |
300 | 0.267600 | 0.269878 | 0.896800 | 0.896389 |
350 | 0.277600 | 0.252492 | 0.896800 | 0.896903 |
400 | 0.263900 | 0.252141 | 0.903200 | 0.903285 |
450 | 0.251700 | 0.254847 | 0.900000 | 0.900150 |
500 | 0.270300 | 0.241287 | 0.904000 | 0.903867 |
550 | 0.253100 | 0.240126 | 0.901600 | 0.901454 |
600 | 0.271100 | 0.235365 | 0.904000 | 0.903921 |
650 | 0.241800 | 0.243780 | 0.903200 | 0.903353 |
700 | 0.233400 | 0.235918 | 0.905600 | 0.905600 |
750 | 0.235500 | 0.234417 | 0.907200 | 0.907186 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.02909, |B|=0.0002327, |∇A|=0.0002211, |∇B|=0.007083, |LoRA(x)|=0.6601, B≠0=12282 distilbert.transformer.layer.0.attention.k_lin: |A|=0.02878, |B|=0.0002298, |∇A|=0.0001474, |∇B|=0.007112, |LoRA(x)|=0.7828, B≠0=12281 distilbert.transformer.layer.0.attention.v_lin: |A|=0.02878, |B|=0.0002287, |∇A|=0.0008088, |∇B|=0.01975, |LoRA(x)|=0.7497, B≠0=12282 distilbert.transformer.layer.0.attention.out_lin: |A|=0.02893, |B|=0.0002272, |∇A|=0.001325, |∇B|=0.02585, |LoRA(x)|=0.4959, B≠0=12282 distilbert.transformer.layer.0.ffn.lin1: |A|=0.01825, |B|=0.000217, |∇A|=0.0006296, |∇B|=0.009524, |LoRA(x)|=2.611, B≠0=49126 distilbert.transformer.layer.0.ffn.lin2: |A|=0.01823, |B|=0.0002204, |∇A|=0.0003984, |∇B|=0.01871, |LoRA(x)|=0.5402, B≠0=12282 distilbert.transformer.layer.1.attention.q_lin: |A|=0.02888, |B|=0.0002127, |∇A|=0.000235, |∇B|=0.008725, |LoRA(x)|=0.8818, B≠0=12280 distilbert.transformer.layer.1.attention.k_lin: |A|=0.02899, |B|=0.0002166, |∇A|=0.000186, |∇B|=0.007929, |LoRA(x)|=0.5567, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.02856, |B|=0.000197, |∇A|=0.0005595, |∇B|=0.03967, |LoRA(x)|=0.7959, B≠0=12281 distilbert.transformer.layer.1.attention.out_lin: |A|=0.02866, |B|=0.0002057, |∇A|=0.00109, |∇B|=0.03448, |LoRA(x)|=0.5205, B≠0=12282 distilbert.transformer.layer.1.ffn.lin1: |A|=0.0181, |B|=0.0002256, |∇A|=0.0005969, |∇B|=0.01394, |LoRA(x)|=4.284, B≠0=49126 distilbert.transformer.layer.1.ffn.lin2: |A|=0.01816, |B|=0.0001935, |∇A|=0.0003225, |∇B|=0.01989, |LoRA(x)|=0.2153, B≠0=12281 distilbert.transformer.layer.2.attention.q_lin: |A|=0.02859, |B|=0.0002196, |∇A|=0.0002355, |∇B|=0.01189, |LoRA(x)|=1.088, B≠0=12282 distilbert.transformer.layer.2.attention.k_lin: |A|=0.02892, |B|=0.0002278, |∇A|=0.0002871, |∇B|=0.01122, |LoRA(x)|=0.6679, B≠0=12282 distilbert.transformer.layer.2.attention.v_lin: |A|=0.02891, |B|=0.000196, |∇A|=0.0007276, |∇B|=0.03889, |LoRA(x)|=0.7448, B≠0=12281 distilbert.transformer.layer.2.attention.out_lin: |A|=0.02864, |B|=0.0002144, |∇A|=0.001153, |∇B|=0.03337, |LoRA(x)|=0.521, B≠0=12282 distilbert.transformer.layer.2.ffn.lin1: |A|=0.01813, |B|=0.0002415, |∇A|=0.001008, |∇B|=0.01763, |LoRA(x)|=5.14, B≠0=49126 distilbert.transformer.layer.2.ffn.lin2: |A|=0.01825, |B|=0.0002205, |∇A|=0.0005423, |∇B|=0.02212, |LoRA(x)|=0.3106, B≠0=12282 distilbert.transformer.layer.3.attention.q_lin: |A|=0.02898, |B|=0.0002311, |∇A|=0.000282, |∇B|=0.009444, |LoRA(x)|=1.016, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.02904, |B|=0.0002356, |∇A|=0.0002989, |∇B|=0.009808, |LoRA(x)|=0.9602, B≠0=12282 distilbert.transformer.layer.3.attention.v_lin: |A|=0.02898, |B|=0.0002056, |∇A|=0.0007878, |∇B|=0.05867, |LoRA(x)|=1.151, B≠0=12282 distilbert.transformer.layer.3.attention.out_lin: |A|=0.02881, |B|=0.0002671, |∇A|=0.002097, |∇B|=0.02614, |LoRA(x)|=0.4825, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.01824, |B|=0.0002449, |∇A|=0.001385, |∇B|=0.01429, |LoRA(x)|=3.554, B≠0=49127 distilbert.transformer.layer.3.ffn.lin2: |A|=0.01825, |B|=0.0002515, |∇A|=0.0008056, |∇B|=0.02044, |LoRA(x)|=0.3727, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.02899, |B|=0.0002635, |∇A|=0.000508, |∇B|=0.01414, |LoRA(x)|=2.408, B≠0=12282 distilbert.transformer.layer.4.attention.k_lin: |A|=0.02889, |B|=0.00025, |∇A|=0.000996, |∇B|=0.01001, |LoRA(x)|=1.021, B≠0=12282 distilbert.transformer.layer.4.attention.v_lin: |A|=0.02878, |B|=0.0002282, |∇A|=0.00123, |∇B|=0.03556, |LoRA(x)|=1.406, B≠0=12282 distilbert.transformer.layer.4.attention.out_lin: |A|=0.02887, |B|=0.000246, |∇A|=0.001401, |∇B|=0.0175, |LoRA(x)|=0.645, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.01833, |B|=0.0002607, |∇A|=0.0008604, |∇B|=0.009056, |LoRA(x)|=5.717, B≠0=49129 distilbert.transformer.layer.4.ffn.lin2: |A|=0.01822, |B|=0.0002721, |∇A|=0.0004477, |∇B|=0.01619, |LoRA(x)|=0.7457, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.02888, |B|=0.0002728, |∇A|=0.0002167, |∇B|=0.00796, |LoRA(x)|=1.832, B≠0=12282 distilbert.transformer.layer.5.attention.k_lin: |A|=0.02903, |B|=0.0002349, |∇A|=0.0007984, |∇B|=0.004929, |LoRA(x)|=2.066, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.02869, |B|=0.0002458, |∇A|=0.001274, |∇B|=0.01634, |LoRA(x)|=1.781, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.02881, |B|=0.0003804, |∇A|=0.001541, |∇B|=0.008762, |LoRA(x)|=1.76, B≠0=12283 distilbert.transformer.layer.5.ffn.lin1: |A|=0.01824, |B|=0.0002594, |∇A|=0.001008, |∇B|=0.003997, |LoRA(x)|=6.319, B≠0=49127 distilbert.transformer.layer.5.ffn.lin2: |A|=0.01822, |B|=0.0003608, |∇A|=0.0004907, |∇B|=0.005384, |LoRA(x)|=1.169, B≠0=12283
Scaled Xavier for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Scaled Xavier initialisation for A
sc_f = 1.5
bound = math.sqrt(6 / (rank + in_dim)) * sc_f
self.A = nn.Parameter(torch.empty(in_dim, rank).uniform_(-bound, bound))
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.559900 | 0.342404 | 0.867200 | 0.866935 |
100 | 0.335600 | 0.298127 | 0.878400 | 0.878565 |
150 | 0.330100 | 0.269384 | 0.892000 | 0.891881 |
200 | 0.256100 | 0.272999 | 0.888800 | 0.888328 |
250 | 0.297100 | 0.247388 | 0.900800 | 0.900662 |
300 | 0.267600 | 0.260106 | 0.892000 | 0.891542 |
350 | 0.270000 | 0.245957 | 0.905600 | 0.905600 |
400 | 0.255400 | 0.242078 | 0.904800 | 0.904934 |
450 | 0.244500 | 0.241970 | 0.912000 | 0.911895 |
500 | 0.257700 | 0.235908 | 0.908800 | 0.908494 |
550 | 0.241300 | 0.237024 | 0.910400 | 0.910055 |
600 | 0.257700 | 0.227201 | 0.913600 | 0.913390 |
650 | 0.231800 | 0.231505 | 0.910400 | 0.910538 |
700 | 0.227300 | 0.224408 | 0.918400 | 0.918237 |
750 | 0.222500 | 0.222870 | 0.916000 | 0.915892 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.06532, |B|=0.0002035, |∇A|=0.0001663, |∇B|=0.01523, |LoRA(x)|=1.255, B≠0=12281 distilbert.transformer.layer.0.attention.k_lin: |A|=0.06534, |B|=0.0002003, |∇A|=0.0001101, |∇B|=0.01507, |LoRA(x)|=1.249, B≠0=12281 distilbert.transformer.layer.0.attention.v_lin: |A|=0.06566, |B|=0.0001929, |∇A|=0.0006117, |∇B|=0.03535, |LoRA(x)|=1.217, B≠0=12281 distilbert.transformer.layer.0.attention.out_lin: |A|=0.0661, |B|=0.0001711, |∇A|=0.0007501, |∇B|=0.04806, |LoRA(x)|=0.6202, B≠0=12281 distilbert.transformer.layer.0.ffn.lin1: |A|=0.06538, |B|=0.0001835, |∇A|=0.0004738, |∇B|=0.0315, |LoRA(x)|=6.167, B≠0=49123 distilbert.transformer.layer.0.ffn.lin2: |A|=0.03304, |B|=0.0001718, |∇A|=0.000289, |∇B|=0.02706, |LoRA(x)|=0.4766, B≠0=12280 distilbert.transformer.layer.1.attention.q_lin: |A|=0.06618, |B|=0.0001879, |∇A|=0.0001798, |∇B|=0.01696, |LoRA(x)|=1.24, B≠0=12280 distilbert.transformer.layer.1.attention.k_lin: |A|=0.06591, |B|=0.000191, |∇A|=0.0001451, |∇B|=0.01575, |LoRA(x)|=0.9638, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.06548, |B|=0.0001601, |∇A|=0.0004498, |∇B|=0.06403, |LoRA(x)|=0.9895, B≠0=12280 distilbert.transformer.layer.1.attention.out_lin: |A|=0.06559, |B|=0.0001658, |∇A|=0.0006209, |∇B|=0.05379, |LoRA(x)|=0.5784, B≠0=12281 distilbert.transformer.layer.1.ffn.lin1: |A|=0.06538, |B|=0.0001922, |∇A|=0.0004596, |∇B|=0.0344, |LoRA(x)|=8.039, B≠0=49123 distilbert.transformer.layer.1.ffn.lin2: |A|=0.03304, |B|=0.0001619, |∇A|=0.0002367, |∇B|=0.0329, |LoRA(x)|=0.297, B≠0=12280 distilbert.transformer.layer.2.attention.q_lin: |A|=0.06613, |B|=0.0001945, |∇A|=0.0002159, |∇B|=0.0249, |LoRA(x)|=1.65, B≠0=12281 distilbert.transformer.layer.2.attention.k_lin: |A|=0.06571, |B|=0.0001978, |∇A|=0.0002439, |∇B|=0.02332, |LoRA(x)|=1.206, B≠0=12281 distilbert.transformer.layer.2.attention.v_lin: |A|=0.06583, |B|=0.0001605, |∇A|=0.0005263, |∇B|=0.06787, |LoRA(x)|=0.9769, B≠0=12280 distilbert.transformer.layer.2.attention.out_lin: |A|=0.06489, |B|=0.0001754, |∇A|=0.0007919, |∇B|=0.04553, |LoRA(x)|=0.5327, B≠0=12280 distilbert.transformer.layer.2.ffn.lin1: |A|=0.06513, |B|=0.0001948, |∇A|=0.0006043, |∇B|=0.04044, |LoRA(x)|=7.488, B≠0=49124 distilbert.transformer.layer.2.ffn.lin2: |A|=0.03309, |B|=0.0001835, |∇A|=0.0004721, |∇B|=0.02781, |LoRA(x)|=0.3463, B≠0=12281 distilbert.transformer.layer.3.attention.q_lin: |A|=0.066, |B|=0.0002053, |∇A|=0.0002105, |∇B|=0.02061, |LoRA(x)|=2.099, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.06551, |B|=0.0002091, |∇A|=0.0002618, |∇B|=0.02079, |LoRA(x)|=1.555, B≠0=12281 distilbert.transformer.layer.3.attention.v_lin: |A|=0.06551, |B|=0.0001497, |∇A|=0.0004883, |∇B|=0.08872, |LoRA(x)|=1.46, B≠0=12280 distilbert.transformer.layer.3.attention.out_lin: |A|=0.06552, |B|=0.0001923, |∇A|=0.0008196, |∇B|=0.03743, |LoRA(x)|=0.5979, B≠0=12281 distilbert.transformer.layer.3.ffn.lin1: |A|=0.06555, |B|=0.000184, |∇A|=0.0006847, |∇B|=0.04193, |LoRA(x)|=8.602, B≠0=49123 distilbert.transformer.layer.3.ffn.lin2: |A|=0.03308, |B|=0.0001783, |∇A|=0.0004434, |∇B|=0.0232, |LoRA(x)|=0.3174, B≠0=12280 distilbert.transformer.layer.4.attention.q_lin: |A|=0.06642, |B|=0.0002137, |∇A|=0.0003046, |∇B|=0.02595, |LoRA(x)|=3.838, B≠0=12281 distilbert.transformer.layer.4.attention.k_lin: |A|=0.06605, |B|=0.0002044, |∇A|=0.000532, |∇B|=0.01982, |LoRA(x)|=2.01, B≠0=12281 distilbert.transformer.layer.4.attention.v_lin: |A|=0.0658, |B|=0.0001578, |∇A|=0.0006522, |∇B|=0.06364, |LoRA(x)|=1.776, B≠0=12280 distilbert.transformer.layer.4.attention.out_lin: |A|=0.0653, |B|=0.0001886, |∇A|=0.0008771, |∇B|=0.02987, |LoRA(x)|=0.784, B≠0=12281 distilbert.transformer.layer.4.ffn.lin1: |A|=0.06552, |B|=0.0002044, |∇A|=0.0006782, |∇B|=0.02294, |LoRA(x)|=9.608, B≠0=49125 distilbert.transformer.layer.4.ffn.lin2: |A|=0.03295, |B|=0.0002195, |∇A|=0.0004506, |∇B|=0.01567, |LoRA(x)|=0.4364, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.06579, |B|=0.000201, |∇A|=0.000149, |∇B|=0.01128, |LoRA(x)|=1.89, B≠0=12281 distilbert.transformer.layer.5.attention.k_lin: |A|=0.06566, |B|=0.0001923, |∇A|=0.0005654, |∇B|=0.01185, |LoRA(x)|=2.068, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.06563, |B|=0.0001716, |∇A|=0.0005953, |∇B|=0.03601, |LoRA(x)|=2.725, B≠0=12281 distilbert.transformer.layer.5.attention.out_lin: |A|=0.0654, |B|=0.0002853, |∇A|=0.0008672, |∇B|=0.01711, |LoRA(x)|=2.471, B≠0=12282 distilbert.transformer.layer.5.ffn.lin1: |A|=0.0657, |B|=0.0001973, |∇A|=0.0004453, |∇B|=0.009803, |LoRA(x)|=11.89, B≠0=49123 distilbert.transformer.layer.5.ffn.lin2: |A|=0.03308, |B|=0.0002221, |∇A|=0.0001933, |∇B|=0.008399, |LoRA(x)|=1.134, B≠0=12281
Orthogonal initialisation for A¶
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
#self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
# Orthogonal initialisation for A
self.A = nn.Parameter(torch.empty(in_dim, rank))
nn.init.orthogonal_(self.A)
self.B = nn.Parameter(1e-6 * torch.randn(rank, out_dim) * std_dev) # Not all zeroes!
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.0
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=1,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
#eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
#print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
#print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
#print('Parameter Statistics: mean.abs()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(name, param.abs().mean().item())
#print('Parameter Statistics: param.norm()')
#for name, param in model_lora_all_attn.named_parameters():
# if "lora" in name:
# print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.631800 | 0.476356 | 0.832000 | 0.832194 |
100 | 0.386100 | 0.353101 | 0.848000 | 0.848142 |
150 | 0.338900 | 0.293442 | 0.880800 | 0.881030 |
200 | 0.273900 | 0.276142 | 0.885600 | 0.885174 |
250 | 0.299300 | 0.277328 | 0.892000 | 0.892224 |
300 | 0.267800 | 0.270556 | 0.896000 | 0.895489 |
350 | 0.275800 | 0.252276 | 0.899200 | 0.899257 |
400 | 0.261600 | 0.252071 | 0.902400 | 0.902492 |
450 | 0.248400 | 0.254714 | 0.898400 | 0.898543 |
500 | 0.268600 | 0.241609 | 0.907200 | 0.907034 |
550 | 0.251900 | 0.238318 | 0.904800 | 0.904695 |
600 | 0.268800 | 0.233752 | 0.906400 | 0.906314 |
650 | 0.239300 | 0.242723 | 0.902400 | 0.902550 |
700 | 0.233100 | 0.234716 | 0.906400 | 0.906393 |
750 | 0.233200 | 0.233564 | 0.904800 | 0.904793 |
distilbert.transformer.layer.0.attention.q_lin: |A|=0.02894, |B|=0.0002256, |∇A|=0.0002051, |∇B|=0.00704, |LoRA(x)|=0.6222, B≠0=12282 distilbert.transformer.layer.0.attention.k_lin: |A|=0.0289, |B|=0.0002237, |∇A|=0.0001388, |∇B|=0.007028, |LoRA(x)|=0.7385, B≠0=12282 distilbert.transformer.layer.0.attention.v_lin: |A|=0.02868, |B|=0.0002192, |∇A|=0.0007466, |∇B|=0.01938, |LoRA(x)|=0.6943, B≠0=12281 distilbert.transformer.layer.0.attention.out_lin: |A|=0.02873, |B|=0.0002161, |∇A|=0.001234, |∇B|=0.02534, |LoRA(x)|=0.4499, B≠0=12281 distilbert.transformer.layer.0.ffn.lin1: |A|=0.02887, |B|=0.0002077, |∇A|=0.0005898, |∇B|=0.01437, |LoRA(x)|=3.549, B≠0=49125 distilbert.transformer.layer.0.ffn.lin2: |A|=0.01439, |B|=0.0002112, |∇A|=0.0003794, |∇B|=0.0146, |LoRA(x)|=0.4078, B≠0=12281 distilbert.transformer.layer.1.attention.q_lin: |A|=0.02871, |B|=0.0002067, |∇A|=0.0002277, |∇B|=0.008702, |LoRA(x)|=0.8418, B≠0=12281 distilbert.transformer.layer.1.attention.k_lin: |A|=0.02878, |B|=0.0002133, |∇A|=0.0001856, |∇B|=0.007924, |LoRA(x)|=0.5599, B≠0=12281 distilbert.transformer.layer.1.attention.v_lin: |A|=0.02885, |B|=0.0001863, |∇A|=0.0005139, |∇B|=0.03951, |LoRA(x)|=0.7551, B≠0=12281 distilbert.transformer.layer.1.attention.out_lin: |A|=0.02884, |B|=0.0001969, |∇A|=0.001035, |∇B|=0.03354, |LoRA(x)|=0.4694, B≠0=12281 distilbert.transformer.layer.1.ffn.lin1: |A|=0.02882, |B|=0.0002174, |∇A|=0.0005914, |∇B|=0.02101, |LoRA(x)|=5.848, B≠0=49126 distilbert.transformer.layer.1.ffn.lin2: |A|=0.01441, |B|=0.0001863, |∇A|=0.0003113, |∇B|=0.01551, |LoRA(x)|=0.1606, B≠0=12281 distilbert.transformer.layer.2.attention.q_lin: |A|=0.02863, |B|=0.0002147, |∇A|=0.0002332, |∇B|=0.01161, |LoRA(x)|=0.9744, B≠0=12281 distilbert.transformer.layer.2.attention.k_lin: |A|=0.02885, |B|=0.0002233, |∇A|=0.000279, |∇B|=0.01124, |LoRA(x)|=0.6466, B≠0=12282 distilbert.transformer.layer.2.attention.v_lin: |A|=0.02879, |B|=0.0001866, |∇A|=0.0006717, |∇B|=0.03784, |LoRA(x)|=0.6731, B≠0=12281 distilbert.transformer.layer.2.attention.out_lin: |A|=0.02879, |B|=0.0002054, |∇A|=0.001122, |∇B|=0.03079, |LoRA(x)|=0.4415, B≠0=12281 distilbert.transformer.layer.2.ffn.lin1: |A|=0.02879, |B|=0.0002269, |∇A|=0.0008912, |∇B|=0.02662, |LoRA(x)|=7.34, B≠0=49126 distilbert.transformer.layer.2.ffn.lin2: |A|=0.01442, |B|=0.0002097, |∇A|=0.0005021, |∇B|=0.01674, |LoRA(x)|=0.2238, B≠0=12281 distilbert.transformer.layer.3.attention.q_lin: |A|=0.02891, |B|=0.000226, |∇A|=0.0002603, |∇B|=0.009309, |LoRA(x)|=0.9878, B≠0=12281 distilbert.transformer.layer.3.attention.k_lin: |A|=0.02882, |B|=0.0002275, |∇A|=0.000281, |∇B|=0.009663, |LoRA(x)|=0.885, B≠0=12282 distilbert.transformer.layer.3.attention.v_lin: |A|=0.02887, |B|=0.0001937, |∇A|=0.0007371, |∇B|=0.0562, |LoRA(x)|=1.023, B≠0=12281 distilbert.transformer.layer.3.attention.out_lin: |A|=0.02891, |B|=0.0002549, |∇A|=0.001971, |∇B|=0.02536, |LoRA(x)|=0.4443, B≠0=12282 distilbert.transformer.layer.3.ffn.lin1: |A|=0.02888, |B|=0.0002292, |∇A|=0.001239, |∇B|=0.02104, |LoRA(x)|=4.993, B≠0=49126 distilbert.transformer.layer.3.ffn.lin2: |A|=0.0144, |B|=0.0002404, |∇A|=0.0007727, |∇B|=0.01577, |LoRA(x)|=0.2704, B≠0=12282 distilbert.transformer.layer.4.attention.q_lin: |A|=0.02871, |B|=0.0002524, |∇A|=0.000489, |∇B|=0.01352, |LoRA(x)|=2.283, B≠0=12282 distilbert.transformer.layer.4.attention.k_lin: |A|=0.02884, |B|=0.0002416, |∇A|=0.0009936, |∇B|=0.009785, |LoRA(x)|=0.9456, B≠0=12281 distilbert.transformer.layer.4.attention.v_lin: |A|=0.02885, |B|=0.0002123, |∇A|=0.00103, |∇B|=0.0365, |LoRA(x)|=1.395, B≠0=12281 distilbert.transformer.layer.4.attention.out_lin: |A|=0.0288, |B|=0.000233, |∇A|=0.001299, |∇B|=0.01674, |LoRA(x)|=0.5857, B≠0=12282 distilbert.transformer.layer.4.ffn.lin1: |A|=0.02887, |B|=0.0002406, |∇A|=0.0007501, |∇B|=0.01257, |LoRA(x)|=7.437, B≠0=49128 distilbert.transformer.layer.4.ffn.lin2: |A|=0.01443, |B|=0.0002652, |∇A|=0.0004402, |∇B|=0.01286, |LoRA(x)|=0.5918, B≠0=12282 distilbert.transformer.layer.5.attention.q_lin: |A|=0.02885, |B|=0.0002558, |∇A|=0.0001962, |∇B|=0.007697, |LoRA(x)|=1.709, B≠0=12281 distilbert.transformer.layer.5.attention.k_lin: |A|=0.02884, |B|=0.0002236, |∇A|=0.0007342, |∇B|=0.004787, |LoRA(x)|=1.879, B≠0=12281 distilbert.transformer.layer.5.attention.v_lin: |A|=0.02887, |B|=0.0002386, |∇A|=0.001155, |∇B|=0.01626, |LoRA(x)|=1.73, B≠0=12282 distilbert.transformer.layer.5.attention.out_lin: |A|=0.02883, |B|=0.0003673, |∇A|=0.001476, |∇B|=0.008586, |LoRA(x)|=1.781, B≠0=12283 distilbert.transformer.layer.5.ffn.lin1: |A|=0.02882, |B|=0.0002423, |∇A|=0.0007944, |∇B|=0.005601, |LoRA(x)|=9.013, B≠0=49127 distilbert.transformer.layer.5.ffn.lin2: |A|=0.0144, |B|=0.0003482, |∇A|=0.0004651, |∇B|=0.004582, |LoRA(x)|=0.9978, B≠0=12283
Increase std.dev for B and apply dropout to boost lora.B adaptation¶
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
torch.autograd.set_detect_anomaly(True)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha, dropout_rate=0.0):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = nn.Parameter(1e-3 * torch.randn(rank, out_dim) * std_dev) # Increase std.dev for B
self.alpha = alpha
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# Dropout applied to the projection to the lower-dimensional space by A
dropped = self.dropout(x @ self.A)
return self.alpha * (dropped @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
self.last_lora_output_norm = 0.0 # for monitoring
def forward(self, x):
#return self.linear(x) + self.lora(x)
lora_out = self.lora(x)
self.last_lora_output_norm = lora_out.norm(p=2, dim=-1).mean().item()
return self.linear(x) + lora_out
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha, dropout_rate=0.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
# target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
lora_linear.lora.dropout = nn.Dropout(dropout_rate)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
dropout = 0.4
learning_rate = 1.5e-5
lora_rank = 16
lora_alpha = 32
weight_decay = 1e-5 # L2
batch_size = 32
output_dir_prefix = "finetuned-imdb-"
import copy
from transformers import TrainingArguments
torch.manual_seed(137)
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha, dropout)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
#print("\nTrainable parameters after freezing:")
#for name, param in model_lora_all_attn.named_parameters():
# if param.requires_grad:
# print(name)
eval_steps = 50
logging_steps = 50
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=2,
#max_steps=200,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
max_grad_norm=1.0, ##################
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(
model=model_lora_all_attn,
args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
compute_metrics=compute_metrics,
)
hooks, monitor = monitor_lora_parameters(trainer_lora_all_attn.model)
#Train!
trainer_lora_all_attn.train()
eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print (torch.cuda.memory_summary())
for hook in hooks:
hook.remove()
# Aggregate/log after training
from collections import defaultdict
agg = defaultdict(list)
for key, vals in monitor.items():
grouped = defaultdict(list)
for name, val in vals:
grouped[name].append(val)
agg[key] = {name: sum(vs)/len(vs) for name, vs in grouped.items()}
for name in agg["A_abs_mean"]:
print(f"{name}: |A|={agg['A_abs_mean'][name]:.4g}, |B|={agg['B_abs_mean'][name]:.4g}, "
f"|∇A|={agg['A_grad_mean'][name]:.4g}, |∇B|={agg['B_grad_mean'][name]:.4g}, "
f"|LoRA(x)|={agg['lora_output_norm'][name]:.4g}, B≠0={agg['B_nonzero_count'][name]:.0f}")
print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
LoRA (All Attention) - Total parameters: 68,282,114 LoRA (All Attention) - Trainable parameters: 1,919,234 (2.81%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.572400 | 0.361935 | 0.853600 | 0.853897 |
100 | 0.395200 | 0.314839 | 0.871200 | 0.871393 |
150 | 0.388700 | 0.324395 | 0.872000 | 0.872105 |
200 | 0.359600 | 0.288238 | 0.875200 | 0.875027 |
250 | 0.375300 | 0.355832 | 0.841600 | 0.841525 |
300 | 0.348600 | 0.279084 | 0.888000 | 0.887753 |
350 | 0.360400 | 0.307965 | 0.874400 | 0.874675 |
400 | 0.331700 | 0.285904 | 0.896800 | 0.897028 |
450 | 0.313200 | 0.311799 | 0.895200 | 0.895123 |
500 | 0.353700 | 0.291399 | 0.883200 | 0.882492 |
550 | 0.313900 | 0.274373 | 0.881600 | 0.881711 |
600 | 0.344400 | 0.289618 | 0.887200 | 0.886285 |
650 | 0.311900 | 0.266004 | 0.895200 | 0.895175 |
700 | 0.313900 | 0.260952 | 0.904800 | 0.904677 |
750 | 0.311800 | 0.276892 | 0.896800 | 0.896441 |
800 | 0.331900 | 0.266435 | 0.900000 | 0.899438 |
850 | 0.289800 | 0.267684 | 0.902400 | 0.902573 |
900 | 0.283600 | 0.294683 | 0.899200 | 0.899422 |
950 | 0.307300 | 0.265540 | 0.906400 | 0.906297 |
1000 | 0.270400 | 0.288483 | 0.900000 | 0.900111 |
1050 | 0.266700 | 0.258305 | 0.907200 | 0.906888 |
1100 | 0.276200 | 0.250781 | 0.905600 | 0.905555 |
1150 | 0.253400 | 0.249464 | 0.902400 | 0.902573 |
1200 | 0.278300 | 0.245663 | 0.912800 | 0.912704 |
1250 | 0.260800 | 0.250558 | 0.904800 | 0.904659 |
1300 | 0.266500 | 0.247603 | 0.907200 | 0.907308 |
1350 | 0.260500 | 0.247938 | 0.904800 | 0.904915 |
1400 | 0.250800 | 0.251896 | 0.904000 | 0.903701 |
1450 | 0.252900 | 0.245024 | 0.907200 | 0.907171 |
1500 | 0.255200 | 0.245073 | 0.904800 | 0.904639 |
1550 | 0.242100 | 0.244082 | 0.908800 | 0.908786 |
LoRA (All Attention) Test Results: {'eval_loss': 0.2366224229335785, 'eval_accuracy': 0.907621052631579, 'eval_f1': 0.9076122672728953, 'eval_runtime': 184.1411, 'eval_samples_per_second': 128.977, 'eval_steps_per_second': 4.035, 'epoch': 2.0} |===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 563068 KiB | 4397 MiB | 232710 GiB | 232710 GiB | | from large pool | 546990 KiB | 4344 MiB | 230922 GiB | 230922 GiB | | from small pool | 16078 KiB | 55 MiB | 1788 GiB | 1788 GiB | |---------------------------------------------------------------------------| | Active memory | 563068 KiB | 4397 MiB | 232710 GiB | 232710 GiB | | from large pool | 546990 KiB | 4344 MiB | 230922 GiB | 230922 GiB | | from small pool | 16078 KiB | 55 MiB | 1788 GiB | 1788 GiB | |---------------------------------------------------------------------------| | Requested memory | 559912 KiB | 4394 MiB | 232500 GiB | 232499 GiB | | from large pool | 543836 KiB | 4340 MiB | 230718 GiB | 230718 GiB | | from small pool | 16076 KiB | 55 MiB | 1781 GiB | 1781 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 4510 MiB | 4580 MiB | 23456 MiB | 18946 MiB | | from large pool | 4450 MiB | 4520 MiB | 23204 MiB | 18754 MiB | | from small pool | 60 MiB | 62 MiB | 252 MiB | 192 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 182404 KiB | 382935 KiB | 79215 GiB | 79215 GiB | | from large pool | 175954 KiB | 376402 KiB | 77291 GiB | 77290 GiB | | from small pool | 6450 KiB | 30566 KiB | 1924 GiB | 1924 GiB | |---------------------------------------------------------------------------| | Allocations | 436 | 623 | 30226 K | 30225 K | | from large pool | 82 | 152 | 6772 K | 6772 K | | from small pool | 354 | 515 | 23453 K | 23453 K | |---------------------------------------------------------------------------| | Active allocs | 436 | 623 | 30226 K | 30225 K | | from large pool | 82 | 152 | 6772 K | 6772 K | | from small pool | 354 | 515 | 23453 K | 23453 K | |---------------------------------------------------------------------------| | GPU reserved segments | 97 | 101 | 431 | 334 | | from large pool | 67 | 71 | 305 | 238 | | from small pool | 30 | 31 | 126 | 96 | |---------------------------------------------------------------------------| | Non-releasable allocs | 46 | 73 | 16377 K | 16377 K | | from large pool | 16 | 22 | 2502 K | 2502 K | | from small pool | 30 | 57 | 13874 K | 13874 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================| distilbert.transformer.layer.0.attention.q_lin: |A|=0.2014, |B|=0.0003389, |∇A|=0.000502, |∇B|=0.158, |LoRA(x)|=6.894, B≠0=12285 distilbert.transformer.layer.0.attention.k_lin: |A|=0.1991, |B|=0.0003358, |∇A|=0.0004346, |∇B|=0.1471, |LoRA(x)|=7.695, B≠0=12286 distilbert.transformer.layer.0.attention.v_lin: |A|=0.1993, |B|=0.000301, |∇A|=0.0007981, |∇B|=0.2967, |LoRA(x)|=6.75, B≠0=12285 distilbert.transformer.layer.0.attention.out_lin: |A|=0.2003, |B|=0.0003029, |∇A|=0.0007932, |∇B|=0.321, |LoRA(x)|=3.656, B≠0=12285 distilbert.transformer.layer.0.ffn.lin1: |A|=0.1998, |B|=0.0003268, |∇A|=0.0009316, |∇B|=0.1899, |LoRA(x)|=21.53, B≠0=49143 distilbert.transformer.layer.0.ffn.lin2: |A|=0.1996, |B|=0.0003015, |∇A|=0.0003772, |∇B|=0.3445, |LoRA(x)|=6.206, B≠0=12286 distilbert.transformer.layer.1.attention.q_lin: |A|=0.2, |B|=0.0003353, |∇A|=0.0003761, |∇B|=0.1395, |LoRA(x)|=7.866, B≠0=12285 distilbert.transformer.layer.1.attention.k_lin: |A|=0.2007, |B|=0.0003366, |∇A|=0.0004358, |∇B|=0.1484, |LoRA(x)|=6.999, B≠0=12285 distilbert.transformer.layer.1.attention.v_lin: |A|=0.1978, |B|=0.0002858, |∇A|=0.0006746, |∇B|=0.2917, |LoRA(x)|=5.321, B≠0=12285 distilbert.transformer.layer.1.attention.out_lin: |A|=0.1984, |B|=0.00029, |∇A|=0.0007089, |∇B|=0.3005, |LoRA(x)|=3.779, B≠0=12285 distilbert.transformer.layer.1.ffn.lin1: |A|=0.198, |B|=0.0003159, |∇A|=0.0007903, |∇B|=0.1846, |LoRA(x)|=23.45, B≠0=49142 distilbert.transformer.layer.1.ffn.lin2: |A|=0.1989, |B|=0.000288, |∇A|=0.0002804, |∇B|=0.2852, |LoRA(x)|=3.764, B≠0=12284 distilbert.transformer.layer.2.attention.q_lin: |A|=0.198, |B|=0.0003262, |∇A|=0.0003942, |∇B|=0.1447, |LoRA(x)|=8.057, B≠0=12286 distilbert.transformer.layer.2.attention.k_lin: |A|=0.2002, |B|=0.0003176, |∇A|=0.000448, |∇B|=0.1577, |LoRA(x)|=7.385, B≠0=12285 distilbert.transformer.layer.2.attention.v_lin: |A|=0.2002, |B|=0.0002813, |∇A|=0.0005137, |∇B|=0.2423, |LoRA(x)|=6.056, B≠0=12286 distilbert.transformer.layer.2.attention.out_lin: |A|=0.1983, |B|=0.0002805, |∇A|=0.0005144, |∇B|=0.2204, |LoRA(x)|=3.703, B≠0=12285 distilbert.transformer.layer.2.ffn.lin1: |A|=0.1982, |B|=0.0003084, |∇A|=0.0006538, |∇B|=0.15, |LoRA(x)|=24.32, B≠0=49141 distilbert.transformer.layer.2.ffn.lin2: |A|=0.1999, |B|=0.0002843, |∇A|=0.000218, |∇B|=0.2078, |LoRA(x)|=3.736, B≠0=12285 distilbert.transformer.layer.3.attention.q_lin: |A|=0.2006, |B|=0.000324, |∇A|=0.0002934, |∇B|=0.1105, |LoRA(x)|=8.299, B≠0=12285 distilbert.transformer.layer.3.attention.k_lin: |A|=0.2011, |B|=0.0003226, |∇A|=0.0003681, |∇B|=0.1199, |LoRA(x)|=7.896, B≠0=12285 distilbert.transformer.layer.3.attention.v_lin: |A|=0.2007, |B|=0.00026, |∇A|=0.0003734, |∇B|=0.2297, |LoRA(x)|=6.205, B≠0=12285 distilbert.transformer.layer.3.attention.out_lin: |A|=0.1995, |B|=0.0002658, |∇A|=0.0004406, |∇B|=0.1566, |LoRA(x)|=4.282, B≠0=12285 distilbert.transformer.layer.3.ffn.lin1: |A|=0.1995, |B|=0.0003003, |∇A|=0.0005145, |∇B|=0.08879, |LoRA(x)|=18.41, B≠0=49141 distilbert.transformer.layer.3.ffn.lin2: |A|=0.1998, |B|=0.0002639, |∇A|=0.0001617, |∇B|=0.1338, |LoRA(x)|=3.407, B≠0=12285 distilbert.transformer.layer.4.attention.q_lin: |A|=0.2008, |B|=0.0003261, |∇A|=0.0001543, |∇B|=0.06833, |LoRA(x)|=12.96, B≠0=12284 distilbert.transformer.layer.4.attention.k_lin: |A|=0.2, |B|=0.0003134, |∇A|=0.0003658, |∇B|=0.07867, |LoRA(x)|=8.107, B≠0=12286 distilbert.transformer.layer.4.attention.v_lin: |A|=0.1993, |B|=0.0002635, |∇A|=0.0003247, |∇B|=0.09568, |LoRA(x)|=6.022, B≠0=12285 distilbert.transformer.layer.4.attention.out_lin: |A|=0.1999, |B|=0.0002577, |∇A|=0.0002548, |∇B|=0.0689, |LoRA(x)|=4.138, B≠0=12285 distilbert.transformer.layer.4.ffn.lin1: |A|=0.2004, |B|=0.0002738, |∇A|=0.0002338, |∇B|=0.03957, |LoRA(x)|=21.57, B≠0=49141 distilbert.transformer.layer.4.ffn.lin2: |A|=0.1994, |B|=0.0002625, |∇A|=0.0001001, |∇B|=0.06575, |LoRA(x)|=4.007, B≠0=12284 distilbert.transformer.layer.5.attention.q_lin: |A|=0.1998, |B|=0.0003016, |∇A|=8.459e-05, |∇B|=0.02227, |LoRA(x)|=9.447, B≠0=12286 distilbert.transformer.layer.5.attention.k_lin: |A|=0.201, |B|=0.0002925, |∇A|=0.0001426, |∇B|=0.02499, |LoRA(x)|=7.739, B≠0=12285 distilbert.transformer.layer.5.attention.v_lin: |A|=0.1987, |B|=0.0002508, |∇A|=0.0002601, |∇B|=0.04984, |LoRA(x)|=8.357, B≠0=12285 distilbert.transformer.layer.5.attention.out_lin: |A|=0.1994, |B|=0.0002483, |∇A|=0.000202, |∇B|=0.0489, |LoRA(x)|=15.56, B≠0=12286 distilbert.transformer.layer.5.ffn.lin1: |A|=0.1993, |B|=0.0002525, |∇A|=0.0001525, |∇B|=0.01987, |LoRA(x)|=20.78, B≠0=49138 distilbert.transformer.layer.5.ffn.lin2: |A|=0.1994, |B|=0.0002402, |∇A|=4.512e-05, |∇B|=0.03451, |LoRA(x)|=5.875, B≠0=12285 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20136898756027222 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00036163683398626745 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.199139803647995 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.0003591857384890318 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.199319526553154 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.00031731155468150973 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20034024119377136 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00031905301148071885 distilbert.transformer.layer.0.ffn.lin1.lora.A 0.19975820183753967 distilbert.transformer.layer.0.ffn.lin1.lora.B 0.00035009029670618474 distilbert.transformer.layer.0.ffn.lin2.lora.A 0.19959013164043427 distilbert.transformer.layer.0.ffn.lin2.lora.B 0.0003168959519825876 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1999898999929428 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.00035791145637631416 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.20074164867401123 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.0003592821885831654 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.19777165353298187 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030186009826138616 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.19843807816505432 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.0003039248113054782 distilbert.transformer.layer.1.ffn.lin1.lora.A 0.1979658454656601 distilbert.transformer.layer.1.ffn.lin1.lora.B 0.00033694933517836034 distilbert.transformer.layer.1.ffn.lin2.lora.A 0.19888746738433838 distilbert.transformer.layer.1.ffn.lin2.lora.B 0.0003027324564754963 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19796311855316162 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00034728783066384494 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.20021361112594604 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00033590427483431995 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.20018713176250458 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00029516470385715365 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19833998382091522 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.00029389417613856494 distilbert.transformer.layer.2.ffn.lin1.lora.A 0.1982342004776001 distilbert.transformer.layer.2.ffn.lin1.lora.B 0.0003286522696726024 distilbert.transformer.layer.2.ffn.lin2.lora.A 0.19985038042068481 distilbert.transformer.layer.2.ffn.lin2.lora.B 0.000297279329970479 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.20064640045166016 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00034552914439700544 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.20107726752758026 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00034139861236326396 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.20070061087608337 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00027113681426271796 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.1994575560092926 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.0002771209110505879 distilbert.transformer.layer.3.ffn.lin1.lora.A 0.19948390126228333 distilbert.transformer.layer.3.ffn.lin1.lora.B 0.00031956000020727515 distilbert.transformer.layer.3.ffn.lin2.lora.A 0.1997573971748352 distilbert.transformer.layer.3.ffn.lin2.lora.B 0.00027455651434138417 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.20080135762691498 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.00035053043393418193 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.19999122619628906 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00033398327650502324 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.199321448802948 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.0002780442591756582 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19989004731178284 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00027118308935314417 distilbert.transformer.layer.4.ffn.lin1.lora.A 0.20036581158638 distilbert.transformer.layer.4.ffn.lin1.lora.B 0.00029024441028013825 distilbert.transformer.layer.4.ffn.lin2.lora.A 0.19942979514598846 distilbert.transformer.layer.4.ffn.lin2.lora.B 0.0002758408372756094 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.19984808564186096 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00032156246015802026 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.20102792978286743 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003097845474258065 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19865462183952332 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002626449568197131 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.1993870586156845 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0002558119304012507 distilbert.transformer.layer.5.ffn.lin1.lora.A 0.19927702844142914 distilbert.transformer.layer.5.ffn.lin1.lora.B 0.0002630744711495936 distilbert.transformer.layer.5.ffn.lin2.lora.A 0.19940851628780365 distilbert.transformer.layer.5.ffn.lin2.lora.B 0.0002476528752595186 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8493 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0506 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.5909 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0500 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.8039 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0444 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.9052 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0442 distilbert.transformer.layer.0.ffn.lin1.lora.A weight norm: 27.7016 distilbert.transformer.layer.0.ffn.lin1.lora.B weight norm: 0.0976 distilbert.transformer.layer.0.ffn.lin2.lora.A weight norm: 55.5094 distilbert.transformer.layer.0.ffn.lin2.lora.B weight norm: 0.0440 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8651 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0501 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.9079 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0503 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.4143 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0421 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.5450 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0424 distilbert.transformer.layer.1.ffn.lin1.lora.A weight norm: 27.5116 distilbert.transformer.layer.1.ffn.lin1.lora.B weight norm: 0.0939 distilbert.transformer.layer.1.ffn.lin2.lora.A weight norm: 55.2299 distilbert.transformer.layer.1.ffn.lin2.lora.B weight norm: 0.0421 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.6563 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0481 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.7654 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0468 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.8159 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0414 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.5538 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0409 distilbert.transformer.layer.2.ffn.lin1.lora.A weight norm: 27.5591 distilbert.transformer.layer.2.ffn.lin1.lora.B weight norm: 0.0917 distilbert.transformer.layer.2.ffn.lin2.lora.A weight norm: 55.4738 distilbert.transformer.layer.2.ffn.lin2.lora.B weight norm: 0.0416 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.7736 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0482 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.9000 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0476 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.7956 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0376 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.6487 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0386 distilbert.transformer.layer.3.ffn.lin1.lora.A weight norm: 27.6616 distilbert.transformer.layer.3.ffn.lin1.lora.B weight norm: 0.0894 distilbert.transformer.layer.3.ffn.lin2.lora.A weight norm: 55.5141 distilbert.transformer.layer.3.ffn.lin2.lora.B weight norm: 0.0384 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.9841 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0488 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.7779 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0466 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.6690 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0386 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.7600 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0378 distilbert.transformer.layer.4.ffn.lin1.lora.A weight norm: 27.7930 distilbert.transformer.layer.4.ffn.lin1.lora.B weight norm: 0.0813 distilbert.transformer.layer.4.ffn.lin2.lora.A weight norm: 55.3399 distilbert.transformer.layer.4.ffn.lin2.lora.B weight norm: 0.0385 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.7641 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0448 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.9006 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0432 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.5460 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0366 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7014 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0356 distilbert.transformer.layer.5.ffn.lin1.lora.A weight norm: 27.6943 distilbert.transformer.layer.5.ffn.lin1.lora.B weight norm: 0.0736 distilbert.transformer.layer.5.ffn.lin2.lora.A weight norm: 55.4635 distilbert.transformer.layer.5.ffn.lin2.lora.B weight norm: 0.0346
Key Observations from the Dump¶
- Despite small absolute values (e.g., |B| ≈ 0.00025–0.00033), the gradients |∇B| are significantly larger, especially in early layers: |∇B| ≈ 0.15–0.34 in layers 0–2; |∇B| drops gradually deeper → ~0.02 at layer 5. This suggests strong early gradient flow into B, tapering off with depth. LoRA is adapting more in early layers, less in deep ones.
- |LoRA(x)| ranges are relatively large despite tiny B. This is desired: LoRA is learning meaningful low-rank deltas.
- Instead of B growing, we are seeing sharp, effective updates guided by ∇B — especially with dropout pushing adaptation.
- Althouhg LoRA.B values remain small (|B| ≈ 0.00025–0.00033), this is not a sign of stagnation/ undertraining — the signal is directional: LoRA module effectively computes: ΔW=α⋅AB where A has relatively large values (directional basis) and B contains very small coefficients (activation selectors). The direction of AB matters much more than the norm, because alpha rescales the final output. Here, the norm of AB is suppressed during training, while its orientation in parameter space becomes useful for model adaptation. Small B encourages better generalization. A large B would effectively cause large ΔW, risking overfitting—defeating the point of parameter-efficient fine-tuning.