In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(torch.cuda.get_device_name(0))
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import bitsandbytes
import peft
import transformers
print(transformers.__version__)
print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)] PyTorch Version: 2.5.1+cu121 True 1 CUDA Version: 12.1 NVIDIA GeForce RTX 4080 Laptop GPU 4.50.0.dev0 bitsandbytes version: 0.45.3 peft version: 0.15.2.dev0 True
Load dataset, base model, and tokeniser¶
In [2]:
from datasets import load_dataset
imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score
# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")
# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)
dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)
# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))
print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0-5): 6 x TransformerBlock( (attention): DistilBertSdpaAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=2, bias=True) (dropout): Dropout(p=0.2, inplace=False) )
In [3]:
torch.autograd.set_detect_anomaly(True)
# Define the performance metrics
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}
LoRA¶
In [4]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# Modify model by injecting LoRA and DoRA layers
# Borrowed from: https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = nn.Parameter(torch.zeros(rank, out_dim)) #### all zeroes!
self.alpha = alpha
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
class LinearWithLoRA(nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)
def forward(self, x):
return self.linear(x) + self.lora(x)
class LinearWithDoRA(nn.Module):
def __init__(self, linear, rank, alpha, scaling_factor=1.0):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.m = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
#self.scale = nn.Parameter(torch.tensor(float(scaling_factor)))
self.scale = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))
def forward(self, x):
linear_output = self.linear(x)
lora_output = self.lora(x)
lora_output_norm = lora_output / (lora_output.norm(p=2, dim=-1, keepdim=True) + 1e-9)
dora_modification = self.scale * self.m * lora_output_norm
return linear_output + dora_modification
In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
lora_linear = LinearWithLoRA(original_linear, rank, alpha)
setattr(parent_module, name.split('.')[-1], lora_linear)
return model
# Function to inject DoRA into specified linear layers
def inject_dora_all_attn(model, rank, alpha, scaling_factor=1.0):
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
dora_linear = LinearWithDoRA(original_linear, rank, alpha, scaling_factor)
setattr(parent_module, name.split('.')[-1], dora_linear)
return model
def count_trainable_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params, 100 * trainable_params / total_params
def freeze_model_layers(model, unfreeze_pre_classifier=False):
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze LoRA and DoRA-specific params
for name, param in model.named_parameters():
if (
"lora.A" in name
or "lora.B" in name
or name.endswith(".m") # For DoRA
or name.endswith(".m_in") # For DDoRA
or name.endswith(".m_out") # For DDoRA
or "scale" in name
):
param.requires_grad = True
# Unfreeze classifier layer (always)
for name, param in model.named_parameters():
if name.startswith("classifier."):
param.requires_grad = True
# unfreeze pre-classifier
if unfreeze_pre_classifier:
for name, param in model.named_parameters():
if name.startswith("pre_classifier."):
param.requires_grad = True
In [6]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 1.5e-5 ##########
weight_decay = 0.0
output_dir_prefix = "finetuned-imdb-"
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_lora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}lora-all-attn",
num_train_epochs=5,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
report_to="none",
log_level="error"
)
trainer_lora_all_attn = Trainer(model=model_lora_all_attn, args=training_args_lora_all_attn,
train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
LoRA (All Attention) - Total parameters: 67,544,834 LoRA (All Attention) - Trainable parameters: 1,181,954 (1.75%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [7]:
trainer_lora_all_attn.train()
eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print('Prediction drift')
outputs = trainer_lora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print (torch.cuda.memory_summary())
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_lora_all_attn.named_parameters():
if "lora.B" in name:
layer_names.append(name)
b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
[3910/3910 1:29:10, Epoch 5/5]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.549600 | 0.334510 | 0.864800 | 0.864919 |
100 | 0.363100 | 0.314867 | 0.873600 | 0.873637 |
150 | 0.368100 | 0.283882 | 0.884000 | 0.884009 |
200 | 0.289200 | 0.270112 | 0.887200 | 0.887403 |
250 | 0.304900 | 0.246406 | 0.903200 | 0.903162 |
300 | 0.304000 | 0.256413 | 0.904800 | 0.904319 |
350 | 0.313300 | 0.257387 | 0.896000 | 0.896121 |
400 | 0.275000 | 0.360744 | 0.857600 | 0.857483 |
450 | 0.278100 | 0.277102 | 0.896800 | 0.897004 |
500 | 0.305700 | 0.243079 | 0.904000 | 0.904014 |
550 | 0.285300 | 0.258897 | 0.897600 | 0.897417 |
600 | 0.303100 | 0.258372 | 0.900800 | 0.900313 |
650 | 0.269200 | 0.246777 | 0.901600 | 0.901607 |
700 | 0.260800 | 0.241955 | 0.904000 | 0.903848 |
750 | 0.258400 | 0.243109 | 0.909600 | 0.909285 |
800 | 0.282400 | 0.338753 | 0.866400 | 0.863283 |
850 | 0.237000 | 0.239093 | 0.908800 | 0.908864 |
900 | 0.217300 | 0.261071 | 0.907200 | 0.907265 |
950 | 0.231800 | 0.263062 | 0.904000 | 0.903630 |
1000 | 0.229900 | 0.261102 | 0.907200 | 0.907155 |
1050 | 0.221300 | 0.245763 | 0.904000 | 0.903360 |
1100 | 0.218000 | 0.285926 | 0.899200 | 0.898366 |
1150 | 0.208400 | 0.238600 | 0.906400 | 0.906363 |
1200 | 0.211900 | 0.268987 | 0.906400 | 0.905927 |
1250 | 0.231800 | 0.293073 | 0.892000 | 0.892213 |
1300 | 0.231000 | 0.240379 | 0.904800 | 0.904965 |
1350 | 0.232700 | 0.251226 | 0.903200 | 0.903375 |
1400 | 0.231500 | 0.236147 | 0.908800 | 0.908771 |
1450 | 0.216500 | 0.243060 | 0.906400 | 0.906119 |
1500 | 0.218400 | 0.235548 | 0.908000 | 0.907745 |
1550 | 0.223300 | 0.237270 | 0.908800 | 0.908579 |
1600 | 0.174400 | 0.258777 | 0.914400 | 0.914256 |
1650 | 0.156200 | 0.266737 | 0.911200 | 0.910846 |
1700 | 0.168200 | 0.254983 | 0.914400 | 0.914337 |
1750 | 0.184600 | 0.265700 | 0.902400 | 0.902585 |
1800 | 0.182500 | 0.258778 | 0.907200 | 0.907287 |
1850 | 0.184000 | 0.251654 | 0.912800 | 0.912877 |
1900 | 0.154400 | 0.277561 | 0.906400 | 0.906447 |
1950 | 0.192000 | 0.241895 | 0.903200 | 0.903296 |
2000 | 0.185900 | 0.254578 | 0.913600 | 0.913789 |
2050 | 0.197500 | 0.240309 | 0.902400 | 0.902245 |
2100 | 0.192700 | 0.260683 | 0.901600 | 0.901103 |
2150 | 0.147100 | 0.262304 | 0.908000 | 0.908081 |
2200 | 0.183600 | 0.251887 | 0.910400 | 0.910439 |
2250 | 0.207200 | 0.251087 | 0.909600 | 0.909750 |
2300 | 0.196100 | 0.249030 | 0.900800 | 0.901003 |
2350 | 0.172600 | 0.248261 | 0.910400 | 0.910426 |
2400 | 0.146400 | 0.240490 | 0.919200 | 0.919168 |
2450 | 0.145300 | 0.248667 | 0.908800 | 0.908598 |
2500 | 0.108100 | 0.305569 | 0.899200 | 0.899355 |
2550 | 0.140500 | 0.277428 | 0.910400 | 0.910142 |
2600 | 0.152600 | 0.284274 | 0.916800 | 0.916913 |
2650 | 0.141600 | 0.252135 | 0.916800 | 0.916731 |
2700 | 0.133600 | 0.268836 | 0.912000 | 0.911958 |
2750 | 0.151300 | 0.262349 | 0.914400 | 0.913991 |
2800 | 0.121400 | 0.255012 | 0.917600 | 0.917494 |
2850 | 0.131700 | 0.269369 | 0.913600 | 0.913222 |
2900 | 0.152800 | 0.246611 | 0.916800 | 0.916716 |
2950 | 0.132300 | 0.251266 | 0.908000 | 0.907724 |
3000 | 0.142600 | 0.265384 | 0.918400 | 0.918532 |
3050 | 0.139300 | 0.260968 | 0.909600 | 0.909483 |
3100 | 0.137700 | 0.303503 | 0.904000 | 0.903330 |
3150 | 0.121600 | 0.266256 | 0.920000 | 0.920056 |
3200 | 0.101800 | 0.283502 | 0.911200 | 0.911288 |
3250 | 0.090600 | 0.293355 | 0.912800 | 0.912635 |
3300 | 0.112900 | 0.290351 | 0.912000 | 0.911958 |
3350 | 0.086300 | 0.298528 | 0.916000 | 0.915907 |
3400 | 0.086100 | 0.310346 | 0.915200 | 0.915280 |
3450 | 0.093600 | 0.306871 | 0.914400 | 0.914352 |
3500 | 0.107200 | 0.312322 | 0.917600 | 0.917606 |
3550 | 0.121300 | 0.315530 | 0.912800 | 0.912844 |
3600 | 0.117800 | 0.316997 | 0.915200 | 0.914956 |
3650 | 0.133400 | 0.311634 | 0.911200 | 0.911179 |
3700 | 0.107600 | 0.316201 | 0.915200 | 0.915248 |
3750 | 0.107800 | 0.312200 | 0.916000 | 0.916019 |
3800 | 0.086900 | 0.314815 | 0.912800 | 0.912844 |
3850 | 0.105500 | 0.314706 | 0.913600 | 0.913600 |
3900 | 0.069900 | 0.314938 | 0.912800 | 0.912780 |
LoRA (All Attention) Test Results: {'eval_loss': 0.22863808274269104, 'eval_accuracy': 0.9094315789473684, 'eval_f1': 0.9093754564531454, 'eval_runtime': 129.5483, 'eval_samples_per_second': 183.329, 'eval_steps_per_second': 5.735, 'epoch': 5.0} Prediction drift [[-1.9898635 1.6761417] [-2.4649115 2.2287219] [ 2.1762567 -1.8627577] [-1.866952 1.787821 ] [ 2.8843796 -2.5472975]] [1 1 0 1 0] |===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 553486 KiB | 3293 MiB | 90884 GiB | 90883 GiB | | from large pool | 546048 KiB | 3264 MiB | 90258 GiB | 90258 GiB | | from small pool | 7438 KiB | 31 MiB | 625 GiB | 625 GiB | |---------------------------------------------------------------------------| | Active memory | 553486 KiB | 3293 MiB | 90884 GiB | 90883 GiB | | from large pool | 546048 KiB | 3264 MiB | 90258 GiB | 90258 GiB | | from small pool | 7438 KiB | 31 MiB | 625 GiB | 625 GiB | |---------------------------------------------------------------------------| | Requested memory | 551272 KiB | 3290 MiB | 90677 GiB | 90676 GiB | | from large pool | 543836 KiB | 3260 MiB | 90054 GiB | 90053 GiB | | from small pool | 7436 KiB | 31 MiB | 623 GiB | 623 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 3514 MiB | 3514 MiB | 3514 MiB | 0 B | | from large pool | 3476 MiB | 3476 MiB | 3476 MiB | 0 B | | from small pool | 38 MiB | 38 MiB | 38 MiB | 0 B | |---------------------------------------------------------------------------| | Non-releasable memory | 67058 KiB | 308581 KiB | 40038 GiB | 40038 GiB | | from large pool | 64256 KiB | 300928 KiB | 39368 GiB | 39368 GiB | | from small pool | 2802 KiB | 9018 KiB | 670 GiB | 670 GiB | |---------------------------------------------------------------------------| | Allocations | 364 | 491 | 10562 K | 10562 K | | from large pool | 82 | 140 | 3162 K | 3162 K | | from small pool | 282 | 395 | 7400 K | 7400 K | |---------------------------------------------------------------------------| | Active allocs | 364 | 491 | 10562 K | 10562 K | | from large pool | 82 | 140 | 3162 K | 3162 K | | from small pool | 282 | 395 | 7400 K | 7400 K | |---------------------------------------------------------------------------| | GPU reserved segments | 79 | 79 | 79 | 0 | | from large pool | 60 | 60 | 60 | 0 | | from small pool | 19 | 19 | 19 | 0 | |---------------------------------------------------------------------------| | Non-releasable allocs | 27 | 44 | 4813 K | 4813 K | | from large pool | 18 | 23 | 935 K | 935 K | | from small pool | 9 | 23 | 3878 K | 3878 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================| LoRA Heatmap
In [8]:
print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20137116312980652 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00041380099719390273 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.2003246247768402 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.00042878554086200893 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.19911907613277435 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.000338834710419178 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20047757029533386 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00034657015930861235 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1993226706981659 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.0004094692412763834 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.1983720064163208 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.000427080609370023 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.20035219192504883 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030109973158687353 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.20086750388145447 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.00034795209649018943 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19976946711540222 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00040535288280807436 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.19660572707653046 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00038226376636885107 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.19878578186035156 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00027373715420253575 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19986116886138916 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.0003197303449269384 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.1998608410358429 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00037932227132841945 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.2004813402891159 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00038690734072588384 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.2002096176147461 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00028052262496203184 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.19950413703918457 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.00029115224606357515 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.19811344146728516 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.0004094548639841378 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.1975163221359253 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00039162003668025136 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.1999872624874115 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.000269879907136783 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19805268943309784 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00028756255051121116 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.20075345039367676 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00037421341403387487 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.19600680470466614 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003485583874862641 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19779205322265625 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002423622936476022 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.19977574050426483 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0003247036365792155
In [9]:
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
if "lora" in name or name.endswith(".m"):
print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8491 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0578 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.7708 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0597 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.5879 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0473 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.8318 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0484 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8054 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0578 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.5557 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0595 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.9069 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0428 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.9222 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0486 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.7035 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0568 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.3314 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0532 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.6290 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0388 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.6716 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0449 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.6330 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0529 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.8436 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0541 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.8198 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0395 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.7259 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0411 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.6222 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0573 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.4509 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0547 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.8650 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0380 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.6073 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0405 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.9097 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0525 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.3422 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0490 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.4174 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0341 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7203 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0450
DoRA¶
In [10]:
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 2e-2 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"
model_dora_all_attn = copy.deepcopy(model)
model_dora_all_attn = inject_dora_all_attn(model_dora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_dora_all_attn, unfreeze_pre_classifier=True)
total_params_dora, trainable_params_dora, percentage_dora = count_trainable_parameters(model_dora_all_attn)
print(f"\nDoRA (All Attention) - Total parameters: {total_params_dora:,}")
print(f"DoRA (All Attention) - Trainable parameters: {trainable_params_dora:,} ({percentage_dora:.2f}%)")
# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_dora_all_attn.named_parameters():
if param.requires_grad:
print(name)
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_dora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}dora-all-attn",
num_train_epochs=5,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
report_to="none",
log_level="error"
)
trainer_dora_all_attn = Trainer(model=model_dora_all_attn, args=training_args_dora_all_attn,
train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DoRA (All Attention) - Total parameters: 67,581,698 DoRA (All Attention) - Trainable parameters: 1,218,818 (1.80%) Trainable parameters after freezing: distilbert.transformer.layer.0.attention.q_lin.m distilbert.transformer.layer.0.attention.q_lin.scale distilbert.transformer.layer.0.attention.q_lin.lora.A distilbert.transformer.layer.0.attention.q_lin.lora.B distilbert.transformer.layer.0.attention.k_lin.m distilbert.transformer.layer.0.attention.k_lin.scale distilbert.transformer.layer.0.attention.k_lin.lora.A distilbert.transformer.layer.0.attention.k_lin.lora.B distilbert.transformer.layer.0.attention.v_lin.m distilbert.transformer.layer.0.attention.v_lin.scale distilbert.transformer.layer.0.attention.v_lin.lora.A distilbert.transformer.layer.0.attention.v_lin.lora.B distilbert.transformer.layer.0.attention.out_lin.m distilbert.transformer.layer.0.attention.out_lin.scale distilbert.transformer.layer.0.attention.out_lin.lora.A distilbert.transformer.layer.0.attention.out_lin.lora.B distilbert.transformer.layer.1.attention.q_lin.m distilbert.transformer.layer.1.attention.q_lin.scale distilbert.transformer.layer.1.attention.q_lin.lora.A distilbert.transformer.layer.1.attention.q_lin.lora.B distilbert.transformer.layer.1.attention.k_lin.m distilbert.transformer.layer.1.attention.k_lin.scale distilbert.transformer.layer.1.attention.k_lin.lora.A distilbert.transformer.layer.1.attention.k_lin.lora.B distilbert.transformer.layer.1.attention.v_lin.m distilbert.transformer.layer.1.attention.v_lin.scale distilbert.transformer.layer.1.attention.v_lin.lora.A distilbert.transformer.layer.1.attention.v_lin.lora.B distilbert.transformer.layer.1.attention.out_lin.m distilbert.transformer.layer.1.attention.out_lin.scale distilbert.transformer.layer.1.attention.out_lin.lora.A distilbert.transformer.layer.1.attention.out_lin.lora.B distilbert.transformer.layer.2.attention.q_lin.m distilbert.transformer.layer.2.attention.q_lin.scale distilbert.transformer.layer.2.attention.q_lin.lora.A distilbert.transformer.layer.2.attention.q_lin.lora.B distilbert.transformer.layer.2.attention.k_lin.m distilbert.transformer.layer.2.attention.k_lin.scale distilbert.transformer.layer.2.attention.k_lin.lora.A distilbert.transformer.layer.2.attention.k_lin.lora.B distilbert.transformer.layer.2.attention.v_lin.m distilbert.transformer.layer.2.attention.v_lin.scale distilbert.transformer.layer.2.attention.v_lin.lora.A distilbert.transformer.layer.2.attention.v_lin.lora.B distilbert.transformer.layer.2.attention.out_lin.m distilbert.transformer.layer.2.attention.out_lin.scale distilbert.transformer.layer.2.attention.out_lin.lora.A distilbert.transformer.layer.2.attention.out_lin.lora.B distilbert.transformer.layer.3.attention.q_lin.m distilbert.transformer.layer.3.attention.q_lin.scale distilbert.transformer.layer.3.attention.q_lin.lora.A distilbert.transformer.layer.3.attention.q_lin.lora.B distilbert.transformer.layer.3.attention.k_lin.m distilbert.transformer.layer.3.attention.k_lin.scale distilbert.transformer.layer.3.attention.k_lin.lora.A distilbert.transformer.layer.3.attention.k_lin.lora.B distilbert.transformer.layer.3.attention.v_lin.m distilbert.transformer.layer.3.attention.v_lin.scale distilbert.transformer.layer.3.attention.v_lin.lora.A distilbert.transformer.layer.3.attention.v_lin.lora.B distilbert.transformer.layer.3.attention.out_lin.m distilbert.transformer.layer.3.attention.out_lin.scale distilbert.transformer.layer.3.attention.out_lin.lora.A distilbert.transformer.layer.3.attention.out_lin.lora.B distilbert.transformer.layer.4.attention.q_lin.m distilbert.transformer.layer.4.attention.q_lin.scale distilbert.transformer.layer.4.attention.q_lin.lora.A distilbert.transformer.layer.4.attention.q_lin.lora.B distilbert.transformer.layer.4.attention.k_lin.m distilbert.transformer.layer.4.attention.k_lin.scale distilbert.transformer.layer.4.attention.k_lin.lora.A distilbert.transformer.layer.4.attention.k_lin.lora.B distilbert.transformer.layer.4.attention.v_lin.m distilbert.transformer.layer.4.attention.v_lin.scale distilbert.transformer.layer.4.attention.v_lin.lora.A distilbert.transformer.layer.4.attention.v_lin.lora.B distilbert.transformer.layer.4.attention.out_lin.m distilbert.transformer.layer.4.attention.out_lin.scale distilbert.transformer.layer.4.attention.out_lin.lora.A distilbert.transformer.layer.4.attention.out_lin.lora.B distilbert.transformer.layer.5.attention.q_lin.m distilbert.transformer.layer.5.attention.q_lin.scale distilbert.transformer.layer.5.attention.q_lin.lora.A distilbert.transformer.layer.5.attention.q_lin.lora.B distilbert.transformer.layer.5.attention.k_lin.m distilbert.transformer.layer.5.attention.k_lin.scale distilbert.transformer.layer.5.attention.k_lin.lora.A distilbert.transformer.layer.5.attention.k_lin.lora.B distilbert.transformer.layer.5.attention.v_lin.m distilbert.transformer.layer.5.attention.v_lin.scale distilbert.transformer.layer.5.attention.v_lin.lora.A distilbert.transformer.layer.5.attention.v_lin.lora.B distilbert.transformer.layer.5.attention.out_lin.m distilbert.transformer.layer.5.attention.out_lin.scale distilbert.transformer.layer.5.attention.out_lin.lora.A distilbert.transformer.layer.5.attention.out_lin.lora.B pre_classifier.weight pre_classifier.bias classifier.weight classifier.bias
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [11]:
trainer_dora_all_attn.train()
eval_results_dora_all_attn = trainer_dora_all_attn.evaluate(dataset_encoded["test"])
print(f"DoRA (All Attention) Test Results: {eval_results_dora_all_attn}")
print('Prediction drift')
outputs = trainer_dora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print (torch.cuda.memory_summary())
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_dora_all_attn.named_parameters():
if "lora.B" in name:
layer_names.append(name)
b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_values = []
for name, module in model_dora_all_attn.named_modules():
if isinstance(module, LinearWithDoRA):
if hasattr(module, 'm'):
m_param = module.m.detach().cpu().numpy().flatten()
m_values.extend(m_param)
m_values = np.array(m_values)
# Analyze the distribution
print(f"Mean of m values: {np.mean(m_values):.4f}")
print(f"Standard deviation of m values: {np.std(m_values):.4f}")
print(f"Minimum m value: {np.min(m_values):.4f}")
print(f"Maximum m value: {np.max(m_values):.4f}")
# Plot a histogram
plt.hist(m_values, bins=50, alpha=0.7)
plt.title('Distribution of Learned m Values (DoRA)')
plt.xlabel('Magnitude (m)')
plt.ylabel('Frequency')
plt.show()
[3910/3910 2:15:09, Epoch 5/5]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 1.102800 | 0.291960 | 0.870400 | 0.869901 |
100 | 0.356100 | 0.278253 | 0.894400 | 0.894522 |
150 | 0.343500 | 0.268145 | 0.897600 | 0.897671 |
200 | 0.305300 | 0.261925 | 0.897600 | 0.897630 |
250 | 0.321600 | 0.257692 | 0.904800 | 0.904558 |
300 | 0.305000 | 0.272489 | 0.887200 | 0.887449 |
350 | 0.300000 | 0.259093 | 0.901600 | 0.901282 |
400 | 0.266300 | 0.279098 | 0.884000 | 0.884259 |
450 | 0.258800 | 0.272474 | 0.892000 | 0.891513 |
500 | 0.288000 | 0.242344 | 0.905600 | 0.905585 |
550 | 0.280400 | 0.270589 | 0.901600 | 0.900959 |
600 | 0.301900 | 0.249609 | 0.897600 | 0.897794 |
650 | 0.277600 | 0.265940 | 0.904800 | 0.904469 |
700 | 0.278000 | 0.279490 | 0.911200 | 0.910675 |
750 | 0.293100 | 0.268084 | 0.897600 | 0.896787 |
800 | 0.254900 | 0.299735 | 0.896800 | 0.895963 |
850 | 0.253600 | 0.252459 | 0.896800 | 0.896935 |
900 | 0.235600 | 0.269197 | 0.900800 | 0.900881 |
950 | 0.271800 | 0.285552 | 0.902400 | 0.902503 |
1000 | 0.232000 | 0.225912 | 0.908800 | 0.908579 |
1050 | 0.219700 | 0.244564 | 0.914400 | 0.914337 |
1100 | 0.234700 | 0.229250 | 0.913600 | 0.913445 |
1150 | 0.217400 | 0.221898 | 0.905600 | 0.905600 |
1200 | 0.230000 | 0.237117 | 0.909600 | 0.909350 |
1250 | 0.221200 | 0.219938 | 0.912000 | 0.911958 |
1300 | 0.255900 | 0.222013 | 0.908000 | 0.908007 |
1350 | 0.260000 | 0.216425 | 0.915200 | 0.914956 |
1400 | 0.223400 | 0.217995 | 0.911200 | 0.911256 |
1450 | 0.219600 | 0.222689 | 0.912000 | 0.912000 |
1500 | 0.215300 | 0.224147 | 0.914400 | 0.914337 |
1550 | 0.227700 | 0.224390 | 0.919200 | 0.919251 |
1600 | 0.196300 | 0.254192 | 0.920800 | 0.920870 |
1650 | 0.182700 | 0.261529 | 0.916000 | 0.915823 |
1700 | 0.205700 | 0.254054 | 0.916800 | 0.916579 |
1750 | 0.199100 | 0.270258 | 0.912000 | 0.912000 |
1800 | 0.209500 | 0.224310 | 0.917600 | 0.917509 |
1850 | 0.211800 | 0.236137 | 0.915200 | 0.915200 |
1900 | 0.177900 | 0.244019 | 0.916800 | 0.916760 |
1950 | 0.204800 | 0.241636 | 0.916800 | 0.916836 |
2000 | 0.205300 | 0.228495 | 0.919200 | 0.919280 |
2050 | 0.212400 | 0.233124 | 0.918400 | 0.918435 |
2100 | 0.200900 | 0.225622 | 0.911200 | 0.911278 |
2150 | 0.153700 | 0.256996 | 0.912000 | 0.912061 |
2200 | 0.209800 | 0.220482 | 0.914400 | 0.914352 |
2250 | 0.196600 | 0.229680 | 0.916000 | 0.915858 |
2300 | 0.202500 | 0.218464 | 0.910400 | 0.910504 |
2350 | 0.181500 | 0.225591 | 0.917600 | 0.917509 |
2400 | 0.185800 | 0.250574 | 0.912000 | 0.911543 |
2450 | 0.157600 | 0.230985 | 0.915200 | 0.914915 |
2500 | 0.141200 | 0.245401 | 0.910400 | 0.910473 |
2550 | 0.136100 | 0.296957 | 0.918400 | 0.918165 |
2600 | 0.167800 | 0.249483 | 0.902400 | 0.902613 |
2650 | 0.157600 | 0.243417 | 0.912000 | 0.911705 |
2700 | 0.157700 | 0.241792 | 0.918400 | 0.918412 |
2750 | 0.158800 | 0.230028 | 0.912000 | 0.911683 |
2800 | 0.133300 | 0.229130 | 0.916800 | 0.916651 |
2850 | 0.145600 | 0.294840 | 0.917600 | 0.917251 |
2900 | 0.157700 | 0.267986 | 0.916000 | 0.915728 |
2950 | 0.150400 | 0.273056 | 0.912800 | 0.912384 |
3000 | 0.157900 | 0.240002 | 0.918400 | 0.918184 |
3050 | 0.157400 | 0.243159 | 0.920000 | 0.919919 |
3100 | 0.148800 | 0.250069 | 0.914400 | 0.913894 |
3150 | 0.124400 | 0.255732 | 0.916800 | 0.916700 |
3200 | 0.127000 | 0.246215 | 0.920800 | 0.920698 |
3250 | 0.116300 | 0.283555 | 0.912000 | 0.911786 |
3300 | 0.110700 | 0.277807 | 0.912000 | 0.911860 |
3350 | 0.112800 | 0.296765 | 0.916800 | 0.916634 |
3400 | 0.115000 | 0.259858 | 0.912800 | 0.912670 |
3450 | 0.117300 | 0.274300 | 0.912800 | 0.912598 |
3500 | 0.110300 | 0.284478 | 0.915200 | 0.915114 |
3550 | 0.123000 | 0.267429 | 0.916000 | 0.915923 |
3600 | 0.126300 | 0.266704 | 0.918400 | 0.918202 |
3650 | 0.131900 | 0.254044 | 0.916000 | 0.915892 |
3700 | 0.125600 | 0.256258 | 0.916800 | 0.916634 |
3750 | 0.109800 | 0.262411 | 0.916000 | 0.915938 |
3800 | 0.105500 | 0.270246 | 0.914400 | 0.914238 |
3850 | 0.115400 | 0.272160 | 0.914400 | 0.914238 |
3900 | 0.108500 | 0.268805 | 0.914400 | 0.914238 |
DoRA (All Attention) Test Results: {'eval_loss': 0.21464677155017853, 'eval_accuracy': 0.9129263157894737, 'eval_f1': 0.9128732729336847, 'eval_runtime': 146.8168, 'eval_samples_per_second': 161.766, 'eval_steps_per_second': 5.061, 'epoch': 5.0} Prediction drift [[-2.0174963 1.9359767] [-2.774578 2.4995706] [ 1.9803579 -2.1773782] [-1.6433766 1.4992952] [ 6.4219594 -7.323969 ]] [1 1 0 1 0] |===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 809 MiB | 5291 MiB | 230768 GiB | 230767 GiB | | from large pool | 794 MiB | 5252 MiB | 229489 GiB | 229488 GiB | | from small pool | 14 MiB | 40 MiB | 1278 GiB | 1278 GiB | |---------------------------------------------------------------------------| | Active memory | 809 MiB | 5291 MiB | 230768 GiB | 230767 GiB | | from large pool | 794 MiB | 5252 MiB | 229489 GiB | 229488 GiB | | from small pool | 14 MiB | 40 MiB | 1278 GiB | 1278 GiB | |---------------------------------------------------------------------------| | Requested memory | 805 MiB | 5287 MiB | 230551 GiB | 230550 GiB | | from large pool | 790 MiB | 5248 MiB | 229278 GiB | 229277 GiB | | from small pool | 14 MiB | 40 MiB | 1273 GiB | 1273 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 5500 MiB | 5500 MiB | 8120 MiB | 2620 MiB | | from large pool | 5456 MiB | 5456 MiB | 8048 MiB | 2592 MiB | | from small pool | 44 MiB | 44 MiB | 72 MiB | 28 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 98691 KiB | 545173 KiB | 88783 GiB | 88783 GiB | | from large pool | 91218 KiB | 540498 KiB | 87421 GiB | 87420 GiB | | from small pool | 7473 KiB | 17003 KiB | 1362 GiB | 1362 GiB | |---------------------------------------------------------------------------| | Allocations | 765 | 1012 | 26004 K | 26003 K | | from large pool | 123 | 229 | 7919 K | 7919 K | | from small pool | 642 | 851 | 18084 K | 18084 K | |---------------------------------------------------------------------------| | Active allocs | 765 | 1012 | 26004 K | 26003 K | | from large pool | 123 | 229 | 7919 K | 7919 K | | from small pool | 642 | 851 | 18084 K | 18084 K | |---------------------------------------------------------------------------| | GPU reserved segments | 134 | 134 | 181 | 47 | | from large pool | 112 | 112 | 145 | 33 | | from small pool | 22 | 22 | 36 | 14 | |---------------------------------------------------------------------------| | Non-releasable allocs | 35 | 53 | 12196 K | 12196 K | | from large pool | 17 | 26 | 2304 K | 2304 K | | from small pool | 18 | 32 | 9892 K | 9892 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================| LoRA Heatmap
Mean of m values: 0.0014 Standard deviation of m values: 0.5832 Minimum m value: -4.2395 Maximum m value: 3.2943
In [12]:
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
if "lin.m" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
if "lin.scale" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m 0.6335324048995972 distilbert.transformer.layer.0.attention.k_lin.m 0.5340768694877625 distilbert.transformer.layer.0.attention.v_lin.m 0.4119078516960144 distilbert.transformer.layer.0.attention.out_lin.m 0.20187006890773773 distilbert.transformer.layer.1.attention.q_lin.m 0.45603156089782715 distilbert.transformer.layer.1.attention.k_lin.m 0.39353370666503906 distilbert.transformer.layer.1.attention.v_lin.m 0.2410321682691574 distilbert.transformer.layer.1.attention.out_lin.m 0.22153238952159882 distilbert.transformer.layer.2.attention.q_lin.m 0.4081450402736664 distilbert.transformer.layer.2.attention.k_lin.m 0.44084256887435913 distilbert.transformer.layer.2.attention.v_lin.m 0.24717780947685242 distilbert.transformer.layer.2.attention.out_lin.m 0.16525350511074066 distilbert.transformer.layer.3.attention.q_lin.m 0.4876405596733093 distilbert.transformer.layer.3.attention.k_lin.m 0.4945552349090576 distilbert.transformer.layer.3.attention.v_lin.m 0.2434060424566269 distilbert.transformer.layer.3.attention.out_lin.m 0.24130374193191528 distilbert.transformer.layer.4.attention.q_lin.m 0.48577117919921875 distilbert.transformer.layer.4.attention.k_lin.m 0.5650743246078491 distilbert.transformer.layer.4.attention.v_lin.m 0.18068918585777283 distilbert.transformer.layer.4.attention.out_lin.m 0.2408987283706665 distilbert.transformer.layer.5.attention.q_lin.m 0.37161561846733093 distilbert.transformer.layer.5.attention.k_lin.m 0.6317863464355469 distilbert.transformer.layer.5.attention.v_lin.m 0.19014406204223633 distilbert.transformer.layer.5.attention.out_lin.m 0.28851351141929626 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale 2.162614583969116 distilbert.transformer.layer.0.attention.k_lin.scale 1.9926241636276245 distilbert.transformer.layer.0.attention.v_lin.scale 1.9132812023162842 distilbert.transformer.layer.0.attention.out_lin.scale 1.6308307647705078 distilbert.transformer.layer.1.attention.q_lin.scale 1.9580211639404297 distilbert.transformer.layer.1.attention.k_lin.scale 1.8660268783569336 distilbert.transformer.layer.1.attention.v_lin.scale 1.6165016889572144 distilbert.transformer.layer.1.attention.out_lin.scale 1.5104354619979858 distilbert.transformer.layer.2.attention.q_lin.scale 1.838735818862915 distilbert.transformer.layer.2.attention.k_lin.scale 1.8984363079071045 distilbert.transformer.layer.2.attention.v_lin.scale 1.6771743297576904 distilbert.transformer.layer.2.attention.out_lin.scale 1.53759765625 distilbert.transformer.layer.3.attention.q_lin.scale 1.9363187551498413 distilbert.transformer.layer.3.attention.k_lin.scale 1.934944987297058 distilbert.transformer.layer.3.attention.v_lin.scale 1.6585863828659058 distilbert.transformer.layer.3.attention.out_lin.scale 1.585425853729248 distilbert.transformer.layer.4.attention.q_lin.scale 2.027859687805176 distilbert.transformer.layer.4.attention.k_lin.scale 2.082413673400879 distilbert.transformer.layer.4.attention.v_lin.scale 1.5296372175216675 distilbert.transformer.layer.4.attention.out_lin.scale 1.618304967880249 distilbert.transformer.layer.5.attention.q_lin.scale 1.9490175247192383 distilbert.transformer.layer.5.attention.k_lin.scale 2.1951510906219482 distilbert.transformer.layer.5.attention.v_lin.scale 1.6163949966430664 distilbert.transformer.layer.5.attention.out_lin.scale 1.8345234394073486 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6016935110092163 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.31266412138938904 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.6316931247711182 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.2952665090560913 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.6079772710800171 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.2492581456899643 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4950897693634033 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.18595407903194427 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.6037637591362 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.2575852870941162 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5715399384498596 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.21878382563591003 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5992859601974487 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.23800623416900635 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.5236724019050598 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24507002532482147 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5646731853485107 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.28050732612609863 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5719441771507263 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2973446249961853 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.5592368841171265 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.20374992489814758 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4538516402244568 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.18519560992717743 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5853655934333801 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.30569520592689514 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.6070346832275391 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.31933528184890747 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.5087218284606934 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.21019062399864197 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.5895759463310242 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.21719345450401306 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.6171450614929199 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.21375727653503418 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.49676698446273804 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2762928605079651 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.5012800097465515 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.17582902312278748 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.5615952610969543 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.17805537581443787 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.44116735458374023 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.16018781065940857 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.47184187173843384 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.27603811025619507 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.39102858304977417 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.14989939332008362 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.4704250693321228 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.1333407163619995
In [13]:
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
if "lora" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
if "lin.m" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
if "lin.scale" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 83.6438 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 50.3954 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 88.2270 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 49.0835 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 86.0185 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 41.9394 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 69.8777 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 32.6178 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 84.0461 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 42.5906 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 79.4528 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 37.7325 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 83.9680 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 39.0889 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 74.0652 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 40.8909 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 78.3896 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 45.2058 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 79.5360 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 47.1290 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 77.7262 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 34.1224 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 63.6211 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 31.4881 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.7695 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 49.9913 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.6226 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 50.7520 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 71.5488 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 34.2226 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 82.8060 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 36.0456 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 86.1940 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 37.5884 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 69.4333 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 43.1834 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 70.5987 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 28.1566 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 79.8895 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 29.1581 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 61.4309 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 28.0646 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 65.4785 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 40.8050 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 55.7645 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 23.9943 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 66.5217 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 21.6479 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m weight norm: 24.3380 distilbert.transformer.layer.0.attention.k_lin.m weight norm: 21.5448 distilbert.transformer.layer.0.attention.v_lin.m weight norm: 16.9134 distilbert.transformer.layer.0.attention.out_lin.m weight norm: 10.2597 distilbert.transformer.layer.1.attention.q_lin.m weight norm: 18.8046 distilbert.transformer.layer.1.attention.k_lin.m weight norm: 17.2833 distilbert.transformer.layer.1.attention.v_lin.m weight norm: 10.7684 distilbert.transformer.layer.1.attention.out_lin.m weight norm: 9.8165 distilbert.transformer.layer.2.attention.q_lin.m weight norm: 16.5732 distilbert.transformer.layer.2.attention.k_lin.m weight norm: 17.3353 distilbert.transformer.layer.2.attention.v_lin.m weight norm: 11.1499 distilbert.transformer.layer.2.attention.out_lin.m weight norm: 8.2443 distilbert.transformer.layer.3.attention.q_lin.m weight norm: 19.2109 distilbert.transformer.layer.3.attention.k_lin.m weight norm: 19.1746 distilbert.transformer.layer.3.attention.v_lin.m weight norm: 11.2522 distilbert.transformer.layer.3.attention.out_lin.m weight norm: 11.3760 distilbert.transformer.layer.4.attention.q_lin.m weight norm: 20.1803 distilbert.transformer.layer.4.attention.k_lin.m weight norm: 21.7132 distilbert.transformer.layer.4.attention.v_lin.m weight norm: 9.0074 distilbert.transformer.layer.4.attention.out_lin.m weight norm: 12.0543 distilbert.transformer.layer.5.attention.q_lin.m weight norm: 16.7970 distilbert.transformer.layer.5.attention.k_lin.m weight norm: 22.2302 distilbert.transformer.layer.5.attention.v_lin.m weight norm: 10.2690 distilbert.transformer.layer.5.attention.out_lin.m weight norm: 14.5015 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale weight norm: 63.7493 distilbert.transformer.layer.0.attention.k_lin.scale weight norm: 58.9528 distilbert.transformer.layer.0.attention.v_lin.scale weight norm: 55.6887 distilbert.transformer.layer.0.attention.out_lin.scale weight norm: 47.4388 distilbert.transformer.layer.1.attention.q_lin.scale weight norm: 57.3731 distilbert.transformer.layer.1.attention.k_lin.scale weight norm: 54.8432 distilbert.transformer.layer.1.attention.v_lin.scale weight norm: 47.2854 distilbert.transformer.layer.1.attention.out_lin.scale weight norm: 44.8088 distilbert.transformer.layer.2.attention.q_lin.scale weight norm: 53.9172 distilbert.transformer.layer.2.attention.k_lin.scale weight norm: 55.4214 distilbert.transformer.layer.2.attention.v_lin.scale weight norm: 48.5935 distilbert.transformer.layer.2.attention.out_lin.scale weight norm: 44.7601 distilbert.transformer.layer.3.attention.q_lin.scale weight norm: 56.9923 distilbert.transformer.layer.3.attention.k_lin.scale weight norm: 56.9395 distilbert.transformer.layer.3.attention.v_lin.scale weight norm: 48.5040 distilbert.transformer.layer.3.attention.out_lin.scale weight norm: 47.1696 distilbert.transformer.layer.4.attention.q_lin.scale weight norm: 59.5007 distilbert.transformer.layer.4.attention.k_lin.scale weight norm: 61.3537 distilbert.transformer.layer.4.attention.v_lin.scale weight norm: 45.5443 distilbert.transformer.layer.4.attention.out_lin.scale weight norm: 48.1659 distilbert.transformer.layer.5.attention.q_lin.scale weight norm: 56.5651 distilbert.transformer.layer.5.attention.k_lin.scale weight norm: 63.7824 distilbert.transformer.layer.5.attention.v_lin.scale weight norm: 47.5776 distilbert.transformer.layer.5.attention.out_lin.scale weight norm: 53.4955
DDoRA, Double DoRA¶
(Double Weight-Decomposed Low-Rank Adaptation)¶
In [14]:
class LinearWithDoubleDoRA(nn.Module):
def __init__(self, linear, rank, alpha, scaling_factor=1.0):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.m_out = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
self.m_in = nn.Parameter(torch.randn(linear.in_features, 1) * std_dev)
self.scale_out = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))
self.scale_in = nn.Parameter(torch.full((linear.in_features, 1), float(scaling_factor)))
def forward(self, x):
scaled_x = x * self.scale_in.T * self.m_in.T # Broadcasting m_in + scaling
linear_output = self.linear(x)
lora_output = self.lora(scaled_x)
lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)
dora_modification = self.scale_out * self.m_out * lora_output_norm
return linear_output + dora_modification
In [15]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
def inject_ddora_all_attn(model, rank, alpha, scaling_factor=1.0):
#target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
parent_name = name.rsplit('.', 1)[0]
parent_module = model.get_submodule(parent_name)
original_linear = getattr(parent_module, name.split('.')[-1])
dora_linear = LinearWithDoubleDoRA(original_linear, rank, alpha, scaling_factor)
setattr(parent_module, name.split('.')[-1], dora_linear)
return model
In [16]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 0.015 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"
model_ddora_all_attn = copy.deepcopy(model)
model_ddora_all_attn = inject_ddora_all_attn(model_ddora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_ddora_all_attn, unfreeze_pre_classifier=True)
total_params_ddora, trainable_params_ddora, percentage_ddora = count_trainable_parameters(model_ddora_all_attn)
print(f"\nDDoRA (All Attention) - Total parameters: {total_params_ddora:,}")
print(f"DDoRA (All Attention) - Trainable parameters: {trainable_params_ddora:,} ({percentage_ddora:.2f}%)")
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_ddora_all_attn = TrainingArguments(
output_dir=f"{output_dir_prefix}ddora-all-attn",
num_train_epochs=5,
#max_steps=100,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_steps=logging_steps,
save_steps=eval_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
disable_tqdm=False,
push_to_hub=False,
report_to="none",
log_level="error"
)
trainer_ddora_all_attn = Trainer(model=model_ddora_all_attn, args=training_args_ddora_all_attn,
train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DDoRA (All Attention) - Total parameters: 67,618,562 DDoRA (All Attention) - Trainable parameters: 1,255,682 (1.86%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [17]:
# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_ddora_all_attn.named_parameters():
if param.requires_grad:
print(name)
Trainable parameters after freezing: distilbert.transformer.layer.0.attention.q_lin.m_out distilbert.transformer.layer.0.attention.q_lin.m_in distilbert.transformer.layer.0.attention.q_lin.scale_out distilbert.transformer.layer.0.attention.q_lin.scale_in distilbert.transformer.layer.0.attention.q_lin.lora.A distilbert.transformer.layer.0.attention.q_lin.lora.B distilbert.transformer.layer.0.attention.k_lin.m_out distilbert.transformer.layer.0.attention.k_lin.m_in distilbert.transformer.layer.0.attention.k_lin.scale_out distilbert.transformer.layer.0.attention.k_lin.scale_in distilbert.transformer.layer.0.attention.k_lin.lora.A distilbert.transformer.layer.0.attention.k_lin.lora.B distilbert.transformer.layer.0.attention.v_lin.m_out distilbert.transformer.layer.0.attention.v_lin.m_in distilbert.transformer.layer.0.attention.v_lin.scale_out distilbert.transformer.layer.0.attention.v_lin.scale_in distilbert.transformer.layer.0.attention.v_lin.lora.A distilbert.transformer.layer.0.attention.v_lin.lora.B distilbert.transformer.layer.0.attention.out_lin.m_out distilbert.transformer.layer.0.attention.out_lin.m_in distilbert.transformer.layer.0.attention.out_lin.scale_out distilbert.transformer.layer.0.attention.out_lin.scale_in distilbert.transformer.layer.0.attention.out_lin.lora.A distilbert.transformer.layer.0.attention.out_lin.lora.B distilbert.transformer.layer.1.attention.q_lin.m_out distilbert.transformer.layer.1.attention.q_lin.m_in distilbert.transformer.layer.1.attention.q_lin.scale_out distilbert.transformer.layer.1.attention.q_lin.scale_in distilbert.transformer.layer.1.attention.q_lin.lora.A distilbert.transformer.layer.1.attention.q_lin.lora.B distilbert.transformer.layer.1.attention.k_lin.m_out distilbert.transformer.layer.1.attention.k_lin.m_in distilbert.transformer.layer.1.attention.k_lin.scale_out distilbert.transformer.layer.1.attention.k_lin.scale_in distilbert.transformer.layer.1.attention.k_lin.lora.A distilbert.transformer.layer.1.attention.k_lin.lora.B distilbert.transformer.layer.1.attention.v_lin.m_out distilbert.transformer.layer.1.attention.v_lin.m_in distilbert.transformer.layer.1.attention.v_lin.scale_out distilbert.transformer.layer.1.attention.v_lin.scale_in distilbert.transformer.layer.1.attention.v_lin.lora.A distilbert.transformer.layer.1.attention.v_lin.lora.B distilbert.transformer.layer.1.attention.out_lin.m_out distilbert.transformer.layer.1.attention.out_lin.m_in distilbert.transformer.layer.1.attention.out_lin.scale_out distilbert.transformer.layer.1.attention.out_lin.scale_in distilbert.transformer.layer.1.attention.out_lin.lora.A distilbert.transformer.layer.1.attention.out_lin.lora.B distilbert.transformer.layer.2.attention.q_lin.m_out distilbert.transformer.layer.2.attention.q_lin.m_in distilbert.transformer.layer.2.attention.q_lin.scale_out distilbert.transformer.layer.2.attention.q_lin.scale_in distilbert.transformer.layer.2.attention.q_lin.lora.A distilbert.transformer.layer.2.attention.q_lin.lora.B distilbert.transformer.layer.2.attention.k_lin.m_out distilbert.transformer.layer.2.attention.k_lin.m_in distilbert.transformer.layer.2.attention.k_lin.scale_out distilbert.transformer.layer.2.attention.k_lin.scale_in distilbert.transformer.layer.2.attention.k_lin.lora.A distilbert.transformer.layer.2.attention.k_lin.lora.B distilbert.transformer.layer.2.attention.v_lin.m_out distilbert.transformer.layer.2.attention.v_lin.m_in distilbert.transformer.layer.2.attention.v_lin.scale_out distilbert.transformer.layer.2.attention.v_lin.scale_in distilbert.transformer.layer.2.attention.v_lin.lora.A distilbert.transformer.layer.2.attention.v_lin.lora.B distilbert.transformer.layer.2.attention.out_lin.m_out distilbert.transformer.layer.2.attention.out_lin.m_in distilbert.transformer.layer.2.attention.out_lin.scale_out distilbert.transformer.layer.2.attention.out_lin.scale_in distilbert.transformer.layer.2.attention.out_lin.lora.A distilbert.transformer.layer.2.attention.out_lin.lora.B distilbert.transformer.layer.3.attention.q_lin.m_out distilbert.transformer.layer.3.attention.q_lin.m_in distilbert.transformer.layer.3.attention.q_lin.scale_out distilbert.transformer.layer.3.attention.q_lin.scale_in distilbert.transformer.layer.3.attention.q_lin.lora.A distilbert.transformer.layer.3.attention.q_lin.lora.B distilbert.transformer.layer.3.attention.k_lin.m_out distilbert.transformer.layer.3.attention.k_lin.m_in distilbert.transformer.layer.3.attention.k_lin.scale_out distilbert.transformer.layer.3.attention.k_lin.scale_in distilbert.transformer.layer.3.attention.k_lin.lora.A distilbert.transformer.layer.3.attention.k_lin.lora.B distilbert.transformer.layer.3.attention.v_lin.m_out distilbert.transformer.layer.3.attention.v_lin.m_in distilbert.transformer.layer.3.attention.v_lin.scale_out distilbert.transformer.layer.3.attention.v_lin.scale_in distilbert.transformer.layer.3.attention.v_lin.lora.A distilbert.transformer.layer.3.attention.v_lin.lora.B distilbert.transformer.layer.3.attention.out_lin.m_out distilbert.transformer.layer.3.attention.out_lin.m_in distilbert.transformer.layer.3.attention.out_lin.scale_out distilbert.transformer.layer.3.attention.out_lin.scale_in distilbert.transformer.layer.3.attention.out_lin.lora.A distilbert.transformer.layer.3.attention.out_lin.lora.B distilbert.transformer.layer.4.attention.q_lin.m_out distilbert.transformer.layer.4.attention.q_lin.m_in distilbert.transformer.layer.4.attention.q_lin.scale_out distilbert.transformer.layer.4.attention.q_lin.scale_in distilbert.transformer.layer.4.attention.q_lin.lora.A distilbert.transformer.layer.4.attention.q_lin.lora.B distilbert.transformer.layer.4.attention.k_lin.m_out distilbert.transformer.layer.4.attention.k_lin.m_in distilbert.transformer.layer.4.attention.k_lin.scale_out distilbert.transformer.layer.4.attention.k_lin.scale_in distilbert.transformer.layer.4.attention.k_lin.lora.A distilbert.transformer.layer.4.attention.k_lin.lora.B distilbert.transformer.layer.4.attention.v_lin.m_out distilbert.transformer.layer.4.attention.v_lin.m_in distilbert.transformer.layer.4.attention.v_lin.scale_out distilbert.transformer.layer.4.attention.v_lin.scale_in distilbert.transformer.layer.4.attention.v_lin.lora.A distilbert.transformer.layer.4.attention.v_lin.lora.B distilbert.transformer.layer.4.attention.out_lin.m_out distilbert.transformer.layer.4.attention.out_lin.m_in distilbert.transformer.layer.4.attention.out_lin.scale_out distilbert.transformer.layer.4.attention.out_lin.scale_in distilbert.transformer.layer.4.attention.out_lin.lora.A distilbert.transformer.layer.4.attention.out_lin.lora.B distilbert.transformer.layer.5.attention.q_lin.m_out distilbert.transformer.layer.5.attention.q_lin.m_in distilbert.transformer.layer.5.attention.q_lin.scale_out distilbert.transformer.layer.5.attention.q_lin.scale_in distilbert.transformer.layer.5.attention.q_lin.lora.A distilbert.transformer.layer.5.attention.q_lin.lora.B distilbert.transformer.layer.5.attention.k_lin.m_out distilbert.transformer.layer.5.attention.k_lin.m_in distilbert.transformer.layer.5.attention.k_lin.scale_out distilbert.transformer.layer.5.attention.k_lin.scale_in distilbert.transformer.layer.5.attention.k_lin.lora.A distilbert.transformer.layer.5.attention.k_lin.lora.B distilbert.transformer.layer.5.attention.v_lin.m_out distilbert.transformer.layer.5.attention.v_lin.m_in distilbert.transformer.layer.5.attention.v_lin.scale_out distilbert.transformer.layer.5.attention.v_lin.scale_in distilbert.transformer.layer.5.attention.v_lin.lora.A distilbert.transformer.layer.5.attention.v_lin.lora.B distilbert.transformer.layer.5.attention.out_lin.m_out distilbert.transformer.layer.5.attention.out_lin.m_in distilbert.transformer.layer.5.attention.out_lin.scale_out distilbert.transformer.layer.5.attention.out_lin.scale_in distilbert.transformer.layer.5.attention.out_lin.lora.A distilbert.transformer.layer.5.attention.out_lin.lora.B pre_classifier.weight pre_classifier.bias classifier.weight classifier.bias
In [18]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
if "lora.B" in name:
layer_names.append(name)
b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
if isinstance(module, LinearWithDoubleDoRA):
if hasattr(module, 'm_out'):
m_out_param = module.m_out.detach().cpu().numpy().flatten()
m_out_values.extend(m_out_param)
if hasattr(module, 'm_in'):
m_in_param = module.m_in.detach().cpu().numpy().flatten()
m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
[3910/3910 2:50:30, Epoch 5/5]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.770700 | 0.286238 | 0.872000 | 0.871847 |
100 | 0.349700 | 0.288429 | 0.880000 | 0.879626 |
150 | 0.364200 | 0.300074 | 0.872800 | 0.873083 |
200 | 0.288800 | 0.273452 | 0.883200 | 0.883296 |
250 | 0.299600 | 0.251706 | 0.898400 | 0.897942 |
300 | 0.296600 | 0.246586 | 0.898400 | 0.898229 |
350 | 0.294200 | 0.256276 | 0.902400 | 0.902205 |
400 | 0.267100 | 0.248215 | 0.896800 | 0.896852 |
450 | 0.259600 | 0.261862 | 0.902400 | 0.902072 |
500 | 0.286100 | 0.339048 | 0.903200 | 0.903249 |
550 | 0.280100 | 0.271318 | 0.905600 | 0.905391 |
600 | 0.296700 | 0.254167 | 0.913600 | 0.913544 |
650 | 0.248500 | 0.248038 | 0.909600 | 0.909448 |
700 | 0.272000 | 0.240011 | 0.910400 | 0.910121 |
750 | 0.256800 | 0.272073 | 0.900800 | 0.900581 |
800 | 0.244800 | 0.336190 | 0.899200 | 0.898152 |
850 | 0.232100 | 0.244281 | 0.904000 | 0.903788 |
900 | 0.217700 | 0.275926 | 0.913600 | 0.913151 |
950 | 0.255400 | 0.276984 | 0.910400 | 0.910473 |
1000 | 0.223600 | 0.227863 | 0.909600 | 0.909143 |
1050 | 0.223100 | 0.244880 | 0.909600 | 0.909593 |
1100 | 0.229500 | 0.250599 | 0.912800 | 0.912539 |
1150 | 0.203900 | 0.231664 | 0.911200 | 0.911013 |
1200 | 0.240400 | 0.231101 | 0.911200 | 0.911179 |
1250 | 0.246100 | 0.229517 | 0.914400 | 0.914454 |
1300 | 0.248700 | 0.231470 | 0.912800 | 0.912765 |
1350 | 0.243300 | 0.221713 | 0.910400 | 0.910400 |
1400 | 0.236300 | 0.219190 | 0.912800 | 0.912704 |
1450 | 0.219200 | 0.231880 | 0.913600 | 0.913625 |
1500 | 0.212400 | 0.220595 | 0.912800 | 0.912559 |
1550 | 0.212400 | 0.239685 | 0.913600 | 0.913613 |
1600 | 0.185500 | 0.284434 | 0.899200 | 0.899420 |
1650 | 0.180000 | 0.298952 | 0.914400 | 0.914183 |
1700 | 0.200800 | 0.263503 | 0.914400 | 0.914443 |
1750 | 0.197100 | 0.274289 | 0.912000 | 0.912083 |
1800 | 0.204200 | 0.241341 | 0.904800 | 0.904905 |
1850 | 0.201200 | 0.239144 | 0.911200 | 0.911220 |
1900 | 0.187300 | 0.246986 | 0.915200 | 0.915130 |
1950 | 0.199200 | 0.275895 | 0.904800 | 0.905001 |
2000 | 0.202000 | 0.238767 | 0.917600 | 0.917682 |
2050 | 0.197800 | 0.252750 | 0.917600 | 0.917272 |
2100 | 0.199800 | 0.244306 | 0.915200 | 0.915099 |
2150 | 0.151500 | 0.297675 | 0.920000 | 0.919948 |
2200 | 0.204700 | 0.226266 | 0.913600 | 0.913681 |
2250 | 0.206400 | 0.223971 | 0.920000 | 0.919873 |
2300 | 0.191300 | 0.228316 | 0.919200 | 0.919206 |
2350 | 0.177600 | 0.251830 | 0.921600 | 0.921337 |
2400 | 0.188400 | 0.279854 | 0.918400 | 0.918220 |
2450 | 0.149600 | 0.237880 | 0.915200 | 0.915082 |
2500 | 0.148100 | 0.255322 | 0.910400 | 0.910462 |
2550 | 0.144100 | 0.284158 | 0.918400 | 0.918184 |
2600 | 0.167500 | 0.257225 | 0.908000 | 0.908138 |
2650 | 0.161400 | 0.226382 | 0.917600 | 0.917594 |
2700 | 0.157400 | 0.233534 | 0.919200 | 0.919229 |
2750 | 0.163500 | 0.243015 | 0.921600 | 0.921410 |
2800 | 0.121700 | 0.245816 | 0.920800 | 0.920682 |
2850 | 0.141500 | 0.285453 | 0.922400 | 0.922071 |
2900 | 0.163300 | 0.232592 | 0.921600 | 0.921549 |
2950 | 0.155700 | 0.241530 | 0.916800 | 0.916616 |
3000 | 0.150700 | 0.241032 | 0.916800 | 0.916760 |
3050 | 0.148100 | 0.261533 | 0.916000 | 0.915994 |
3100 | 0.146800 | 0.262320 | 0.912000 | 0.911413 |
3150 | 0.135700 | 0.245895 | 0.920000 | 0.919975 |
3200 | 0.129700 | 0.245331 | 0.921600 | 0.921476 |
3250 | 0.112900 | 0.275194 | 0.917600 | 0.917581 |
3300 | 0.124800 | 0.264425 | 0.923200 | 0.923013 |
3350 | 0.115600 | 0.292223 | 0.924000 | 0.923887 |
3400 | 0.105400 | 0.271411 | 0.917600 | 0.917427 |
3450 | 0.131200 | 0.260183 | 0.920000 | 0.919934 |
3500 | 0.109600 | 0.265847 | 0.922400 | 0.922314 |
3550 | 0.116700 | 0.269032 | 0.920800 | 0.920682 |
3600 | 0.133200 | 0.260401 | 0.924000 | 0.923916 |
3650 | 0.133300 | 0.267982 | 0.924000 | 0.923840 |
3700 | 0.124600 | 0.265107 | 0.919200 | 0.919194 |
3750 | 0.104300 | 0.276479 | 0.920800 | 0.920769 |
3800 | 0.106300 | 0.271788 | 0.921600 | 0.921562 |
3850 | 0.109200 | 0.271291 | 0.922400 | 0.922343 |
3900 | 0.110700 | 0.268460 | 0.920800 | 0.920755 |
DDoRA (All Attention) Test Results: {'eval_loss': 0.20760443806648254, 'eval_accuracy': 0.9165473684210527, 'eval_f1': 0.9165350069271784, 'eval_runtime': 155.718, 'eval_samples_per_second': 152.519, 'eval_steps_per_second': 4.771, 'epoch': 5.0} Prediction drift [[-2.5889697 1.7642486] [-3.9092817 2.3953364] [ 2.2905545 -2.4951563] [-2.053718 1.375324 ] [ 5.5414968 -5.709928 ]] [1 1 0 1 0] LoRA Heatmap
[m_out] Mean: -0.0008, Std: 0.5493, Min: -2.8318, Max: 3.1991 [m_in ] Mean: 0.0086, Std: 0.4519, Min: -2.4116, Max: 2.3131
In [19]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.scale" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.m" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale_out 2.0901947021484375 distilbert.transformer.layer.0.attention.q_lin.scale_in 1.9317268133163452 distilbert.transformer.layer.0.attention.k_lin.scale_out 2.0788636207580566 distilbert.transformer.layer.0.attention.k_lin.scale_in 1.8939082622528076 distilbert.transformer.layer.0.attention.v_lin.scale_out 2.0176312923431396 distilbert.transformer.layer.0.attention.v_lin.scale_in 1.6768195629119873 distilbert.transformer.layer.0.attention.out_lin.scale_out 1.7772330045700073 distilbert.transformer.layer.0.attention.out_lin.scale_in 1.8626189231872559 distilbert.transformer.layer.1.attention.q_lin.scale_out 2.018207311630249 distilbert.transformer.layer.1.attention.q_lin.scale_in 1.8735299110412598 distilbert.transformer.layer.1.attention.k_lin.scale_out 1.9069545269012451 distilbert.transformer.layer.1.attention.k_lin.scale_in 1.9223332405090332 distilbert.transformer.layer.1.attention.v_lin.scale_out 1.74566650390625 distilbert.transformer.layer.1.attention.v_lin.scale_in 1.838054895401001 distilbert.transformer.layer.1.attention.out_lin.scale_out 1.740161418914795 distilbert.transformer.layer.1.attention.out_lin.scale_in 1.8335179090499878 distilbert.transformer.layer.2.attention.q_lin.scale_out 1.9253816604614258 distilbert.transformer.layer.2.attention.q_lin.scale_in 1.9024275541305542 distilbert.transformer.layer.2.attention.k_lin.scale_out 1.9852197170257568 distilbert.transformer.layer.2.attention.k_lin.scale_in 1.8761833906173706 distilbert.transformer.layer.2.attention.v_lin.scale_out 1.796034574508667 distilbert.transformer.layer.2.attention.v_lin.scale_in 1.8118958473205566 distilbert.transformer.layer.2.attention.out_lin.scale_out 1.7673529386520386 distilbert.transformer.layer.2.attention.out_lin.scale_in 1.7620668411254883 distilbert.transformer.layer.3.attention.q_lin.scale_out 1.9725602865219116 distilbert.transformer.layer.3.attention.q_lin.scale_in 1.912266492843628 distilbert.transformer.layer.3.attention.k_lin.scale_out 1.9736626148223877 distilbert.transformer.layer.3.attention.k_lin.scale_in 1.8784260749816895 distilbert.transformer.layer.3.attention.v_lin.scale_out 1.7679297924041748 distilbert.transformer.layer.3.attention.v_lin.scale_in 1.8323858976364136 distilbert.transformer.layer.3.attention.out_lin.scale_out 1.8022470474243164 distilbert.transformer.layer.3.attention.out_lin.scale_in 1.8460712432861328 distilbert.transformer.layer.4.attention.q_lin.scale_out 1.9774129390716553 distilbert.transformer.layer.4.attention.q_lin.scale_in 1.8028367757797241 distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0612895488739014 distilbert.transformer.layer.4.attention.k_lin.scale_in 1.8295701742172241 distilbert.transformer.layer.4.attention.v_lin.scale_out 1.719773292541504 distilbert.transformer.layer.4.attention.v_lin.scale_in 1.7408709526062012 distilbert.transformer.layer.4.attention.out_lin.scale_out 1.7295746803283691 distilbert.transformer.layer.4.attention.out_lin.scale_in 1.7716983556747437 distilbert.transformer.layer.5.attention.q_lin.scale_out 1.9752777814865112 distilbert.transformer.layer.5.attention.q_lin.scale_in 1.7774198055267334 distilbert.transformer.layer.5.attention.k_lin.scale_out 2.129931926727295 distilbert.transformer.layer.5.attention.k_lin.scale_in 1.8098797798156738 distilbert.transformer.layer.5.attention.v_lin.scale_out 1.919114112854004 distilbert.transformer.layer.5.attention.v_lin.scale_in 1.6814241409301758 distilbert.transformer.layer.5.attention.out_lin.scale_out 1.7194103002548218 distilbert.transformer.layer.5.attention.out_lin.scale_in 1.7984395027160645 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m_out 0.5647355318069458 distilbert.transformer.layer.0.attention.q_lin.m_in 0.3798569142818451 distilbert.transformer.layer.0.attention.k_lin.m_out 0.5425061583518982 distilbert.transformer.layer.0.attention.k_lin.m_in 0.37627243995666504 distilbert.transformer.layer.0.attention.v_lin.m_out 0.480976939201355 distilbert.transformer.layer.0.attention.v_lin.m_in 0.26557350158691406 distilbert.transformer.layer.0.attention.out_lin.m_out 0.3110979497432709 distilbert.transformer.layer.0.attention.out_lin.m_in 0.29723748564720154 distilbert.transformer.layer.1.attention.q_lin.m_out 0.4283873736858368 distilbert.transformer.layer.1.attention.q_lin.m_in 0.3325177729129791 distilbert.transformer.layer.1.attention.k_lin.m_out 0.40755534172058105 distilbert.transformer.layer.1.attention.k_lin.m_in 0.34619420766830444 distilbert.transformer.layer.1.attention.v_lin.m_out 0.2836112380027771 distilbert.transformer.layer.1.attention.v_lin.m_in 0.29925063252449036 distilbert.transformer.layer.1.attention.out_lin.m_out 0.3088749349117279 distilbert.transformer.layer.1.attention.out_lin.m_in 0.27071183919906616 distilbert.transformer.layer.2.attention.q_lin.m_out 0.4211108684539795 distilbert.transformer.layer.2.attention.q_lin.m_in 0.30545711517333984 distilbert.transformer.layer.2.attention.k_lin.m_out 0.44150853157043457 distilbert.transformer.layer.2.attention.k_lin.m_in 0.3178306221961975 distilbert.transformer.layer.2.attention.v_lin.m_out 0.34029704332351685 distilbert.transformer.layer.2.attention.v_lin.m_in 0.24604985117912292 distilbert.transformer.layer.2.attention.out_lin.m_out 0.3179643154144287 distilbert.transformer.layer.2.attention.out_lin.m_in 0.2530387341976166 distilbert.transformer.layer.3.attention.q_lin.m_out 0.4518465995788574 distilbert.transformer.layer.3.attention.q_lin.m_in 0.3557422161102295 distilbert.transformer.layer.3.attention.k_lin.m_out 0.47002920508384705 distilbert.transformer.layer.3.attention.k_lin.m_in 0.35762590169906616 distilbert.transformer.layer.3.attention.v_lin.m_out 0.32528677582740784 distilbert.transformer.layer.3.attention.v_lin.m_in 0.2892145812511444 distilbert.transformer.layer.3.attention.out_lin.m_out 0.33790677785873413 distilbert.transformer.layer.3.attention.out_lin.m_in 0.28367918729782104 distilbert.transformer.layer.4.attention.q_lin.m_out 0.47639477252960205 distilbert.transformer.layer.4.attention.q_lin.m_in 0.29029154777526855 distilbert.transformer.layer.4.attention.k_lin.m_out 0.5150153636932373 distilbert.transformer.layer.4.attention.k_lin.m_in 0.3103344440460205 distilbert.transformer.layer.4.attention.v_lin.m_out 0.32499608397483826 distilbert.transformer.layer.4.attention.v_lin.m_in 0.2510720491409302 distilbert.transformer.layer.4.attention.out_lin.m_out 0.3214051127433777 distilbert.transformer.layer.4.attention.out_lin.m_in 0.2650536000728607 distilbert.transformer.layer.5.attention.q_lin.m_out 0.43587324023246765 distilbert.transformer.layer.5.attention.q_lin.m_in 0.22241328656673431 distilbert.transformer.layer.5.attention.k_lin.m_out 0.5570618510246277 distilbert.transformer.layer.5.attention.k_lin.m_in 0.2527024447917938 distilbert.transformer.layer.5.attention.v_lin.m_out 0.37015169858932495 distilbert.transformer.layer.5.attention.v_lin.m_in 0.20549951493740082 distilbert.transformer.layer.5.attention.out_lin.m_out 0.31466561555862427 distilbert.transformer.layer.5.attention.out_lin.m_in 0.27215391397476196 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.39448559284210205 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.29920345544815063 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.388362854719162 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.291858047246933 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.3334265351295471 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.26573705673217773 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.33093491196632385 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.2437349259853363 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.3549439609050751 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.23443476855754852 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.3481384217739105 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.27366185188293457 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.3255350589752197 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.1981675922870636 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.31510788202285767 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24702942371368408 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.32814159989356995 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.2725476622581482 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.33942073583602905 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2639490067958832 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.28993457555770874 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.19012793898582458 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.30221056938171387 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.23099581897258759 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.3653438091278076 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.2912411689758301 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.37983566522598267 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.2919533848762512 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.3182007968425751 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.22263233363628387 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.3163568377494812 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.22031466662883759 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.32468748092651367 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.25622785091400146 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.3412896394729614 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2617414891719818 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.3016916513442993 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.18387971818447113 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.31307727098464966 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.20827889442443848 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.2830806076526642 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.23386383056640625 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.3035750091075897 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.2312200516462326 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.28683286905288696 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.16982325911521912 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.3157418370246887 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.2109360694885254
In [20]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lora" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.scale" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.m" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 57.6984 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 42.6242 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 57.0737 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 41.5720 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 50.5210 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 38.4451 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 49.7002 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 35.1399 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 52.7742 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 34.2313 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 50.7043 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 38.7794 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 48.0156 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 28.3948 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 47.3806 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 35.4175 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 48.9293 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 38.7429 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 49.8689 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 37.3596 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 43.0877 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 27.5580 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 45.4100 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 32.9905 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 53.3644 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 41.3374 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 56.0531 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 41.8924 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 47.7002 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 31.9272 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 47.3052 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 31.5475 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 47.6581 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 36.8326 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 50.2827 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 37.8750 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 45.1280 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 26.8244 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 47.0836 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 30.0062 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 42.4531 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 33.3586 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 44.9119 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 33.6370 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 43.6901 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 24.9901 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 47.2171 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 30.6058 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 60.7115 distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.4249 distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 60.4370 distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 54.5890 distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 58.4866 distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 49.1681 distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 51.2425 distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 53.3198 distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 57.8829 distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 53.6451 distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 55.0335 distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 54.6923 distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 50.1200 distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 52.6154 distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 50.4183 distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 52.5371 distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 55.7336 distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 54.1221 distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 57.0685 distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 53.5724 distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 52.1679 distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 51.4892 distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 51.0791 distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 50.5406 distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 57.1301 distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 54.7893 distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 57.2046 distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 54.1217 distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.3382 distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 52.4507 distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 51.9679 distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 52.7578 distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 57.4674 distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 51.6265 distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 59.7401 distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 52.3292 distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 50.7752 distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 50.0643 distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 50.8358 distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 50.8573 distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 57.1890 distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 50.6338 distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 61.1200 distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 51.5444 distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 55.0821 distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 48.5430 distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 50.9995 distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 51.4657 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 20.0095 distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 15.0613 distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 19.6359 distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 15.1377 distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 17.2782 distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 12.7195 distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 11.9346 distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 12.8137 distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 15.6942 distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 13.5273 distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 14.8891 distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 13.4656 distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 11.5086 distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 12.5572 distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 11.9882 distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 12.0898 distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 15.6175 distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 12.3557 distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 15.7173 distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 12.8633 distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 12.7515 distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 10.7463 distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 12.1293 distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 11.4631 distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 16.4127 distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 13.9588 distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 16.8955 distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 14.1378 distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 12.3394 distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 12.3336 distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 12.4052 distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 11.9973 distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 17.3603 distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 12.2036 distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 18.3908 distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 12.4500 distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 12.6651 distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 11.2682 distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 12.6220 distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 11.9401 distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 16.4559 distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 10.3415 distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 18.6853 distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 10.9104 distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 13.6871 distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 10.8449 distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 12.6990 distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 11.9039
In [21]:
print (torch.cuda.memory_summary())
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 1077 MiB | 7290 MiB | 406631 GiB | 406630 GiB | | from large pool | 1055 MiB | 7240 MiB | 404673 GiB | 404672 GiB | | from small pool | 22 MiB | 51 MiB | 1958 GiB | 1958 GiB | |---------------------------------------------------------------------------| | Active memory | 1077 MiB | 7290 MiB | 406631 GiB | 406630 GiB | | from large pool | 1055 MiB | 7240 MiB | 404673 GiB | 404672 GiB | | from small pool | 22 MiB | 51 MiB | 1958 GiB | 1958 GiB | |---------------------------------------------------------------------------| | Requested memory | 1072 MiB | 7285 MiB | 406404 GiB | 406403 GiB | | from large pool | 1050 MiB | 7235 MiB | 404455 GiB | 404454 GiB | | from small pool | 22 MiB | 51 MiB | 1948 GiB | 1948 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 7380 MiB | 7380 MiB | 14270 MiB | 6890 MiB | | from large pool | 7328 MiB | 7328 MiB | 14168 MiB | 6840 MiB | | from small pool | 52 MiB | 52 MiB | 102 MiB | 50 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 162149 KiB | 545173 KiB | 144872 GiB | 144872 GiB | | from large pool | 156580 KiB | 540498 KiB | 142781 GiB | 142781 GiB | | from small pool | 5569 KiB | 22942 KiB | 2091 GiB | 2091 GiB | |---------------------------------------------------------------------------| | Allocations | 1310 | 1616 | 44683 K | 44682 K | | from large pool | 164 | 318 | 13813 K | 13813 K | | from small pool | 1146 | 1451 | 30869 K | 30868 K | |---------------------------------------------------------------------------| | Active allocs | 1310 | 1616 | 44683 K | 44682 K | | from large pool | 164 | 318 | 13813 K | 13813 K | | from small pool | 1146 | 1451 | 30869 K | 30868 K | |---------------------------------------------------------------------------| | GPU reserved segments | 190 | 190 | 327 | 137 | | from large pool | 164 | 164 | 276 | 112 | | from small pool | 26 | 26 | 51 | 25 | |---------------------------------------------------------------------------| | Non-releasable allocs | 37 | 60 | 20413 K | 20413 K | | from large pool | 18 | 26 | 3326 K | 3326 K | | from small pool | 19 | 43 | 17086 K | 17086 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
Full DistilBERT on IMDB Sentiment¶
In [22]:
# Define the training parameters
batch_size = 32
logging_steps = len(dataset_encoded["train"]) // batch_size
model_name = f"{model_ckpt}-finetuned-imdb"
training_args = TrainingArguments(output_dir=model_name,
num_train_epochs=2,
learning_rate=1e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
evaluation_strategy="steps",
eval_steps=20, # Evaluate more frequently on the larger dataset
logging_steps=20,
save_steps=20,
max_grad_norm=1.0,
disable_tqdm=False,
push_to_hub=False,
report_to="none",
log_level="error")
trainer = Trainer(model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
tokenizer=tokenizer)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn( C:\Users\alexa\AppData\Local\Temp\ipykernel_22420\3915345233.py:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead. trainer = Trainer(model=model,
In [23]:
# Train!
trainer.train()
print (torch.cuda.memory_summary())
# Evaluate on the test set
evaluation_results = trainer.evaluate(dataset_encoded["test"])
print(f"\nEvaluation results on the test set:")
print(evaluation_results)
# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt
# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]
# Extract validation loss
val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]
# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss vs. Steps")
plt.legend()
plt.show()
val_accuracy_steps = [entry["step"] for entry in log_history if "eval_accuracy" in entry]
val_accuracies = [entry["eval_accuracy"] for entry in log_history if "eval_accuracy" in entry]
val_f1_steps = [entry["step"] for entry in log_history if "eval_f1" in entry]
val_f1_scores = [entry["eval_f1"] for entry in log_history if "eval_f1" in entry]
plt.plot(val_accuracy_steps, val_accuracies, label="Validation Accuracy", linestyle="-")
plt.plot(val_f1_steps, val_f1_scores, label="Validation F1 score", linestyle="--")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Val Accurace and F1 vs. Steps")
plt.legend()
plt.show()
[1564/1564 45:03, Epoch 2/2]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
20 | 0.689400 | 0.676229 | 0.744000 | 0.739512 |
40 | 0.659500 | 0.588595 | 0.820800 | 0.819493 |
60 | 0.494700 | 0.388211 | 0.865600 | 0.865512 |
80 | 0.366500 | 0.328102 | 0.866400 | 0.865610 |
100 | 0.323200 | 0.317173 | 0.876800 | 0.877069 |
120 | 0.339900 | 0.284819 | 0.884000 | 0.883705 |
140 | 0.298000 | 0.275737 | 0.885600 | 0.885591 |
160 | 0.305800 | 0.292675 | 0.880800 | 0.879911 |
180 | 0.258600 | 0.272887 | 0.888800 | 0.888825 |
200 | 0.257000 | 0.272066 | 0.884800 | 0.884265 |
220 | 0.255000 | 0.281653 | 0.884800 | 0.883885 |
240 | 0.286100 | 0.253439 | 0.901600 | 0.901510 |
260 | 0.257200 | 0.248806 | 0.902400 | 0.902369 |
280 | 0.244400 | 0.249558 | 0.902400 | 0.902184 |
300 | 0.282500 | 0.258910 | 0.907200 | 0.906769 |
320 | 0.266500 | 0.241840 | 0.906400 | 0.906141 |
340 | 0.269500 | 0.239128 | 0.900800 | 0.900752 |
360 | 0.252200 | 0.258152 | 0.897600 | 0.896886 |
380 | 0.250800 | 0.246389 | 0.908800 | 0.908924 |
400 | 0.246200 | 0.252573 | 0.901600 | 0.901795 |
420 | 0.252400 | 0.269475 | 0.894400 | 0.894634 |
440 | 0.219100 | 0.262670 | 0.898400 | 0.898618 |
460 | 0.300400 | 0.263741 | 0.899200 | 0.898037 |
480 | 0.270000 | 0.264092 | 0.896000 | 0.894800 |
500 | 0.258500 | 0.240369 | 0.904800 | 0.904421 |
520 | 0.254600 | 0.230196 | 0.908000 | 0.907948 |
540 | 0.238300 | 0.233112 | 0.905600 | 0.905677 |
560 | 0.233700 | 0.233234 | 0.904000 | 0.904111 |
580 | 0.284200 | 0.229832 | 0.905600 | 0.905570 |
600 | 0.254800 | 0.229044 | 0.911200 | 0.911050 |
620 | 0.215700 | 0.228184 | 0.913600 | 0.913528 |
640 | 0.247200 | 0.244651 | 0.901600 | 0.901771 |
660 | 0.223600 | 0.232534 | 0.911200 | 0.910891 |
680 | 0.234400 | 0.248592 | 0.906400 | 0.906585 |
700 | 0.247600 | 0.223161 | 0.913600 | 0.913625 |
720 | 0.269200 | 0.223580 | 0.912000 | 0.912128 |
740 | 0.202000 | 0.222004 | 0.909600 | 0.909669 |
760 | 0.226900 | 0.233625 | 0.908800 | 0.908961 |
780 | 0.215700 | 0.219046 | 0.908800 | 0.908814 |
800 | 0.195900 | 0.222158 | 0.917600 | 0.917478 |
820 | 0.175600 | 0.259501 | 0.896800 | 0.897026 |
840 | 0.213900 | 0.228438 | 0.914400 | 0.914306 |
860 | 0.212300 | 0.244444 | 0.904800 | 0.904978 |
880 | 0.188900 | 0.224636 | 0.914400 | 0.914337 |
900 | 0.166400 | 0.226523 | 0.912000 | 0.911943 |
920 | 0.177200 | 0.237086 | 0.910400 | 0.910530 |
940 | 0.214900 | 0.226861 | 0.913600 | 0.913463 |
960 | 0.178100 | 0.239278 | 0.908000 | 0.908153 |
980 | 0.189200 | 0.224798 | 0.912000 | 0.911943 |
1000 | 0.166500 | 0.232560 | 0.913600 | 0.913660 |
1020 | 0.187100 | 0.230836 | 0.913600 | 0.913625 |
1040 | 0.154100 | 0.229299 | 0.916000 | 0.915923 |
1060 | 0.207200 | 0.226525 | 0.916800 | 0.916847 |
1080 | 0.191800 | 0.225182 | 0.916000 | 0.915907 |
1100 | 0.204000 | 0.233329 | 0.912000 | 0.912093 |
1120 | 0.165900 | 0.227521 | 0.916800 | 0.916836 |
1140 | 0.156800 | 0.226366 | 0.917600 | 0.917524 |
1160 | 0.203400 | 0.224103 | 0.914400 | 0.914443 |
1180 | 0.203600 | 0.221393 | 0.917600 | 0.917478 |
1200 | 0.174700 | 0.232756 | 0.908800 | 0.908932 |
1220 | 0.164400 | 0.225512 | 0.916000 | 0.915967 |
1240 | 0.169700 | 0.230427 | 0.908800 | 0.908886 |
1260 | 0.220100 | 0.228160 | 0.916000 | 0.915823 |
1280 | 0.177400 | 0.233028 | 0.907200 | 0.907317 |
1300 | 0.207600 | 0.226083 | 0.914400 | 0.914406 |
1320 | 0.210700 | 0.225015 | 0.911200 | 0.911245 |
1340 | 0.148000 | 0.224107 | 0.916800 | 0.916634 |
1360 | 0.170400 | 0.224322 | 0.910400 | 0.910439 |
1380 | 0.200600 | 0.223000 | 0.913600 | 0.913558 |
1400 | 0.175900 | 0.223361 | 0.911200 | 0.911245 |
1420 | 0.169600 | 0.224143 | 0.909600 | 0.909657 |
1440 | 0.157600 | 0.224942 | 0.909600 | 0.909657 |
1460 | 0.183600 | 0.223980 | 0.912000 | 0.911986 |
1480 | 0.204600 | 0.224059 | 0.912000 | 0.911972 |
1500 | 0.181200 | 0.224224 | 0.914400 | 0.914337 |
1520 | 0.175000 | 0.224378 | 0.912800 | 0.912765 |
1540 | 0.183800 | 0.225751 | 0.907200 | 0.907240 |
1560 | 0.188800 | 0.226429 | 0.908000 | 0.908070 |
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 1589 MiB | 7290 MiB | 440273 GiB | 440272 GiB | | from large pool | 1566 MiB | 7240 MiB | 438070 GiB | 438069 GiB | | from small pool | 23 MiB | 51 MiB | 2203 GiB | 2203 GiB | |---------------------------------------------------------------------------| | Active memory | 1589 MiB | 7290 MiB | 440273 GiB | 440272 GiB | | from large pool | 1566 MiB | 7240 MiB | 438070 GiB | 438069 GiB | | from small pool | 23 MiB | 51 MiB | 2203 GiB | 2203 GiB | |---------------------------------------------------------------------------| | Requested memory | 1583 MiB | 7285 MiB | 439981 GiB | 439979 GiB | | from large pool | 1560 MiB | 7235 MiB | 437788 GiB | 437786 GiB | | from small pool | 23 MiB | 51 MiB | 2192 GiB | 2192 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 5538 MiB | 7380 MiB | 18572 MiB | 13034 MiB | | from large pool | 5508 MiB | 7328 MiB | 18468 MiB | 12960 MiB | | from small pool | 30 MiB | 52 MiB | 104 MiB | 74 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 264726 KiB | 545173 KiB | 154496 GiB | 154496 GiB | | from large pool | 259656 KiB | 540498 KiB | 152152 GiB | 152152 GiB | | from small pool | 5070 KiB | 22942 KiB | 2343 GiB | 2343 GiB | |---------------------------------------------------------------------------| | Allocations | 1518 | 1736 | 48429 K | 48428 K | | from large pool | 242 | 320 | 14932 K | 14932 K | | from small pool | 1276 | 1455 | 33496 K | 33495 K | |---------------------------------------------------------------------------| | Active allocs | 1518 | 1736 | 48429 K | 48428 K | | from large pool | 242 | 320 | 14932 K | 14932 K | | from small pool | 1276 | 1455 | 33496 K | 33495 K | |---------------------------------------------------------------------------| | GPU reserved segments | 119 | 190 | 399 | 280 | | from large pool | 104 | 164 | 347 | 243 | | from small pool | 15 | 26 | 52 | 37 | |---------------------------------------------------------------------------| | Non-releasable allocs | 41 | 60 | 22134 K | 22134 K | | from large pool | 21 | 28 | 3630 K | 3630 K | | from small pool | 20 | 43 | 18503 K | 18503 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
[743/743 6:00:33]
Evaluation results on the test set: {'eval_loss': 0.21304276585578918, 'eval_accuracy': 0.9208421052631579, 'eval_f1': 0.9208423095953104, 'eval_runtime': 122.0514, 'eval_samples_per_second': 194.59, 'eval_steps_per_second': 6.088, 'epoch': 2.0}
Bonus - DDoRA training - continued...¶
In [24]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
if "lora.B" in name:
layer_names.append(name)
b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
if isinstance(module, LinearWithDoubleDoRA):
if hasattr(module, 'm_out'):
m_out_param = module.m_out.detach().cpu().numpy().flatten()
m_out_values.extend(m_out_param)
if hasattr(module, 'm_in'):
m_in_param = module.m_in.detach().cpu().numpy().flatten()
m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
[3910/3910 2:54:29, Epoch 5/5]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.220400 | 0.234968 | 0.918400 | 0.918165 |
100 | 0.218500 | 0.291846 | 0.909600 | 0.909350 |
150 | 0.262700 | 0.238732 | 0.904800 | 0.904872 |
200 | 0.202600 | 0.255080 | 0.905600 | 0.905789 |
250 | 0.236200 | 0.241231 | 0.912800 | 0.912896 |
300 | 0.241200 | 0.219332 | 0.915200 | 0.915200 |
350 | 0.250000 | 0.223478 | 0.909600 | 0.909217 |
400 | 0.212100 | 0.238065 | 0.910400 | 0.910055 |
450 | 0.205200 | 0.279559 | 0.912800 | 0.912232 |
500 | 0.218100 | 0.243256 | 0.918400 | 0.918435 |
550 | 0.204100 | 0.298102 | 0.916800 | 0.916634 |
600 | 0.246300 | 0.275445 | 0.916000 | 0.915923 |
650 | 0.220000 | 0.258947 | 0.905600 | 0.905260 |
700 | 0.196100 | 0.254702 | 0.912800 | 0.912704 |
750 | 0.234500 | 0.282484 | 0.887200 | 0.885504 |
800 | 0.202900 | 0.328079 | 0.892800 | 0.891303 |
850 | 0.186800 | 0.242680 | 0.912800 | 0.912780 |
900 | 0.180200 | 0.268671 | 0.921600 | 0.921575 |
950 | 0.173000 | 0.297629 | 0.917600 | 0.917391 |
1000 | 0.196400 | 0.250007 | 0.904800 | 0.904209 |
1050 | 0.196400 | 0.241103 | 0.920800 | 0.920794 |
1100 | 0.188100 | 0.263111 | 0.920000 | 0.919934 |
1150 | 0.170100 | 0.242923 | 0.916800 | 0.916941 |
1200 | 0.195600 | 0.233525 | 0.927200 | 0.927236 |
1250 | 0.175300 | 0.244239 | 0.909600 | 0.908955 |
1300 | 0.197200 | 0.263203 | 0.914400 | 0.914565 |
1350 | 0.225300 | 0.233377 | 0.916000 | 0.915907 |
1400 | 0.193400 | 0.232386 | 0.922400 | 0.922477 |
1450 | 0.201500 | 0.216410 | 0.919200 | 0.919168 |
1500 | 0.185000 | 0.218531 | 0.916800 | 0.916458 |
1550 | 0.208200 | 0.246917 | 0.910400 | 0.910552 |
1600 | 0.156400 | 0.343675 | 0.893600 | 0.893824 |
1650 | 0.145800 | 0.321663 | 0.923200 | 0.923150 |
1700 | 0.170000 | 0.295388 | 0.918400 | 0.918467 |
1750 | 0.166200 | 0.247114 | 0.928800 | 0.928760 |
1800 | 0.154300 | 0.227517 | 0.920000 | 0.920066 |
1850 | 0.165400 | 0.281829 | 0.922400 | 0.922459 |
1900 | 0.153700 | 0.256229 | 0.924800 | 0.924710 |
1950 | 0.164500 | 0.270434 | 0.923200 | 0.923297 |
2000 | 0.167300 | 0.226990 | 0.924800 | 0.924724 |
2050 | 0.167500 | 0.239940 | 0.927200 | 0.927120 |
2100 | 0.163400 | 0.226345 | 0.918400 | 0.918146 |
2150 | 0.116000 | 0.307402 | 0.923200 | 0.923108 |
2200 | 0.172000 | 0.269007 | 0.921600 | 0.921410 |
2250 | 0.160900 | 0.270219 | 0.924800 | 0.924634 |
2300 | 0.165600 | 0.256944 | 0.921600 | 0.921521 |
2350 | 0.155100 | 0.233660 | 0.916000 | 0.915666 |
2400 | 0.138500 | 0.256974 | 0.913600 | 0.913331 |
2450 | 0.124600 | 0.244719 | 0.913600 | 0.913371 |
2500 | 0.118200 | 0.303150 | 0.920000 | 0.920109 |
2550 | 0.102700 | 0.305154 | 0.912000 | 0.911568 |
2600 | 0.140000 | 0.266935 | 0.921600 | 0.921664 |
2650 | 0.114900 | 0.281125 | 0.915200 | 0.915099 |
2700 | 0.124200 | 0.297636 | 0.916800 | 0.916746 |
2750 | 0.112800 | 0.275313 | 0.919200 | 0.919096 |
2800 | 0.109300 | 0.253661 | 0.920000 | 0.919934 |
2850 | 0.107300 | 0.281190 | 0.913600 | 0.913127 |
2900 | 0.118900 | 0.262891 | 0.920800 | 0.920850 |
2950 | 0.123200 | 0.258890 | 0.922400 | 0.922300 |
3000 | 0.130200 | 0.258712 | 0.921600 | 0.921562 |
3050 | 0.119200 | 0.269527 | 0.917600 | 0.917567 |
3100 | 0.124000 | 0.263077 | 0.913600 | 0.913127 |
3150 | 0.101900 | 0.296506 | 0.919200 | 0.919181 |
3200 | 0.093300 | 0.302173 | 0.923200 | 0.923093 |
3250 | 0.084200 | 0.305703 | 0.920800 | 0.920781 |
3300 | 0.091100 | 0.294771 | 0.921600 | 0.921506 |
3350 | 0.077800 | 0.341725 | 0.920000 | 0.919806 |
3400 | 0.087400 | 0.324175 | 0.916800 | 0.916716 |
3450 | 0.100900 | 0.307702 | 0.921600 | 0.921443 |
3500 | 0.080300 | 0.325871 | 0.913600 | 0.913573 |
3550 | 0.067700 | 0.348119 | 0.916800 | 0.916800 |
3600 | 0.109800 | 0.301693 | 0.918400 | 0.918184 |
3650 | 0.094600 | 0.309193 | 0.915200 | 0.915114 |
3700 | 0.087600 | 0.316292 | 0.916800 | 0.916598 |
3750 | 0.076500 | 0.331572 | 0.915200 | 0.915225 |
3800 | 0.082600 | 0.318606 | 0.918400 | 0.918302 |
3850 | 0.075700 | 0.319939 | 0.919200 | 0.919126 |
3900 | 0.076200 | 0.319871 | 0.919200 | 0.919126 |
DDoRA (All Attention) Test Results: {'eval_loss': 0.20528165996074677, 'eval_accuracy': 0.9178526315789474, 'eval_f1': 0.9178393215322689, 'eval_runtime': 160.0105, 'eval_samples_per_second': 148.428, 'eval_steps_per_second': 4.643, 'epoch': 5.0} Prediction drift [[-2.8385003 1.8467747] [-4.145672 3.0743616] [ 2.190418 -2.6564922] [-2.6781795 1.6880615] [ 4.3640585 -5.1461053]] [1 1 0 1 0] LoRA Heatmap
[m_out] Mean: 0.0013, Std: 0.7323, Min: -3.2482, Max: 5.2270 [m_in ] Mean: 0.0093, Std: 0.5591, Min: -2.4425, Max: 2.6041
In [25]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
if "lora.B" in name:
layer_names.append(name)
b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
if isinstance(module, LinearWithDoubleDoRA):
if hasattr(module, 'm_out'):
m_out_param = module.m_out.detach().cpu().numpy().flatten()
m_out_values.extend(m_out_param)
if hasattr(module, 'm_in'):
m_in_param = module.m_in.detach().cpu().numpy().flatten()
m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
[3910/3910 2:58:07, Epoch 5/5]
Step | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
50 | 0.193100 | 0.231645 | 0.906400 | 0.905762 |
100 | 0.181300 | 0.284688 | 0.905600 | 0.905806 |
150 | 0.201600 | 0.228212 | 0.919200 | 0.919194 |
200 | 0.165300 | 0.329915 | 0.904000 | 0.903447 |
250 | 0.180800 | 0.242941 | 0.906400 | 0.905702 |
300 | 0.200800 | 0.223974 | 0.911200 | 0.910594 |
350 | 0.207600 | 0.294283 | 0.903200 | 0.902281 |
400 | 0.172700 | 0.272758 | 0.912800 | 0.912720 |
450 | 0.174600 | 0.307540 | 0.912000 | 0.911301 |
500 | 0.194500 | 0.255431 | 0.923200 | 0.922979 |
550 | 0.179700 | 0.227661 | 0.924000 | 0.923982 |
600 | 0.206200 | 0.250669 | 0.924800 | 0.924788 |
650 | 0.161900 | 0.223719 | 0.920800 | 0.920616 |
700 | 0.160000 | 0.327061 | 0.913600 | 0.913709 |
750 | 0.183100 | 0.310250 | 0.889600 | 0.887916 |
800 | 0.176400 | 0.296977 | 0.902400 | 0.901688 |
850 | 0.130100 | 0.293862 | 0.912800 | 0.912720 |
900 | 0.132400 | 0.327696 | 0.912000 | 0.911493 |
950 | 0.145000 | 0.389061 | 0.912800 | 0.912430 |
1000 | 0.154600 | 0.376502 | 0.908800 | 0.908494 |
1050 | 0.149700 | 0.331415 | 0.908000 | 0.908145 |
1100 | 0.150100 | 0.422993 | 0.908800 | 0.908786 |
1150 | 0.122200 | 0.281298 | 0.911200 | 0.911316 |
1200 | 0.146500 | 0.323460 | 0.913600 | 0.913747 |
1250 | 0.152600 | 0.273852 | 0.918400 | 0.918271 |
1300 | 0.174300 | 0.231292 | 0.920000 | 0.919919 |
1350 | 0.172700 | 0.253323 | 0.918400 | 0.918237 |
1400 | 0.174600 | 0.237144 | 0.910400 | 0.910008 |
1450 | 0.150800 | 0.240501 | 0.921600 | 0.921506 |
1500 | 0.196900 | 0.244910 | 0.920800 | 0.920769 |
1550 | 0.202900 | 0.221191 | 0.916000 | 0.916064 |
1600 | 0.134700 | 0.323171 | 0.916800 | 0.916824 |
1650 | 0.106900 | 0.297118 | 0.923200 | 0.922923 |
1700 | 0.138200 | 0.273596 | 0.913600 | 0.913725 |
1750 | 0.125000 | 0.324790 | 0.922400 | 0.922382 |
1800 | 0.113900 | 0.257193 | 0.908800 | 0.908915 |
1850 | 0.138000 | 0.295698 | 0.916800 | 0.916836 |
1900 | 0.110100 | 0.312183 | 0.923200 | 0.923163 |
1950 | 0.120400 | 0.276098 | 0.916000 | 0.916074 |
2000 | 0.135500 | 0.260222 | 0.923200 | 0.923233 |
2050 | 0.134200 | 0.242475 | 0.920800 | 0.920829 |
2100 | 0.124400 | 0.246645 | 0.924800 | 0.924634 |
2150 | 0.099600 | 0.298770 | 0.922400 | 0.922369 |
2200 | 0.120900 | 0.293304 | 0.923200 | 0.923136 |
2250 | 0.119700 | 0.287883 | 0.924800 | 0.924681 |
2300 | 0.125600 | 0.258974 | 0.915200 | 0.915315 |
2350 | 0.118500 | 0.262726 | 0.921600 | 0.921427 |
2400 | 0.094500 | 0.329549 | 0.912800 | 0.912178 |
2450 | 0.102100 | 0.306560 | 0.922400 | 0.922269 |
2500 | 0.100400 | 0.277154 | 0.922400 | 0.922439 |
2550 | 0.080800 | 0.327069 | 0.921600 | 0.921374 |
2600 | 0.103800 | 0.279821 | 0.918400 | 0.918526 |
2650 | 0.097400 | 0.280838 | 0.918400 | 0.918495 |
2700 | 0.089000 | 0.333324 | 0.921600 | 0.921634 |
2750 | 0.093300 | 0.278819 | 0.921600 | 0.921443 |
2800 | 0.069700 | 0.351026 | 0.924000 | 0.923916 |
2850 | 0.071100 | 0.334101 | 0.923200 | 0.923013 |
2900 | 0.089700 | 0.317389 | 0.924000 | 0.923970 |
2950 | 0.101400 | 0.295993 | 0.924800 | 0.924724 |
3000 | 0.090500 | 0.301456 | 0.925600 | 0.925594 |
3050 | 0.092200 | 0.300997 | 0.924000 | 0.924028 |
3100 | 0.091500 | 0.297800 | 0.920000 | 0.919731 |
3150 | 0.079600 | 0.339592 | 0.920800 | 0.920840 |
3200 | 0.072300 | 0.360106 | 0.922400 | 0.922369 |
3250 | 0.053700 | 0.383484 | 0.925600 | 0.925545 |
3300 | 0.072800 | 0.334235 | 0.922400 | 0.922356 |
3350 | 0.065400 | 0.371789 | 0.921600 | 0.921410 |
3400 | 0.068900 | 0.343820 | 0.924000 | 0.923930 |
3450 | 0.071900 | 0.332017 | 0.922400 | 0.922329 |
3500 | 0.046600 | 0.393188 | 0.922400 | 0.922394 |
3550 | 0.057900 | 0.388426 | 0.919200 | 0.919240 |
3600 | 0.065900 | 0.366166 | 0.923200 | 0.923176 |
3650 | 0.073700 | 0.351996 | 0.924000 | 0.923916 |
3700 | 0.070800 | 0.337321 | 0.922400 | 0.922314 |
3750 | 0.054100 | 0.353665 | 0.923200 | 0.923188 |
3800 | 0.051200 | 0.360188 | 0.924000 | 0.923957 |
3850 | 0.061000 | 0.360322 | 0.923200 | 0.923176 |
3900 | 0.063800 | 0.360985 | 0.924800 | 0.924788 |
DDoRA (All Attention) Test Results: {'eval_loss': 0.2019222378730774, 'eval_accuracy': 0.9212631578947369, 'eval_f1': 0.9212608684750585, 'eval_runtime': 157.6478, 'eval_samples_per_second': 150.652, 'eval_steps_per_second': 4.713, 'epoch': 5.0} Prediction drift [[-2.494168 1.6709884] [-4.1562243 3.2771716] [ 2.5441208 -2.9046338] [-2.77265 1.9955821] [ 5.211438 -5.7563157]] [1 1 0 1 0] LoRA Heatmap
[m_out] Mean: 0.0017, Std: 0.8710, Min: -3.8374, Max: 6.9031 [m_in ] Mean: 0.0119, Std: 0.6580, Min: -2.7239, Max: 2.7596
In [26]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.scale" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.m" in name:
print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
if "lora" in name:
print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale_out 2.211477279663086 distilbert.transformer.layer.0.attention.q_lin.scale_in 1.7990350723266602 distilbert.transformer.layer.0.attention.k_lin.scale_out 2.160134792327881 distilbert.transformer.layer.0.attention.k_lin.scale_in 1.7010796070098877 distilbert.transformer.layer.0.attention.v_lin.scale_out 2.182528257369995 distilbert.transformer.layer.0.attention.v_lin.scale_in 1.0187453031539917 distilbert.transformer.layer.0.attention.out_lin.scale_out 1.5009715557098389 distilbert.transformer.layer.0.attention.out_lin.scale_in 1.3177869319915771 distilbert.transformer.layer.1.attention.q_lin.scale_out 2.0520782470703125 distilbert.transformer.layer.1.attention.q_lin.scale_in 1.6306018829345703 distilbert.transformer.layer.1.attention.k_lin.scale_out 1.8983781337738037 distilbert.transformer.layer.1.attention.k_lin.scale_in 1.684961199760437 distilbert.transformer.layer.1.attention.v_lin.scale_out 1.3902130126953125 distilbert.transformer.layer.1.attention.v_lin.scale_in 1.4229816198349 distilbert.transformer.layer.1.attention.out_lin.scale_out 1.4535112380981445 distilbert.transformer.layer.1.attention.out_lin.scale_in 1.2360076904296875 distilbert.transformer.layer.2.attention.q_lin.scale_out 1.8091825246810913 distilbert.transformer.layer.2.attention.q_lin.scale_in 1.6466021537780762 distilbert.transformer.layer.2.attention.k_lin.scale_out 2.054741859436035 distilbert.transformer.layer.2.attention.k_lin.scale_in 1.6497936248779297 distilbert.transformer.layer.2.attention.v_lin.scale_out 1.364464282989502 distilbert.transformer.layer.2.attention.v_lin.scale_in 1.3622424602508545 distilbert.transformer.layer.2.attention.out_lin.scale_out 1.31589674949646 distilbert.transformer.layer.2.attention.out_lin.scale_in 1.1265673637390137 distilbert.transformer.layer.3.attention.q_lin.scale_out 2.026132345199585 distilbert.transformer.layer.3.attention.q_lin.scale_in 1.7109956741333008 distilbert.transformer.layer.3.attention.k_lin.scale_out 2.1449785232543945 distilbert.transformer.layer.3.attention.k_lin.scale_in 1.6678569316864014 distilbert.transformer.layer.3.attention.v_lin.scale_out 1.5447391271591187 distilbert.transformer.layer.3.attention.v_lin.scale_in 1.2968910932540894 distilbert.transformer.layer.3.attention.out_lin.scale_out 1.459152102470398 distilbert.transformer.layer.3.attention.out_lin.scale_in 1.3626108169555664 distilbert.transformer.layer.4.attention.q_lin.scale_out 1.8095883131027222 distilbert.transformer.layer.4.attention.q_lin.scale_in 1.434910535812378 distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0096232891082764 distilbert.transformer.layer.4.attention.k_lin.scale_in 1.4818964004516602 distilbert.transformer.layer.4.attention.v_lin.scale_out 1.224574089050293 distilbert.transformer.layer.4.attention.v_lin.scale_in 0.9664930701255798 distilbert.transformer.layer.4.attention.out_lin.scale_out 1.186560034751892 distilbert.transformer.layer.4.attention.out_lin.scale_in 1.138476848602295 distilbert.transformer.layer.5.attention.q_lin.scale_out 1.6922223567962646 distilbert.transformer.layer.5.attention.q_lin.scale_in 1.1524317264556885 distilbert.transformer.layer.5.attention.k_lin.scale_out 1.892466425895691 distilbert.transformer.layer.5.attention.k_lin.scale_in 1.325523853302002 distilbert.transformer.layer.5.attention.v_lin.scale_out 1.4372739791870117 distilbert.transformer.layer.5.attention.v_lin.scale_in 0.7315273284912109 distilbert.transformer.layer.5.attention.out_lin.scale_out 1.1462877988815308 distilbert.transformer.layer.5.attention.out_lin.scale_in 1.2485125064849854 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m_out 0.904198169708252 distilbert.transformer.layer.0.attention.q_lin.m_in 0.5918328166007996 distilbert.transformer.layer.0.attention.k_lin.m_out 0.8658651113510132 distilbert.transformer.layer.0.attention.k_lin.m_in 0.552005410194397 distilbert.transformer.layer.0.attention.v_lin.m_out 0.891939640045166 distilbert.transformer.layer.0.attention.v_lin.m_in 0.3521280288696289 distilbert.transformer.layer.0.attention.out_lin.m_out 0.5082004070281982 distilbert.transformer.layer.0.attention.out_lin.m_in 0.3994007110595703 distilbert.transformer.layer.1.attention.q_lin.m_out 0.7819246053695679 distilbert.transformer.layer.1.attention.q_lin.m_in 0.5072041153907776 distilbert.transformer.layer.1.attention.k_lin.m_out 0.7096916437149048 distilbert.transformer.layer.1.attention.k_lin.m_in 0.5101008415222168 distilbert.transformer.layer.1.attention.v_lin.m_out 0.48597097396850586 distilbert.transformer.layer.1.attention.v_lin.m_in 0.4318428933620453 distilbert.transformer.layer.1.attention.out_lin.m_out 0.510459303855896 distilbert.transformer.layer.1.attention.out_lin.m_in 0.370250940322876 distilbert.transformer.layer.2.attention.q_lin.m_out 0.672264039516449 distilbert.transformer.layer.2.attention.q_lin.m_in 0.49723148345947266 distilbert.transformer.layer.2.attention.k_lin.m_out 0.8109812140464783 distilbert.transformer.layer.2.attention.k_lin.m_in 0.5146569013595581 distilbert.transformer.layer.2.attention.v_lin.m_out 0.4366285800933838 distilbert.transformer.layer.2.attention.v_lin.m_in 0.36893153190612793 distilbert.transformer.layer.2.attention.out_lin.m_out 0.4298613965511322 distilbert.transformer.layer.2.attention.out_lin.m_in 0.3405319154262543 distilbert.transformer.layer.3.attention.q_lin.m_out 0.7905937433242798 distilbert.transformer.layer.3.attention.q_lin.m_in 0.5291155576705933 distilbert.transformer.layer.3.attention.k_lin.m_out 0.8796322345733643 distilbert.transformer.layer.3.attention.k_lin.m_in 0.539239764213562 distilbert.transformer.layer.3.attention.v_lin.m_out 0.5514059066772461 distilbert.transformer.layer.3.attention.v_lin.m_in 0.395082950592041 distilbert.transformer.layer.3.attention.out_lin.m_out 0.5047212839126587 distilbert.transformer.layer.3.attention.out_lin.m_in 0.4035267233848572 distilbert.transformer.layer.4.attention.q_lin.m_out 0.682336688041687 distilbert.transformer.layer.4.attention.q_lin.m_in 0.4449799656867981 distilbert.transformer.layer.4.attention.k_lin.m_out 0.7756460905075073 distilbert.transformer.layer.4.attention.k_lin.m_in 0.46460646390914917 distilbert.transformer.layer.4.attention.v_lin.m_out 0.41292259097099304 distilbert.transformer.layer.4.attention.v_lin.m_in 0.2883872985839844 distilbert.transformer.layer.4.attention.out_lin.m_out 0.3856971859931946 distilbert.transformer.layer.4.attention.out_lin.m_in 0.3439074456691742 distilbert.transformer.layer.5.attention.q_lin.m_out 0.5993542671203613 distilbert.transformer.layer.5.attention.q_lin.m_in 0.3158256709575653 distilbert.transformer.layer.5.attention.k_lin.m_out 0.6338485479354858 distilbert.transformer.layer.5.attention.k_lin.m_in 0.3383934199810028 distilbert.transformer.layer.5.attention.v_lin.m_out 0.4689478278160095 distilbert.transformer.layer.5.attention.v_lin.m_in 0.2039266675710678 distilbert.transformer.layer.5.attention.out_lin.m_out 0.3742349147796631 distilbert.transformer.layer.5.attention.out_lin.m_in 0.39616748690605164 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6073777079582214 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.5936362743377686 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.5834858417510986 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.5944669246673584 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.477506160736084 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.5750714540481567 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4822663962841034 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.5455904006958008 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.5509202480316162 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.5553864240646362 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5384150147438049 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.5774441957473755 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5051205158233643 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.4909871518611908 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.4681047201156616 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.5537931323051453 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5328584909439087 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.5830724239349365 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5513193011283875 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.5603161454200745 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.46175330877304077 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.4580882787704468 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4390673041343689 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.5048208236694336 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5584718585014343 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.5750592947006226 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.5745105147361755 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.5809159278869629 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.47269874811172485 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.5128231048583984 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.460321843624115 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.48533257842063904 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.5054908990859985 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.5460034608840942 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.5137321352958679 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.5609502196311951 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.4001336991786957 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.4369596838951111 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.42696064710617065 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.4693056344985962 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.4086011052131653 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.48308905959129333 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.432908296585083 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.4992391765117645 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.34997496008872986 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.45308423042297363 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.46718907356262207 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.5256445407867432
In [27]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lora" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.scale" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
if "lin.m" in name:
print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 88.6770 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 83.6840 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 85.9274 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 83.4974 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 73.3022 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 80.9918 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 73.7286 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 77.5629 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 81.9203 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 78.9050 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 80.0022 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 81.3284 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 76.1209 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 70.1328 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 71.7114 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 78.6895 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 80.5862 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 81.8691 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 81.0452 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 79.2133 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 70.1643 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 65.6669 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 68.0408 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 71.9416 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.9671 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 80.9681 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.2276 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 82.0434 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 72.1863 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 73.0758 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 70.6866 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 69.2462 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 75.9484 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 77.3235 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 76.9465 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 79.4074 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 62.7360 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 63.9343 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 66.6852 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 67.7626 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 63.4330 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 69.0939 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 66.2249 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 70.6252 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 56.1786 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 67.4088 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 71.9610 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 76.3459 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 67.8149 distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.8788 distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 66.3004 distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 53.4427 distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 67.3438 distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 40.8041 distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 49.3765 distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 46.0389 distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 63.8303 distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 51.8581 distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 59.3891 distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 52.9394 distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 47.0823 distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 47.6743 distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 48.7647 distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 44.8163 distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 57.6219 distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 52.3231 distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 63.9857 distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 52.2575 distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 46.9235 distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 45.0910 distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 45.7886 distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 41.9345 distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 63.2718 distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 53.4324 distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 66.0361 distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 52.7628 distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.4346 distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 45.3650 distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 49.2689 distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 46.8675 distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 58.5162 distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 47.5943 distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 62.9462 distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 48.5050 distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 45.4089 distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 38.7889 distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 44.1285 distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 42.3015 distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 55.4135 distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 41.3333 distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 58.2401 distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 44.2531 distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 49.2695 distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 33.0258 distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 43.1049 distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 44.5303 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 31.4448 distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 22.4398 distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 30.1271 distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 21.4311 distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 31.0435 distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 16.8509 distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 19.6845 distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 17.7480 distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 28.3183 distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 20.1127 distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 25.9295 distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 20.4797 distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 19.3969 distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 18.3722 distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 19.8745 distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 17.4045 distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 24.9047 distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 20.1862 distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 29.0122 distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 20.1494 distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 17.9841 distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 16.4528 distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 17.7306 distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 16.3007 distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 28.5406 distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 20.5190 distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 30.7927 distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 20.6228 distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 21.3742 distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 17.7191 distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 20.0696 distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 17.6262 distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 25.6429 distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 18.5982 distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 27.9177 distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 18.6512 distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 18.0866 distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 15.1881 distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 17.4686 distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 16.4471 distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 22.9480 distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 15.1683 distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 22.7621 distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 15.4354 distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 19.4222 distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 12.9561 distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 16.6265 distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 17.4611
In [28]:
evaluation_results = trainer.evaluate(dataset_encoded["validation"])
print(f"\nEvaluation results on the validation set:")
print(evaluation_results)
print("Last 10 Validation Accuracy Steps:", val_accuracy_steps[-10:])
print("Last 10 Validation Accuracies:", val_accuracies[-10:])
print("Last 10 Validation F1 Steps:", val_f1_steps[-10:])
print("Last 10 Validation F1 Scores:", val_f1_scores[-10:])
Evaluation results on the validation set: {'eval_loss': 0.22643864154815674, 'eval_accuracy': 0.908, 'eval_f1': 0.908069813307235, 'eval_runtime': 6.7981, 'eval_samples_per_second': 183.876, 'eval_steps_per_second': 5.884, 'epoch': 2.0} Last 10 Validation Accuracy Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564] Last 10 Validation Accuracies: [0.9112, 0.9096, 0.9096, 0.912, 0.912, 0.9144, 0.9128, 0.9072, 0.908, 0.9208421052631579] Last 10 Validation F1 Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564] Last 10 Validation F1 Scores: [0.9112445094405226, 0.9096571916118659, 0.9096571916118659, 0.9119863574704475, 0.9119722515145877, 0.9143367287669, 0.9127653425067421, 0.9072402344324614, 0.908069813307235, 0.9208423095953104]
In [ ]: