In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(torch.cuda.get_device_name(0))

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

import bitsandbytes
import peft
import transformers

print(transformers.__version__)

print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
PyTorch Version: 2.5.1+cu121
True
1
CUDA Version: 12.1
NVIDIA GeForce RTX 4080 Laptop GPU
4.50.0.dev0
bitsandbytes version: 0.45.3
peft version: 0.15.2.dev0
True

Load dataset, base model, and tokeniser¶

In [2]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score

# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")

# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)

dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)

# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
         .from_pretrained(model_ckpt, num_labels=num_labels)
         .to(device))
print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
  )
  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)
In [3]:
torch.autograd.set_detect_anomaly(True)

# Define the performance metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}

LoRA¶

In [4]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Modify model by injecting LoRA and DoRA layers
# Borrowed from: https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim)) #### all zeroes!
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x


class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)


class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha, scaling_factor=1.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.m = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
        #self.scale = nn.Parameter(torch.tensor(float(scaling_factor)))
        self.scale = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))

    def forward(self, x):
        linear_output = self.linear(x)
        lora_output = self.lora(x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=-1, keepdim=True) + 1e-9)        
        dora_modification = self.scale * self.m * lora_output_norm
        return linear_output + dora_modification
In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model

# Function to inject DoRA into specified linear layers
def inject_dora_all_attn(model, rank, alpha, scaling_factor=1.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            dora_linear = LinearWithDoRA(original_linear, rank, alpha, scaling_factor)
            setattr(parent_module, name.split('.')[-1], dora_linear)
    return model

def count_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params, 100 * trainable_params / total_params

def freeze_model_layers(model, unfreeze_pre_classifier=False):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze LoRA and DoRA-specific params
    for name, param in model.named_parameters():
        if (
            "lora.A" in name
            or "lora.B" in name
            or name.endswith(".m")  # For DoRA
            or name.endswith(".m_in") # For DDoRA
            or name.endswith(".m_out") # For DDoRA
            or "scale" in name 
        ):
            param.requires_grad = True

    # Unfreeze classifier layer (always)
    for name, param in model.named_parameters():
        if name.startswith("classifier."):
            param.requires_grad = True

    # unfreeze pre-classifier
    if unfreeze_pre_classifier:
        for name, param in model.named_parameters():
            if name.startswith("pre_classifier."):            
                param.requires_grad = True
In [6]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt

import copy
torch.manual_seed(137)

lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 1.5e-5 ##########
weight_decay = 0.0
output_dir_prefix = "finetuned-imdb-"

model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)

total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")

eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"

training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)

trainer_lora_all_attn = Trainer(model=model_lora_all_attn, args=training_args_lora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
LoRA (All Attention) - Total parameters: 67,544,834
LoRA (All Attention) - Trainable parameters: 1,181,954 (1.75%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
In [7]:
trainer_lora_all_attn.train()
eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print('Prediction drift')
outputs = trainer_lora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])

print (torch.cuda.memory_summary())

print('LoRA Heatmap')
layer_names = []
b_norms = []

for name, param in model_lora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())

sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
[3910/3910 1:29:10, Epoch 5/5]
Step Training Loss Validation Loss Accuracy F1
50 0.549600 0.334510 0.864800 0.864919
100 0.363100 0.314867 0.873600 0.873637
150 0.368100 0.283882 0.884000 0.884009
200 0.289200 0.270112 0.887200 0.887403
250 0.304900 0.246406 0.903200 0.903162
300 0.304000 0.256413 0.904800 0.904319
350 0.313300 0.257387 0.896000 0.896121
400 0.275000 0.360744 0.857600 0.857483
450 0.278100 0.277102 0.896800 0.897004
500 0.305700 0.243079 0.904000 0.904014
550 0.285300 0.258897 0.897600 0.897417
600 0.303100 0.258372 0.900800 0.900313
650 0.269200 0.246777 0.901600 0.901607
700 0.260800 0.241955 0.904000 0.903848
750 0.258400 0.243109 0.909600 0.909285
800 0.282400 0.338753 0.866400 0.863283
850 0.237000 0.239093 0.908800 0.908864
900 0.217300 0.261071 0.907200 0.907265
950 0.231800 0.263062 0.904000 0.903630
1000 0.229900 0.261102 0.907200 0.907155
1050 0.221300 0.245763 0.904000 0.903360
1100 0.218000 0.285926 0.899200 0.898366
1150 0.208400 0.238600 0.906400 0.906363
1200 0.211900 0.268987 0.906400 0.905927
1250 0.231800 0.293073 0.892000 0.892213
1300 0.231000 0.240379 0.904800 0.904965
1350 0.232700 0.251226 0.903200 0.903375
1400 0.231500 0.236147 0.908800 0.908771
1450 0.216500 0.243060 0.906400 0.906119
1500 0.218400 0.235548 0.908000 0.907745
1550 0.223300 0.237270 0.908800 0.908579
1600 0.174400 0.258777 0.914400 0.914256
1650 0.156200 0.266737 0.911200 0.910846
1700 0.168200 0.254983 0.914400 0.914337
1750 0.184600 0.265700 0.902400 0.902585
1800 0.182500 0.258778 0.907200 0.907287
1850 0.184000 0.251654 0.912800 0.912877
1900 0.154400 0.277561 0.906400 0.906447
1950 0.192000 0.241895 0.903200 0.903296
2000 0.185900 0.254578 0.913600 0.913789
2050 0.197500 0.240309 0.902400 0.902245
2100 0.192700 0.260683 0.901600 0.901103
2150 0.147100 0.262304 0.908000 0.908081
2200 0.183600 0.251887 0.910400 0.910439
2250 0.207200 0.251087 0.909600 0.909750
2300 0.196100 0.249030 0.900800 0.901003
2350 0.172600 0.248261 0.910400 0.910426
2400 0.146400 0.240490 0.919200 0.919168
2450 0.145300 0.248667 0.908800 0.908598
2500 0.108100 0.305569 0.899200 0.899355
2550 0.140500 0.277428 0.910400 0.910142
2600 0.152600 0.284274 0.916800 0.916913
2650 0.141600 0.252135 0.916800 0.916731
2700 0.133600 0.268836 0.912000 0.911958
2750 0.151300 0.262349 0.914400 0.913991
2800 0.121400 0.255012 0.917600 0.917494
2850 0.131700 0.269369 0.913600 0.913222
2900 0.152800 0.246611 0.916800 0.916716
2950 0.132300 0.251266 0.908000 0.907724
3000 0.142600 0.265384 0.918400 0.918532
3050 0.139300 0.260968 0.909600 0.909483
3100 0.137700 0.303503 0.904000 0.903330
3150 0.121600 0.266256 0.920000 0.920056
3200 0.101800 0.283502 0.911200 0.911288
3250 0.090600 0.293355 0.912800 0.912635
3300 0.112900 0.290351 0.912000 0.911958
3350 0.086300 0.298528 0.916000 0.915907
3400 0.086100 0.310346 0.915200 0.915280
3450 0.093600 0.306871 0.914400 0.914352
3500 0.107200 0.312322 0.917600 0.917606
3550 0.121300 0.315530 0.912800 0.912844
3600 0.117800 0.316997 0.915200 0.914956
3650 0.133400 0.311634 0.911200 0.911179
3700 0.107600 0.316201 0.915200 0.915248
3750 0.107800 0.312200 0.916000 0.916019
3800 0.086900 0.314815 0.912800 0.912844
3850 0.105500 0.314706 0.913600 0.913600
3900 0.069900 0.314938 0.912800 0.912780

LoRA (All Attention) Test Results: {'eval_loss': 0.22863808274269104, 'eval_accuracy': 0.9094315789473684, 'eval_f1': 0.9093754564531454, 'eval_runtime': 129.5483, 'eval_samples_per_second': 183.329, 'eval_steps_per_second': 5.735, 'epoch': 5.0}
Prediction drift
[[-1.9898635  1.6761417]
 [-2.4649115  2.2287219]
 [ 2.1762567 -1.8627577]
 [-1.866952   1.787821 ]
 [ 2.8843796 -2.5472975]] [1 1 0 1 0]
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 553486 KiB |   3293 MiB |  90884 GiB |  90883 GiB |
|       from large pool | 546048 KiB |   3264 MiB |  90258 GiB |  90258 GiB |
|       from small pool |   7438 KiB |     31 MiB |    625 GiB |    625 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 553486 KiB |   3293 MiB |  90884 GiB |  90883 GiB |
|       from large pool | 546048 KiB |   3264 MiB |  90258 GiB |  90258 GiB |
|       from small pool |   7438 KiB |     31 MiB |    625 GiB |    625 GiB |
|---------------------------------------------------------------------------|
| Requested memory      | 551272 KiB |   3290 MiB |  90677 GiB |  90676 GiB |
|       from large pool | 543836 KiB |   3260 MiB |  90054 GiB |  90053 GiB |
|       from small pool |   7436 KiB |     31 MiB |    623 GiB |    623 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   3514 MiB |   3514 MiB |   3514 MiB |      0 B   |
|       from large pool |   3476 MiB |   3476 MiB |   3476 MiB |      0 B   |
|       from small pool |     38 MiB |     38 MiB |     38 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory |  67058 KiB | 308581 KiB |  40038 GiB |  40038 GiB |
|       from large pool |  64256 KiB | 300928 KiB |  39368 GiB |  39368 GiB |
|       from small pool |   2802 KiB |   9018 KiB |    670 GiB |    670 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     364    |     491    |   10562 K  |   10562 K  |
|       from large pool |      82    |     140    |    3162 K  |    3162 K  |
|       from small pool |     282    |     395    |    7400 K  |    7400 K  |
|---------------------------------------------------------------------------|
| Active allocs         |     364    |     491    |   10562 K  |   10562 K  |
|       from large pool |      82    |     140    |    3162 K  |    3162 K  |
|       from small pool |     282    |     395    |    7400 K  |    7400 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      79    |      79    |      79    |       0    |
|       from large pool |      60    |      60    |      60    |       0    |
|       from small pool |      19    |      19    |      19    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      27    |      44    |    4813 K  |    4813 K  |
|       from large pool |      18    |      23    |     935 K  |     935 K  |
|       from small pool |       9    |      23    |    3878 K  |    3878 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

LoRA Heatmap
No description has been provided for this image
In [8]:
print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20137116312980652
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00041380099719390273
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.2003246247768402
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.00042878554086200893
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.19911907613277435
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.000338834710419178
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20047757029533386
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00034657015930861235
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1993226706981659
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.0004094692412763834
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.1983720064163208
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.000427080609370023
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.20035219192504883
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030109973158687353
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.20086750388145447
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.00034795209649018943
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19976946711540222
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00040535288280807436
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.19660572707653046
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00038226376636885107
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.19878578186035156
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00027373715420253575
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19986116886138916
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.0003197303449269384
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.1998608410358429
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00037932227132841945
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.2004813402891159
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00038690734072588384
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.2002096176147461
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00028052262496203184
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.19950413703918457
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.00029115224606357515
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.19811344146728516
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.0004094548639841378
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.1975163221359253
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00039162003668025136
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.1999872624874115
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.000269879907136783
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19805268943309784
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00028756255051121116
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.20075345039367676
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00037421341403387487
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.19600680470466614
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003485583874862641
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19779205322265625
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002423622936476022
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.19977574050426483
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0003247036365792155
In [9]:
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name or name.endswith(".m"):
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8491
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0578
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.7708
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0597
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.5879
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0473
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.8318
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0484
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8054
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0578
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.5557
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0595
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.9069
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0428
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.9222
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0486
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.7035
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0568
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.3314
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0532
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.6290
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0388
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.6716
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0449
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.6330
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0529
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.8436
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0541
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.8198
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0395
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.7259
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0411
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.6222
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0573
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.4509
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0547
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.8650
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0380
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.6073
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0405
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.9097
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0525
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.3422
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0490
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.4174
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0341
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7203
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0450

DoRA¶

In [10]:
import copy
torch.manual_seed(137)

lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 2e-2 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"

model_dora_all_attn = copy.deepcopy(model)
model_dora_all_attn = inject_dora_all_attn(model_dora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_dora_all_attn, unfreeze_pre_classifier=True)

total_params_dora, trainable_params_dora, percentage_dora = count_trainable_parameters(model_dora_all_attn)
print(f"\nDoRA (All Attention) - Total parameters: {total_params_dora:,}")
print(f"DoRA (All Attention) - Trainable parameters: {trainable_params_dora:,} ({percentage_dora:.2f}%)")

# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_dora_all_attn.named_parameters():
    if param.requires_grad:
        print(name)

eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"

training_args_dora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}dora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)

trainer_dora_all_attn = Trainer(model=model_dora_all_attn, args=training_args_dora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DoRA (All Attention) - Total parameters: 67,581,698
DoRA (All Attention) - Trainable parameters: 1,218,818 (1.80%)

Trainable parameters after freezing:
distilbert.transformer.layer.0.attention.q_lin.m
distilbert.transformer.layer.0.attention.q_lin.scale
distilbert.transformer.layer.0.attention.q_lin.lora.A
distilbert.transformer.layer.0.attention.q_lin.lora.B
distilbert.transformer.layer.0.attention.k_lin.m
distilbert.transformer.layer.0.attention.k_lin.scale
distilbert.transformer.layer.0.attention.k_lin.lora.A
distilbert.transformer.layer.0.attention.k_lin.lora.B
distilbert.transformer.layer.0.attention.v_lin.m
distilbert.transformer.layer.0.attention.v_lin.scale
distilbert.transformer.layer.0.attention.v_lin.lora.A
distilbert.transformer.layer.0.attention.v_lin.lora.B
distilbert.transformer.layer.0.attention.out_lin.m
distilbert.transformer.layer.0.attention.out_lin.scale
distilbert.transformer.layer.0.attention.out_lin.lora.A
distilbert.transformer.layer.0.attention.out_lin.lora.B
distilbert.transformer.layer.1.attention.q_lin.m
distilbert.transformer.layer.1.attention.q_lin.scale
distilbert.transformer.layer.1.attention.q_lin.lora.A
distilbert.transformer.layer.1.attention.q_lin.lora.B
distilbert.transformer.layer.1.attention.k_lin.m
distilbert.transformer.layer.1.attention.k_lin.scale
distilbert.transformer.layer.1.attention.k_lin.lora.A
distilbert.transformer.layer.1.attention.k_lin.lora.B
distilbert.transformer.layer.1.attention.v_lin.m
distilbert.transformer.layer.1.attention.v_lin.scale
distilbert.transformer.layer.1.attention.v_lin.lora.A
distilbert.transformer.layer.1.attention.v_lin.lora.B
distilbert.transformer.layer.1.attention.out_lin.m
distilbert.transformer.layer.1.attention.out_lin.scale
distilbert.transformer.layer.1.attention.out_lin.lora.A
distilbert.transformer.layer.1.attention.out_lin.lora.B
distilbert.transformer.layer.2.attention.q_lin.m
distilbert.transformer.layer.2.attention.q_lin.scale
distilbert.transformer.layer.2.attention.q_lin.lora.A
distilbert.transformer.layer.2.attention.q_lin.lora.B
distilbert.transformer.layer.2.attention.k_lin.m
distilbert.transformer.layer.2.attention.k_lin.scale
distilbert.transformer.layer.2.attention.k_lin.lora.A
distilbert.transformer.layer.2.attention.k_lin.lora.B
distilbert.transformer.layer.2.attention.v_lin.m
distilbert.transformer.layer.2.attention.v_lin.scale
distilbert.transformer.layer.2.attention.v_lin.lora.A
distilbert.transformer.layer.2.attention.v_lin.lora.B
distilbert.transformer.layer.2.attention.out_lin.m
distilbert.transformer.layer.2.attention.out_lin.scale
distilbert.transformer.layer.2.attention.out_lin.lora.A
distilbert.transformer.layer.2.attention.out_lin.lora.B
distilbert.transformer.layer.3.attention.q_lin.m
distilbert.transformer.layer.3.attention.q_lin.scale
distilbert.transformer.layer.3.attention.q_lin.lora.A
distilbert.transformer.layer.3.attention.q_lin.lora.B
distilbert.transformer.layer.3.attention.k_lin.m
distilbert.transformer.layer.3.attention.k_lin.scale
distilbert.transformer.layer.3.attention.k_lin.lora.A
distilbert.transformer.layer.3.attention.k_lin.lora.B
distilbert.transformer.layer.3.attention.v_lin.m
distilbert.transformer.layer.3.attention.v_lin.scale
distilbert.transformer.layer.3.attention.v_lin.lora.A
distilbert.transformer.layer.3.attention.v_lin.lora.B
distilbert.transformer.layer.3.attention.out_lin.m
distilbert.transformer.layer.3.attention.out_lin.scale
distilbert.transformer.layer.3.attention.out_lin.lora.A
distilbert.transformer.layer.3.attention.out_lin.lora.B
distilbert.transformer.layer.4.attention.q_lin.m
distilbert.transformer.layer.4.attention.q_lin.scale
distilbert.transformer.layer.4.attention.q_lin.lora.A
distilbert.transformer.layer.4.attention.q_lin.lora.B
distilbert.transformer.layer.4.attention.k_lin.m
distilbert.transformer.layer.4.attention.k_lin.scale
distilbert.transformer.layer.4.attention.k_lin.lora.A
distilbert.transformer.layer.4.attention.k_lin.lora.B
distilbert.transformer.layer.4.attention.v_lin.m
distilbert.transformer.layer.4.attention.v_lin.scale
distilbert.transformer.layer.4.attention.v_lin.lora.A
distilbert.transformer.layer.4.attention.v_lin.lora.B
distilbert.transformer.layer.4.attention.out_lin.m
distilbert.transformer.layer.4.attention.out_lin.scale
distilbert.transformer.layer.4.attention.out_lin.lora.A
distilbert.transformer.layer.4.attention.out_lin.lora.B
distilbert.transformer.layer.5.attention.q_lin.m
distilbert.transformer.layer.5.attention.q_lin.scale
distilbert.transformer.layer.5.attention.q_lin.lora.A
distilbert.transformer.layer.5.attention.q_lin.lora.B
distilbert.transformer.layer.5.attention.k_lin.m
distilbert.transformer.layer.5.attention.k_lin.scale
distilbert.transformer.layer.5.attention.k_lin.lora.A
distilbert.transformer.layer.5.attention.k_lin.lora.B
distilbert.transformer.layer.5.attention.v_lin.m
distilbert.transformer.layer.5.attention.v_lin.scale
distilbert.transformer.layer.5.attention.v_lin.lora.A
distilbert.transformer.layer.5.attention.v_lin.lora.B
distilbert.transformer.layer.5.attention.out_lin.m
distilbert.transformer.layer.5.attention.out_lin.scale
distilbert.transformer.layer.5.attention.out_lin.lora.A
distilbert.transformer.layer.5.attention.out_lin.lora.B
pre_classifier.weight
pre_classifier.bias
classifier.weight
classifier.bias
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
In [11]:
trainer_dora_all_attn.train()
eval_results_dora_all_attn = trainer_dora_all_attn.evaluate(dataset_encoded["test"])
print(f"DoRA (All Attention) Test Results: {eval_results_dora_all_attn}")
print('Prediction drift')
outputs = trainer_dora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])

print (torch.cuda.memory_summary())

print('LoRA Heatmap')
layer_names = []
b_norms = []

for name, param in model_dora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())

sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()

import torch
import numpy as np
import matplotlib.pyplot as plt

m_values = []

for name, module in model_dora_all_attn.named_modules():
    if isinstance(module, LinearWithDoRA):
        if hasattr(module, 'm'):
            m_param = module.m.detach().cpu().numpy().flatten()
            m_values.extend(m_param)

m_values = np.array(m_values)

# Analyze the distribution
print(f"Mean of m values: {np.mean(m_values):.4f}")
print(f"Standard deviation of m values: {np.std(m_values):.4f}")
print(f"Minimum m value: {np.min(m_values):.4f}")
print(f"Maximum m value: {np.max(m_values):.4f}")

# Plot a histogram
plt.hist(m_values, bins=50, alpha=0.7)
plt.title('Distribution of Learned m Values (DoRA)')
plt.xlabel('Magnitude (m)')
plt.ylabel('Frequency')
plt.show()
[3910/3910 2:15:09, Epoch 5/5]
Step Training Loss Validation Loss Accuracy F1
50 1.102800 0.291960 0.870400 0.869901
100 0.356100 0.278253 0.894400 0.894522
150 0.343500 0.268145 0.897600 0.897671
200 0.305300 0.261925 0.897600 0.897630
250 0.321600 0.257692 0.904800 0.904558
300 0.305000 0.272489 0.887200 0.887449
350 0.300000 0.259093 0.901600 0.901282
400 0.266300 0.279098 0.884000 0.884259
450 0.258800 0.272474 0.892000 0.891513
500 0.288000 0.242344 0.905600 0.905585
550 0.280400 0.270589 0.901600 0.900959
600 0.301900 0.249609 0.897600 0.897794
650 0.277600 0.265940 0.904800 0.904469
700 0.278000 0.279490 0.911200 0.910675
750 0.293100 0.268084 0.897600 0.896787
800 0.254900 0.299735 0.896800 0.895963
850 0.253600 0.252459 0.896800 0.896935
900 0.235600 0.269197 0.900800 0.900881
950 0.271800 0.285552 0.902400 0.902503
1000 0.232000 0.225912 0.908800 0.908579
1050 0.219700 0.244564 0.914400 0.914337
1100 0.234700 0.229250 0.913600 0.913445
1150 0.217400 0.221898 0.905600 0.905600
1200 0.230000 0.237117 0.909600 0.909350
1250 0.221200 0.219938 0.912000 0.911958
1300 0.255900 0.222013 0.908000 0.908007
1350 0.260000 0.216425 0.915200 0.914956
1400 0.223400 0.217995 0.911200 0.911256
1450 0.219600 0.222689 0.912000 0.912000
1500 0.215300 0.224147 0.914400 0.914337
1550 0.227700 0.224390 0.919200 0.919251
1600 0.196300 0.254192 0.920800 0.920870
1650 0.182700 0.261529 0.916000 0.915823
1700 0.205700 0.254054 0.916800 0.916579
1750 0.199100 0.270258 0.912000 0.912000
1800 0.209500 0.224310 0.917600 0.917509
1850 0.211800 0.236137 0.915200 0.915200
1900 0.177900 0.244019 0.916800 0.916760
1950 0.204800 0.241636 0.916800 0.916836
2000 0.205300 0.228495 0.919200 0.919280
2050 0.212400 0.233124 0.918400 0.918435
2100 0.200900 0.225622 0.911200 0.911278
2150 0.153700 0.256996 0.912000 0.912061
2200 0.209800 0.220482 0.914400 0.914352
2250 0.196600 0.229680 0.916000 0.915858
2300 0.202500 0.218464 0.910400 0.910504
2350 0.181500 0.225591 0.917600 0.917509
2400 0.185800 0.250574 0.912000 0.911543
2450 0.157600 0.230985 0.915200 0.914915
2500 0.141200 0.245401 0.910400 0.910473
2550 0.136100 0.296957 0.918400 0.918165
2600 0.167800 0.249483 0.902400 0.902613
2650 0.157600 0.243417 0.912000 0.911705
2700 0.157700 0.241792 0.918400 0.918412
2750 0.158800 0.230028 0.912000 0.911683
2800 0.133300 0.229130 0.916800 0.916651
2850 0.145600 0.294840 0.917600 0.917251
2900 0.157700 0.267986 0.916000 0.915728
2950 0.150400 0.273056 0.912800 0.912384
3000 0.157900 0.240002 0.918400 0.918184
3050 0.157400 0.243159 0.920000 0.919919
3100 0.148800 0.250069 0.914400 0.913894
3150 0.124400 0.255732 0.916800 0.916700
3200 0.127000 0.246215 0.920800 0.920698
3250 0.116300 0.283555 0.912000 0.911786
3300 0.110700 0.277807 0.912000 0.911860
3350 0.112800 0.296765 0.916800 0.916634
3400 0.115000 0.259858 0.912800 0.912670
3450 0.117300 0.274300 0.912800 0.912598
3500 0.110300 0.284478 0.915200 0.915114
3550 0.123000 0.267429 0.916000 0.915923
3600 0.126300 0.266704 0.918400 0.918202
3650 0.131900 0.254044 0.916000 0.915892
3700 0.125600 0.256258 0.916800 0.916634
3750 0.109800 0.262411 0.916000 0.915938
3800 0.105500 0.270246 0.914400 0.914238
3850 0.115400 0.272160 0.914400 0.914238
3900 0.108500 0.268805 0.914400 0.914238

DoRA (All Attention) Test Results: {'eval_loss': 0.21464677155017853, 'eval_accuracy': 0.9129263157894737, 'eval_f1': 0.9128732729336847, 'eval_runtime': 146.8168, 'eval_samples_per_second': 161.766, 'eval_steps_per_second': 5.061, 'epoch': 5.0}
Prediction drift
[[-2.0174963  1.9359767]
 [-2.774578   2.4995706]
 [ 1.9803579 -2.1773782]
 [-1.6433766  1.4992952]
 [ 6.4219594 -7.323969 ]] [1 1 0 1 0]
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    809 MiB |   5291 MiB | 230768 GiB | 230767 GiB |
|       from large pool |    794 MiB |   5252 MiB | 229489 GiB | 229488 GiB |
|       from small pool |     14 MiB |     40 MiB |   1278 GiB |   1278 GiB |
|---------------------------------------------------------------------------|
| Active memory         |    809 MiB |   5291 MiB | 230768 GiB | 230767 GiB |
|       from large pool |    794 MiB |   5252 MiB | 229489 GiB | 229488 GiB |
|       from small pool |     14 MiB |     40 MiB |   1278 GiB |   1278 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |    805 MiB |   5287 MiB | 230551 GiB | 230550 GiB |
|       from large pool |    790 MiB |   5248 MiB | 229278 GiB | 229277 GiB |
|       from small pool |     14 MiB |     40 MiB |   1273 GiB |   1273 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   5500 MiB |   5500 MiB |   8120 MiB |   2620 MiB |
|       from large pool |   5456 MiB |   5456 MiB |   8048 MiB |   2592 MiB |
|       from small pool |     44 MiB |     44 MiB |     72 MiB |     28 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |  98691 KiB | 545173 KiB |  88783 GiB |  88783 GiB |
|       from large pool |  91218 KiB | 540498 KiB |  87421 GiB |  87420 GiB |
|       from small pool |   7473 KiB |  17003 KiB |   1362 GiB |   1362 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     765    |    1012    |   26004 K  |   26003 K  |
|       from large pool |     123    |     229    |    7919 K  |    7919 K  |
|       from small pool |     642    |     851    |   18084 K  |   18084 K  |
|---------------------------------------------------------------------------|
| Active allocs         |     765    |    1012    |   26004 K  |   26003 K  |
|       from large pool |     123    |     229    |    7919 K  |    7919 K  |
|       from small pool |     642    |     851    |   18084 K  |   18084 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     134    |     134    |     181    |      47    |
|       from large pool |     112    |     112    |     145    |      33    |
|       from small pool |      22    |      22    |      36    |      14    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      35    |      53    |   12196 K  |   12196 K  |
|       from large pool |      17    |      26    |    2304 K  |    2304 K  |
|       from small pool |      18    |      32    |    9892 K  |    9892 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

LoRA Heatmap
No description has been provided for this image
Mean of m values: 0.0014
Standard deviation of m values: 0.5832
Minimum m value: -4.2395
Maximum m value: 3.2943
No description has been provided for this image
In [12]:
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.m 0.6335324048995972
distilbert.transformer.layer.0.attention.k_lin.m 0.5340768694877625
distilbert.transformer.layer.0.attention.v_lin.m 0.4119078516960144
distilbert.transformer.layer.0.attention.out_lin.m 0.20187006890773773
distilbert.transformer.layer.1.attention.q_lin.m 0.45603156089782715
distilbert.transformer.layer.1.attention.k_lin.m 0.39353370666503906
distilbert.transformer.layer.1.attention.v_lin.m 0.2410321682691574
distilbert.transformer.layer.1.attention.out_lin.m 0.22153238952159882
distilbert.transformer.layer.2.attention.q_lin.m 0.4081450402736664
distilbert.transformer.layer.2.attention.k_lin.m 0.44084256887435913
distilbert.transformer.layer.2.attention.v_lin.m 0.24717780947685242
distilbert.transformer.layer.2.attention.out_lin.m 0.16525350511074066
distilbert.transformer.layer.3.attention.q_lin.m 0.4876405596733093
distilbert.transformer.layer.3.attention.k_lin.m 0.4945552349090576
distilbert.transformer.layer.3.attention.v_lin.m 0.2434060424566269
distilbert.transformer.layer.3.attention.out_lin.m 0.24130374193191528
distilbert.transformer.layer.4.attention.q_lin.m 0.48577117919921875
distilbert.transformer.layer.4.attention.k_lin.m 0.5650743246078491
distilbert.transformer.layer.4.attention.v_lin.m 0.18068918585777283
distilbert.transformer.layer.4.attention.out_lin.m 0.2408987283706665
distilbert.transformer.layer.5.attention.q_lin.m 0.37161561846733093
distilbert.transformer.layer.5.attention.k_lin.m 0.6317863464355469
distilbert.transformer.layer.5.attention.v_lin.m 0.19014406204223633
distilbert.transformer.layer.5.attention.out_lin.m 0.28851351141929626
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.scale 2.162614583969116
distilbert.transformer.layer.0.attention.k_lin.scale 1.9926241636276245
distilbert.transformer.layer.0.attention.v_lin.scale 1.9132812023162842
distilbert.transformer.layer.0.attention.out_lin.scale 1.6308307647705078
distilbert.transformer.layer.1.attention.q_lin.scale 1.9580211639404297
distilbert.transformer.layer.1.attention.k_lin.scale 1.8660268783569336
distilbert.transformer.layer.1.attention.v_lin.scale 1.6165016889572144
distilbert.transformer.layer.1.attention.out_lin.scale 1.5104354619979858
distilbert.transformer.layer.2.attention.q_lin.scale 1.838735818862915
distilbert.transformer.layer.2.attention.k_lin.scale 1.8984363079071045
distilbert.transformer.layer.2.attention.v_lin.scale 1.6771743297576904
distilbert.transformer.layer.2.attention.out_lin.scale 1.53759765625
distilbert.transformer.layer.3.attention.q_lin.scale 1.9363187551498413
distilbert.transformer.layer.3.attention.k_lin.scale 1.934944987297058
distilbert.transformer.layer.3.attention.v_lin.scale 1.6585863828659058
distilbert.transformer.layer.3.attention.out_lin.scale 1.585425853729248
distilbert.transformer.layer.4.attention.q_lin.scale 2.027859687805176
distilbert.transformer.layer.4.attention.k_lin.scale 2.082413673400879
distilbert.transformer.layer.4.attention.v_lin.scale 1.5296372175216675
distilbert.transformer.layer.4.attention.out_lin.scale 1.618304967880249
distilbert.transformer.layer.5.attention.q_lin.scale 1.9490175247192383
distilbert.transformer.layer.5.attention.k_lin.scale 2.1951510906219482
distilbert.transformer.layer.5.attention.v_lin.scale 1.6163949966430664
distilbert.transformer.layer.5.attention.out_lin.scale 1.8345234394073486
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6016935110092163
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.31266412138938904
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.6316931247711182
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.2952665090560913
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.6079772710800171
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.2492581456899643
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4950897693634033
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.18595407903194427
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.6037637591362
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.2575852870941162
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5715399384498596
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.21878382563591003
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5992859601974487
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.23800623416900635
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.5236724019050598
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24507002532482147
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5646731853485107
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.28050732612609863
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5719441771507263
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2973446249961853
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.5592368841171265
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.20374992489814758
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4538516402244568
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.18519560992717743
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5853655934333801
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.30569520592689514
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.6070346832275391
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.31933528184890747
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.5087218284606934
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.21019062399864197
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.5895759463310242
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.21719345450401306
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.6171450614929199
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.21375727653503418
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.49676698446273804
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2762928605079651
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.5012800097465515
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.17582902312278748
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.5615952610969543
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.17805537581443787
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.44116735458374023
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.16018781065940857
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.47184187173843384
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.27603811025619507
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.39102858304977417
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.14989939332008362
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.4704250693321228
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.1333407163619995
In [13]:
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 83.6438
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 50.3954
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 88.2270
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 49.0835
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 86.0185
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 41.9394
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 69.8777
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 32.6178
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 84.0461
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 42.5906
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 79.4528
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 37.7325
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 83.9680
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 39.0889
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 74.0652
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 40.8909
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 78.3896
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 45.2058
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 79.5360
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 47.1290
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 77.7262
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 34.1224
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 63.6211
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 31.4881
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.7695
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 49.9913
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.6226
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 50.7520
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 71.5488
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 34.2226
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 82.8060
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 36.0456
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 86.1940
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 37.5884
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 69.4333
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 43.1834
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 70.5987
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 28.1566
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 79.8895
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 29.1581
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 61.4309
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 28.0646
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 65.4785
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 40.8050
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 55.7645
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 23.9943
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 66.5217
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 21.6479
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.m weight norm: 24.3380
distilbert.transformer.layer.0.attention.k_lin.m weight norm: 21.5448
distilbert.transformer.layer.0.attention.v_lin.m weight norm: 16.9134
distilbert.transformer.layer.0.attention.out_lin.m weight norm: 10.2597
distilbert.transformer.layer.1.attention.q_lin.m weight norm: 18.8046
distilbert.transformer.layer.1.attention.k_lin.m weight norm: 17.2833
distilbert.transformer.layer.1.attention.v_lin.m weight norm: 10.7684
distilbert.transformer.layer.1.attention.out_lin.m weight norm: 9.8165
distilbert.transformer.layer.2.attention.q_lin.m weight norm: 16.5732
distilbert.transformer.layer.2.attention.k_lin.m weight norm: 17.3353
distilbert.transformer.layer.2.attention.v_lin.m weight norm: 11.1499
distilbert.transformer.layer.2.attention.out_lin.m weight norm: 8.2443
distilbert.transformer.layer.3.attention.q_lin.m weight norm: 19.2109
distilbert.transformer.layer.3.attention.k_lin.m weight norm: 19.1746
distilbert.transformer.layer.3.attention.v_lin.m weight norm: 11.2522
distilbert.transformer.layer.3.attention.out_lin.m weight norm: 11.3760
distilbert.transformer.layer.4.attention.q_lin.m weight norm: 20.1803
distilbert.transformer.layer.4.attention.k_lin.m weight norm: 21.7132
distilbert.transformer.layer.4.attention.v_lin.m weight norm: 9.0074
distilbert.transformer.layer.4.attention.out_lin.m weight norm: 12.0543
distilbert.transformer.layer.5.attention.q_lin.m weight norm: 16.7970
distilbert.transformer.layer.5.attention.k_lin.m weight norm: 22.2302
distilbert.transformer.layer.5.attention.v_lin.m weight norm: 10.2690
distilbert.transformer.layer.5.attention.out_lin.m weight norm: 14.5015
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.scale weight norm: 63.7493
distilbert.transformer.layer.0.attention.k_lin.scale weight norm: 58.9528
distilbert.transformer.layer.0.attention.v_lin.scale weight norm: 55.6887
distilbert.transformer.layer.0.attention.out_lin.scale weight norm: 47.4388
distilbert.transformer.layer.1.attention.q_lin.scale weight norm: 57.3731
distilbert.transformer.layer.1.attention.k_lin.scale weight norm: 54.8432
distilbert.transformer.layer.1.attention.v_lin.scale weight norm: 47.2854
distilbert.transformer.layer.1.attention.out_lin.scale weight norm: 44.8088
distilbert.transformer.layer.2.attention.q_lin.scale weight norm: 53.9172
distilbert.transformer.layer.2.attention.k_lin.scale weight norm: 55.4214
distilbert.transformer.layer.2.attention.v_lin.scale weight norm: 48.5935
distilbert.transformer.layer.2.attention.out_lin.scale weight norm: 44.7601
distilbert.transformer.layer.3.attention.q_lin.scale weight norm: 56.9923
distilbert.transformer.layer.3.attention.k_lin.scale weight norm: 56.9395
distilbert.transformer.layer.3.attention.v_lin.scale weight norm: 48.5040
distilbert.transformer.layer.3.attention.out_lin.scale weight norm: 47.1696
distilbert.transformer.layer.4.attention.q_lin.scale weight norm: 59.5007
distilbert.transformer.layer.4.attention.k_lin.scale weight norm: 61.3537
distilbert.transformer.layer.4.attention.v_lin.scale weight norm: 45.5443
distilbert.transformer.layer.4.attention.out_lin.scale weight norm: 48.1659
distilbert.transformer.layer.5.attention.q_lin.scale weight norm: 56.5651
distilbert.transformer.layer.5.attention.k_lin.scale weight norm: 63.7824
distilbert.transformer.layer.5.attention.v_lin.scale weight norm: 47.5776
distilbert.transformer.layer.5.attention.out_lin.scale weight norm: 53.4955

DDoRA, Double DoRA¶

(Double Weight-Decomposed Low-Rank Adaptation)¶

In [14]:
class LinearWithDoubleDoRA(nn.Module):
    def __init__(self, linear, rank, alpha, scaling_factor=1.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())        
        self.m_out = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
        self.m_in = nn.Parameter(torch.randn(linear.in_features, 1) * std_dev)        
        self.scale_out = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))
        self.scale_in = nn.Parameter(torch.full((linear.in_features, 1), float(scaling_factor)))

    def forward(self, x):
        scaled_x = x * self.scale_in.T * self.m_in.T # Broadcasting m_in + scaling
        linear_output = self.linear(x)
        lora_output = self.lora(scaled_x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)
        dora_modification = self.scale_out * self.m_out * lora_output_norm
        return linear_output + dora_modification
In [15]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

def inject_ddora_all_attn(model, rank, alpha, scaling_factor=1.0):
    #target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            dora_linear = LinearWithDoubleDoRA(original_linear, rank, alpha, scaling_factor)
            setattr(parent_module, name.split('.')[-1], dora_linear) 
    return model
In [16]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt

import copy
torch.manual_seed(137)

lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 0.015 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"

model_ddora_all_attn = copy.deepcopy(model)
model_ddora_all_attn = inject_ddora_all_attn(model_ddora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_ddora_all_attn, unfreeze_pre_classifier=True)

total_params_ddora, trainable_params_ddora, percentage_ddora = count_trainable_parameters(model_ddora_all_attn)
print(f"\nDDoRA (All Attention) - Total parameters: {total_params_ddora:,}")
print(f"DDoRA (All Attention) - Trainable parameters: {trainable_params_ddora:,} ({percentage_ddora:.2f}%)")

eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"

training_args_ddora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}ddora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)

trainer_ddora_all_attn = Trainer(model=model_ddora_all_attn, args=training_args_ddora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DDoRA (All Attention) - Total parameters: 67,618,562
DDoRA (All Attention) - Trainable parameters: 1,255,682 (1.86%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
In [17]:
# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_ddora_all_attn.named_parameters():
    if param.requires_grad:
        print(name)
Trainable parameters after freezing:
distilbert.transformer.layer.0.attention.q_lin.m_out
distilbert.transformer.layer.0.attention.q_lin.m_in
distilbert.transformer.layer.0.attention.q_lin.scale_out
distilbert.transformer.layer.0.attention.q_lin.scale_in
distilbert.transformer.layer.0.attention.q_lin.lora.A
distilbert.transformer.layer.0.attention.q_lin.lora.B
distilbert.transformer.layer.0.attention.k_lin.m_out
distilbert.transformer.layer.0.attention.k_lin.m_in
distilbert.transformer.layer.0.attention.k_lin.scale_out
distilbert.transformer.layer.0.attention.k_lin.scale_in
distilbert.transformer.layer.0.attention.k_lin.lora.A
distilbert.transformer.layer.0.attention.k_lin.lora.B
distilbert.transformer.layer.0.attention.v_lin.m_out
distilbert.transformer.layer.0.attention.v_lin.m_in
distilbert.transformer.layer.0.attention.v_lin.scale_out
distilbert.transformer.layer.0.attention.v_lin.scale_in
distilbert.transformer.layer.0.attention.v_lin.lora.A
distilbert.transformer.layer.0.attention.v_lin.lora.B
distilbert.transformer.layer.0.attention.out_lin.m_out
distilbert.transformer.layer.0.attention.out_lin.m_in
distilbert.transformer.layer.0.attention.out_lin.scale_out
distilbert.transformer.layer.0.attention.out_lin.scale_in
distilbert.transformer.layer.0.attention.out_lin.lora.A
distilbert.transformer.layer.0.attention.out_lin.lora.B
distilbert.transformer.layer.1.attention.q_lin.m_out
distilbert.transformer.layer.1.attention.q_lin.m_in
distilbert.transformer.layer.1.attention.q_lin.scale_out
distilbert.transformer.layer.1.attention.q_lin.scale_in
distilbert.transformer.layer.1.attention.q_lin.lora.A
distilbert.transformer.layer.1.attention.q_lin.lora.B
distilbert.transformer.layer.1.attention.k_lin.m_out
distilbert.transformer.layer.1.attention.k_lin.m_in
distilbert.transformer.layer.1.attention.k_lin.scale_out
distilbert.transformer.layer.1.attention.k_lin.scale_in
distilbert.transformer.layer.1.attention.k_lin.lora.A
distilbert.transformer.layer.1.attention.k_lin.lora.B
distilbert.transformer.layer.1.attention.v_lin.m_out
distilbert.transformer.layer.1.attention.v_lin.m_in
distilbert.transformer.layer.1.attention.v_lin.scale_out
distilbert.transformer.layer.1.attention.v_lin.scale_in
distilbert.transformer.layer.1.attention.v_lin.lora.A
distilbert.transformer.layer.1.attention.v_lin.lora.B
distilbert.transformer.layer.1.attention.out_lin.m_out
distilbert.transformer.layer.1.attention.out_lin.m_in
distilbert.transformer.layer.1.attention.out_lin.scale_out
distilbert.transformer.layer.1.attention.out_lin.scale_in
distilbert.transformer.layer.1.attention.out_lin.lora.A
distilbert.transformer.layer.1.attention.out_lin.lora.B
distilbert.transformer.layer.2.attention.q_lin.m_out
distilbert.transformer.layer.2.attention.q_lin.m_in
distilbert.transformer.layer.2.attention.q_lin.scale_out
distilbert.transformer.layer.2.attention.q_lin.scale_in
distilbert.transformer.layer.2.attention.q_lin.lora.A
distilbert.transformer.layer.2.attention.q_lin.lora.B
distilbert.transformer.layer.2.attention.k_lin.m_out
distilbert.transformer.layer.2.attention.k_lin.m_in
distilbert.transformer.layer.2.attention.k_lin.scale_out
distilbert.transformer.layer.2.attention.k_lin.scale_in
distilbert.transformer.layer.2.attention.k_lin.lora.A
distilbert.transformer.layer.2.attention.k_lin.lora.B
distilbert.transformer.layer.2.attention.v_lin.m_out
distilbert.transformer.layer.2.attention.v_lin.m_in
distilbert.transformer.layer.2.attention.v_lin.scale_out
distilbert.transformer.layer.2.attention.v_lin.scale_in
distilbert.transformer.layer.2.attention.v_lin.lora.A
distilbert.transformer.layer.2.attention.v_lin.lora.B
distilbert.transformer.layer.2.attention.out_lin.m_out
distilbert.transformer.layer.2.attention.out_lin.m_in
distilbert.transformer.layer.2.attention.out_lin.scale_out
distilbert.transformer.layer.2.attention.out_lin.scale_in
distilbert.transformer.layer.2.attention.out_lin.lora.A
distilbert.transformer.layer.2.attention.out_lin.lora.B
distilbert.transformer.layer.3.attention.q_lin.m_out
distilbert.transformer.layer.3.attention.q_lin.m_in
distilbert.transformer.layer.3.attention.q_lin.scale_out
distilbert.transformer.layer.3.attention.q_lin.scale_in
distilbert.transformer.layer.3.attention.q_lin.lora.A
distilbert.transformer.layer.3.attention.q_lin.lora.B
distilbert.transformer.layer.3.attention.k_lin.m_out
distilbert.transformer.layer.3.attention.k_lin.m_in
distilbert.transformer.layer.3.attention.k_lin.scale_out
distilbert.transformer.layer.3.attention.k_lin.scale_in
distilbert.transformer.layer.3.attention.k_lin.lora.A
distilbert.transformer.layer.3.attention.k_lin.lora.B
distilbert.transformer.layer.3.attention.v_lin.m_out
distilbert.transformer.layer.3.attention.v_lin.m_in
distilbert.transformer.layer.3.attention.v_lin.scale_out
distilbert.transformer.layer.3.attention.v_lin.scale_in
distilbert.transformer.layer.3.attention.v_lin.lora.A
distilbert.transformer.layer.3.attention.v_lin.lora.B
distilbert.transformer.layer.3.attention.out_lin.m_out
distilbert.transformer.layer.3.attention.out_lin.m_in
distilbert.transformer.layer.3.attention.out_lin.scale_out
distilbert.transformer.layer.3.attention.out_lin.scale_in
distilbert.transformer.layer.3.attention.out_lin.lora.A
distilbert.transformer.layer.3.attention.out_lin.lora.B
distilbert.transformer.layer.4.attention.q_lin.m_out
distilbert.transformer.layer.4.attention.q_lin.m_in
distilbert.transformer.layer.4.attention.q_lin.scale_out
distilbert.transformer.layer.4.attention.q_lin.scale_in
distilbert.transformer.layer.4.attention.q_lin.lora.A
distilbert.transformer.layer.4.attention.q_lin.lora.B
distilbert.transformer.layer.4.attention.k_lin.m_out
distilbert.transformer.layer.4.attention.k_lin.m_in
distilbert.transformer.layer.4.attention.k_lin.scale_out
distilbert.transformer.layer.4.attention.k_lin.scale_in
distilbert.transformer.layer.4.attention.k_lin.lora.A
distilbert.transformer.layer.4.attention.k_lin.lora.B
distilbert.transformer.layer.4.attention.v_lin.m_out
distilbert.transformer.layer.4.attention.v_lin.m_in
distilbert.transformer.layer.4.attention.v_lin.scale_out
distilbert.transformer.layer.4.attention.v_lin.scale_in
distilbert.transformer.layer.4.attention.v_lin.lora.A
distilbert.transformer.layer.4.attention.v_lin.lora.B
distilbert.transformer.layer.4.attention.out_lin.m_out
distilbert.transformer.layer.4.attention.out_lin.m_in
distilbert.transformer.layer.4.attention.out_lin.scale_out
distilbert.transformer.layer.4.attention.out_lin.scale_in
distilbert.transformer.layer.4.attention.out_lin.lora.A
distilbert.transformer.layer.4.attention.out_lin.lora.B
distilbert.transformer.layer.5.attention.q_lin.m_out
distilbert.transformer.layer.5.attention.q_lin.m_in
distilbert.transformer.layer.5.attention.q_lin.scale_out
distilbert.transformer.layer.5.attention.q_lin.scale_in
distilbert.transformer.layer.5.attention.q_lin.lora.A
distilbert.transformer.layer.5.attention.q_lin.lora.B
distilbert.transformer.layer.5.attention.k_lin.m_out
distilbert.transformer.layer.5.attention.k_lin.m_in
distilbert.transformer.layer.5.attention.k_lin.scale_out
distilbert.transformer.layer.5.attention.k_lin.scale_in
distilbert.transformer.layer.5.attention.k_lin.lora.A
distilbert.transformer.layer.5.attention.k_lin.lora.B
distilbert.transformer.layer.5.attention.v_lin.m_out
distilbert.transformer.layer.5.attention.v_lin.m_in
distilbert.transformer.layer.5.attention.v_lin.scale_out
distilbert.transformer.layer.5.attention.v_lin.scale_in
distilbert.transformer.layer.5.attention.v_lin.lora.A
distilbert.transformer.layer.5.attention.v_lin.lora.B
distilbert.transformer.layer.5.attention.out_lin.m_out
distilbert.transformer.layer.5.attention.out_lin.m_in
distilbert.transformer.layer.5.attention.out_lin.scale_out
distilbert.transformer.layer.5.attention.out_lin.scale_in
distilbert.transformer.layer.5.attention.out_lin.lora.A
distilbert.transformer.layer.5.attention.out_lin.lora.B
pre_classifier.weight
pre_classifier.bias
classifier.weight
classifier.bias
In [18]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")

print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])

print('LoRA Heatmap')
layer_names = []
b_norms = []

for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())

sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()

import torch
import numpy as np
import matplotlib.pyplot as plt

m_in_values = []
m_out_values = []

for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)

# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)

# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")

# Plot histograms
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')

plt.tight_layout()
plt.show()
[3910/3910 2:50:30, Epoch 5/5]
Step Training Loss Validation Loss Accuracy F1
50 0.770700 0.286238 0.872000 0.871847
100 0.349700 0.288429 0.880000 0.879626
150 0.364200 0.300074 0.872800 0.873083
200 0.288800 0.273452 0.883200 0.883296
250 0.299600 0.251706 0.898400 0.897942
300 0.296600 0.246586 0.898400 0.898229
350 0.294200 0.256276 0.902400 0.902205
400 0.267100 0.248215 0.896800 0.896852
450 0.259600 0.261862 0.902400 0.902072
500 0.286100 0.339048 0.903200 0.903249
550 0.280100 0.271318 0.905600 0.905391
600 0.296700 0.254167 0.913600 0.913544
650 0.248500 0.248038 0.909600 0.909448
700 0.272000 0.240011 0.910400 0.910121
750 0.256800 0.272073 0.900800 0.900581
800 0.244800 0.336190 0.899200 0.898152
850 0.232100 0.244281 0.904000 0.903788
900 0.217700 0.275926 0.913600 0.913151
950 0.255400 0.276984 0.910400 0.910473
1000 0.223600 0.227863 0.909600 0.909143
1050 0.223100 0.244880 0.909600 0.909593
1100 0.229500 0.250599 0.912800 0.912539
1150 0.203900 0.231664 0.911200 0.911013
1200 0.240400 0.231101 0.911200 0.911179
1250 0.246100 0.229517 0.914400 0.914454
1300 0.248700 0.231470 0.912800 0.912765
1350 0.243300 0.221713 0.910400 0.910400
1400 0.236300 0.219190 0.912800 0.912704
1450 0.219200 0.231880 0.913600 0.913625
1500 0.212400 0.220595 0.912800 0.912559
1550 0.212400 0.239685 0.913600 0.913613
1600 0.185500 0.284434 0.899200 0.899420
1650 0.180000 0.298952 0.914400 0.914183
1700 0.200800 0.263503 0.914400 0.914443
1750 0.197100 0.274289 0.912000 0.912083
1800 0.204200 0.241341 0.904800 0.904905
1850 0.201200 0.239144 0.911200 0.911220
1900 0.187300 0.246986 0.915200 0.915130
1950 0.199200 0.275895 0.904800 0.905001
2000 0.202000 0.238767 0.917600 0.917682
2050 0.197800 0.252750 0.917600 0.917272
2100 0.199800 0.244306 0.915200 0.915099
2150 0.151500 0.297675 0.920000 0.919948
2200 0.204700 0.226266 0.913600 0.913681
2250 0.206400 0.223971 0.920000 0.919873
2300 0.191300 0.228316 0.919200 0.919206
2350 0.177600 0.251830 0.921600 0.921337
2400 0.188400 0.279854 0.918400 0.918220
2450 0.149600 0.237880 0.915200 0.915082
2500 0.148100 0.255322 0.910400 0.910462
2550 0.144100 0.284158 0.918400 0.918184
2600 0.167500 0.257225 0.908000 0.908138
2650 0.161400 0.226382 0.917600 0.917594
2700 0.157400 0.233534 0.919200 0.919229
2750 0.163500 0.243015 0.921600 0.921410
2800 0.121700 0.245816 0.920800 0.920682
2850 0.141500 0.285453 0.922400 0.922071
2900 0.163300 0.232592 0.921600 0.921549
2950 0.155700 0.241530 0.916800 0.916616
3000 0.150700 0.241032 0.916800 0.916760
3050 0.148100 0.261533 0.916000 0.915994
3100 0.146800 0.262320 0.912000 0.911413
3150 0.135700 0.245895 0.920000 0.919975
3200 0.129700 0.245331 0.921600 0.921476
3250 0.112900 0.275194 0.917600 0.917581
3300 0.124800 0.264425 0.923200 0.923013
3350 0.115600 0.292223 0.924000 0.923887
3400 0.105400 0.271411 0.917600 0.917427
3450 0.131200 0.260183 0.920000 0.919934
3500 0.109600 0.265847 0.922400 0.922314
3550 0.116700 0.269032 0.920800 0.920682
3600 0.133200 0.260401 0.924000 0.923916
3650 0.133300 0.267982 0.924000 0.923840
3700 0.124600 0.265107 0.919200 0.919194
3750 0.104300 0.276479 0.920800 0.920769
3800 0.106300 0.271788 0.921600 0.921562
3850 0.109200 0.271291 0.922400 0.922343
3900 0.110700 0.268460 0.920800 0.920755

DDoRA (All Attention) Test Results: {'eval_loss': 0.20760443806648254, 'eval_accuracy': 0.9165473684210527, 'eval_f1': 0.9165350069271784, 'eval_runtime': 155.718, 'eval_samples_per_second': 152.519, 'eval_steps_per_second': 4.771, 'epoch': 5.0}
Prediction drift
[[-2.5889697  1.7642486]
 [-3.9092817  2.3953364]
 [ 2.2905545 -2.4951563]
 [-2.053718   1.375324 ]
 [ 5.5414968 -5.709928 ]] [1 1 0 1 0]
LoRA Heatmap
No description has been provided for this image
[m_out] Mean: -0.0008, Std: 0.5493, Min: -2.8318, Max: 3.1991
[m_in ] Mean: 0.0086, Std: 0.4519, Min: -2.4116, Max: 2.3131
No description has been provided for this image
In [19]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.scale_out 2.0901947021484375
distilbert.transformer.layer.0.attention.q_lin.scale_in 1.9317268133163452
distilbert.transformer.layer.0.attention.k_lin.scale_out 2.0788636207580566
distilbert.transformer.layer.0.attention.k_lin.scale_in 1.8939082622528076
distilbert.transformer.layer.0.attention.v_lin.scale_out 2.0176312923431396
distilbert.transformer.layer.0.attention.v_lin.scale_in 1.6768195629119873
distilbert.transformer.layer.0.attention.out_lin.scale_out 1.7772330045700073
distilbert.transformer.layer.0.attention.out_lin.scale_in 1.8626189231872559
distilbert.transformer.layer.1.attention.q_lin.scale_out 2.018207311630249
distilbert.transformer.layer.1.attention.q_lin.scale_in 1.8735299110412598
distilbert.transformer.layer.1.attention.k_lin.scale_out 1.9069545269012451
distilbert.transformer.layer.1.attention.k_lin.scale_in 1.9223332405090332
distilbert.transformer.layer.1.attention.v_lin.scale_out 1.74566650390625
distilbert.transformer.layer.1.attention.v_lin.scale_in 1.838054895401001
distilbert.transformer.layer.1.attention.out_lin.scale_out 1.740161418914795
distilbert.transformer.layer.1.attention.out_lin.scale_in 1.8335179090499878
distilbert.transformer.layer.2.attention.q_lin.scale_out 1.9253816604614258
distilbert.transformer.layer.2.attention.q_lin.scale_in 1.9024275541305542
distilbert.transformer.layer.2.attention.k_lin.scale_out 1.9852197170257568
distilbert.transformer.layer.2.attention.k_lin.scale_in 1.8761833906173706
distilbert.transformer.layer.2.attention.v_lin.scale_out 1.796034574508667
distilbert.transformer.layer.2.attention.v_lin.scale_in 1.8118958473205566
distilbert.transformer.layer.2.attention.out_lin.scale_out 1.7673529386520386
distilbert.transformer.layer.2.attention.out_lin.scale_in 1.7620668411254883
distilbert.transformer.layer.3.attention.q_lin.scale_out 1.9725602865219116
distilbert.transformer.layer.3.attention.q_lin.scale_in 1.912266492843628
distilbert.transformer.layer.3.attention.k_lin.scale_out 1.9736626148223877
distilbert.transformer.layer.3.attention.k_lin.scale_in 1.8784260749816895
distilbert.transformer.layer.3.attention.v_lin.scale_out 1.7679297924041748
distilbert.transformer.layer.3.attention.v_lin.scale_in 1.8323858976364136
distilbert.transformer.layer.3.attention.out_lin.scale_out 1.8022470474243164
distilbert.transformer.layer.3.attention.out_lin.scale_in 1.8460712432861328
distilbert.transformer.layer.4.attention.q_lin.scale_out 1.9774129390716553
distilbert.transformer.layer.4.attention.q_lin.scale_in 1.8028367757797241
distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0612895488739014
distilbert.transformer.layer.4.attention.k_lin.scale_in 1.8295701742172241
distilbert.transformer.layer.4.attention.v_lin.scale_out 1.719773292541504
distilbert.transformer.layer.4.attention.v_lin.scale_in 1.7408709526062012
distilbert.transformer.layer.4.attention.out_lin.scale_out 1.7295746803283691
distilbert.transformer.layer.4.attention.out_lin.scale_in 1.7716983556747437
distilbert.transformer.layer.5.attention.q_lin.scale_out 1.9752777814865112
distilbert.transformer.layer.5.attention.q_lin.scale_in 1.7774198055267334
distilbert.transformer.layer.5.attention.k_lin.scale_out 2.129931926727295
distilbert.transformer.layer.5.attention.k_lin.scale_in 1.8098797798156738
distilbert.transformer.layer.5.attention.v_lin.scale_out 1.919114112854004
distilbert.transformer.layer.5.attention.v_lin.scale_in 1.6814241409301758
distilbert.transformer.layer.5.attention.out_lin.scale_out 1.7194103002548218
distilbert.transformer.layer.5.attention.out_lin.scale_in 1.7984395027160645
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.m_out 0.5647355318069458
distilbert.transformer.layer.0.attention.q_lin.m_in 0.3798569142818451
distilbert.transformer.layer.0.attention.k_lin.m_out 0.5425061583518982
distilbert.transformer.layer.0.attention.k_lin.m_in 0.37627243995666504
distilbert.transformer.layer.0.attention.v_lin.m_out 0.480976939201355
distilbert.transformer.layer.0.attention.v_lin.m_in 0.26557350158691406
distilbert.transformer.layer.0.attention.out_lin.m_out 0.3110979497432709
distilbert.transformer.layer.0.attention.out_lin.m_in 0.29723748564720154
distilbert.transformer.layer.1.attention.q_lin.m_out 0.4283873736858368
distilbert.transformer.layer.1.attention.q_lin.m_in 0.3325177729129791
distilbert.transformer.layer.1.attention.k_lin.m_out 0.40755534172058105
distilbert.transformer.layer.1.attention.k_lin.m_in 0.34619420766830444
distilbert.transformer.layer.1.attention.v_lin.m_out 0.2836112380027771
distilbert.transformer.layer.1.attention.v_lin.m_in 0.29925063252449036
distilbert.transformer.layer.1.attention.out_lin.m_out 0.3088749349117279
distilbert.transformer.layer.1.attention.out_lin.m_in 0.27071183919906616
distilbert.transformer.layer.2.attention.q_lin.m_out 0.4211108684539795
distilbert.transformer.layer.2.attention.q_lin.m_in 0.30545711517333984
distilbert.transformer.layer.2.attention.k_lin.m_out 0.44150853157043457
distilbert.transformer.layer.2.attention.k_lin.m_in 0.3178306221961975
distilbert.transformer.layer.2.attention.v_lin.m_out 0.34029704332351685
distilbert.transformer.layer.2.attention.v_lin.m_in 0.24604985117912292
distilbert.transformer.layer.2.attention.out_lin.m_out 0.3179643154144287
distilbert.transformer.layer.2.attention.out_lin.m_in 0.2530387341976166
distilbert.transformer.layer.3.attention.q_lin.m_out 0.4518465995788574
distilbert.transformer.layer.3.attention.q_lin.m_in 0.3557422161102295
distilbert.transformer.layer.3.attention.k_lin.m_out 0.47002920508384705
distilbert.transformer.layer.3.attention.k_lin.m_in 0.35762590169906616
distilbert.transformer.layer.3.attention.v_lin.m_out 0.32528677582740784
distilbert.transformer.layer.3.attention.v_lin.m_in 0.2892145812511444
distilbert.transformer.layer.3.attention.out_lin.m_out 0.33790677785873413
distilbert.transformer.layer.3.attention.out_lin.m_in 0.28367918729782104
distilbert.transformer.layer.4.attention.q_lin.m_out 0.47639477252960205
distilbert.transformer.layer.4.attention.q_lin.m_in 0.29029154777526855
distilbert.transformer.layer.4.attention.k_lin.m_out 0.5150153636932373
distilbert.transformer.layer.4.attention.k_lin.m_in 0.3103344440460205
distilbert.transformer.layer.4.attention.v_lin.m_out 0.32499608397483826
distilbert.transformer.layer.4.attention.v_lin.m_in 0.2510720491409302
distilbert.transformer.layer.4.attention.out_lin.m_out 0.3214051127433777
distilbert.transformer.layer.4.attention.out_lin.m_in 0.2650536000728607
distilbert.transformer.layer.5.attention.q_lin.m_out 0.43587324023246765
distilbert.transformer.layer.5.attention.q_lin.m_in 0.22241328656673431
distilbert.transformer.layer.5.attention.k_lin.m_out 0.5570618510246277
distilbert.transformer.layer.5.attention.k_lin.m_in 0.2527024447917938
distilbert.transformer.layer.5.attention.v_lin.m_out 0.37015169858932495
distilbert.transformer.layer.5.attention.v_lin.m_in 0.20549951493740082
distilbert.transformer.layer.5.attention.out_lin.m_out 0.31466561555862427
distilbert.transformer.layer.5.attention.out_lin.m_in 0.27215391397476196
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.39448559284210205
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.29920345544815063
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.388362854719162
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.291858047246933
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.3334265351295471
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.26573705673217773
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.33093491196632385
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.2437349259853363
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.3549439609050751
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.23443476855754852
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.3481384217739105
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.27366185188293457
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.3255350589752197
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.1981675922870636
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.31510788202285767
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24702942371368408
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.32814159989356995
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.2725476622581482
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.33942073583602905
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2639490067958832
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.28993457555770874
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.19012793898582458
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.30221056938171387
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.23099581897258759
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.3653438091278076
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.2912411689758301
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.37983566522598267
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.2919533848762512
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.3182007968425751
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.22263233363628387
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.3163568377494812
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.22031466662883759
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.32468748092651367
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.25622785091400146
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.3412896394729614
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2617414891719818
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.3016916513442993
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.18387971818447113
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.31307727098464966
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.20827889442443848
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.2830806076526642
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.23386383056640625
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.3035750091075897
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.2312200516462326
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.28683286905288696
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.16982325911521912
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.3157418370246887
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.2109360694885254
In [20]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 57.6984
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 42.6242
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 57.0737
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 41.5720
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 50.5210
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 38.4451
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 49.7002
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 35.1399
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 52.7742
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 34.2313
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 50.7043
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 38.7794
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 48.0156
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 28.3948
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 47.3806
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 35.4175
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 48.9293
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 38.7429
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 49.8689
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 37.3596
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 43.0877
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 27.5580
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 45.4100
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 32.9905
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 53.3644
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 41.3374
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 56.0531
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 41.8924
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 47.7002
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 31.9272
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 47.3052
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 31.5475
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 47.6581
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 36.8326
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 50.2827
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 37.8750
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 45.1280
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 26.8244
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 47.0836
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 30.0062
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 42.4531
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 33.3586
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 44.9119
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 33.6370
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 43.6901
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 24.9901
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 47.2171
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 30.6058
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 60.7115
distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.4249
distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 60.4370
distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 54.5890
distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 58.4866
distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 49.1681
distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 51.2425
distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 53.3198
distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 57.8829
distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 53.6451
distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 55.0335
distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 54.6923
distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 50.1200
distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 52.6154
distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 50.4183
distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 52.5371
distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 55.7336
distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 54.1221
distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 57.0685
distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 53.5724
distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 52.1679
distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 51.4892
distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 51.0791
distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 50.5406
distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 57.1301
distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 54.7893
distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 57.2046
distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 54.1217
distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.3382
distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 52.4507
distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 51.9679
distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 52.7578
distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 57.4674
distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 51.6265
distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 59.7401
distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 52.3292
distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 50.7752
distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 50.0643
distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 50.8358
distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 50.8573
distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 57.1890
distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 50.6338
distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 61.1200
distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 51.5444
distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 55.0821
distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 48.5430
distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 50.9995
distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 51.4657
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 20.0095
distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 15.0613
distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 19.6359
distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 15.1377
distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 17.2782
distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 12.7195
distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 11.9346
distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 12.8137
distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 15.6942
distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 13.5273
distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 14.8891
distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 13.4656
distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 11.5086
distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 12.5572
distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 11.9882
distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 12.0898
distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 15.6175
distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 12.3557
distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 15.7173
distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 12.8633
distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 12.7515
distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 10.7463
distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 12.1293
distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 11.4631
distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 16.4127
distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 13.9588
distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 16.8955
distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 14.1378
distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 12.3394
distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 12.3336
distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 12.4052
distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 11.9973
distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 17.3603
distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 12.2036
distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 18.3908
distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 12.4500
distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 12.6651
distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 11.2682
distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 12.6220
distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 11.9401
distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 16.4559
distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 10.3415
distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 18.6853
distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 10.9104
distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 13.6871
distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 10.8449
distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 12.6990
distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 11.9039
In [21]:
print (torch.cuda.memory_summary())
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1077 MiB |   7290 MiB | 406631 GiB | 406630 GiB |
|       from large pool |   1055 MiB |   7240 MiB | 404673 GiB | 404672 GiB |
|       from small pool |     22 MiB |     51 MiB |   1958 GiB |   1958 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   1077 MiB |   7290 MiB | 406631 GiB | 406630 GiB |
|       from large pool |   1055 MiB |   7240 MiB | 404673 GiB | 404672 GiB |
|       from small pool |     22 MiB |     51 MiB |   1958 GiB |   1958 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |   1072 MiB |   7285 MiB | 406404 GiB | 406403 GiB |
|       from large pool |   1050 MiB |   7235 MiB | 404455 GiB | 404454 GiB |
|       from small pool |     22 MiB |     51 MiB |   1948 GiB |   1948 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   7380 MiB |   7380 MiB |  14270 MiB |   6890 MiB |
|       from large pool |   7328 MiB |   7328 MiB |  14168 MiB |   6840 MiB |
|       from small pool |     52 MiB |     52 MiB |    102 MiB |     50 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 162149 KiB | 545173 KiB | 144872 GiB | 144872 GiB |
|       from large pool | 156580 KiB | 540498 KiB | 142781 GiB | 142781 GiB |
|       from small pool |   5569 KiB |  22942 KiB |   2091 GiB |   2091 GiB |
|---------------------------------------------------------------------------|
| Allocations           |    1310    |    1616    |   44683 K  |   44682 K  |
|       from large pool |     164    |     318    |   13813 K  |   13813 K  |
|       from small pool |    1146    |    1451    |   30869 K  |   30868 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1310    |    1616    |   44683 K  |   44682 K  |
|       from large pool |     164    |     318    |   13813 K  |   13813 K  |
|       from small pool |    1146    |    1451    |   30869 K  |   30868 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     190    |     190    |     327    |     137    |
|       from large pool |     164    |     164    |     276    |     112    |
|       from small pool |      26    |      26    |      51    |      25    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      37    |      60    |   20413 K  |   20413 K  |
|       from large pool |      18    |      26    |    3326 K  |    3326 K  |
|       from small pool |      19    |      43    |   17086 K  |   17086 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

Full DistilBERT on IMDB Sentiment¶

In [22]:
# Define the training parameters
batch_size = 32
logging_steps = len(dataset_encoded["train"]) // batch_size
model_name = f"{model_ckpt}-finetuned-imdb"
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=2, 
                                  learning_rate=1e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="steps",
                                  eval_steps=20,  # Evaluate more frequently on the larger dataset
                                  logging_steps=20,
                                  save_steps=20,
                                  max_grad_norm=1.0,
                                  disable_tqdm=False,
                                  push_to_hub=False,
                                  report_to="none",
                                  log_level="error")

trainer = Trainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=dataset_encoded["train"],
                  eval_dataset=dataset_encoded["validation"],
                  tokenizer=tokenizer)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
C:\Users\alexa\AppData\Local\Temp\ipykernel_22420\3915345233.py:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(model=model,
In [23]:
# Train!
trainer.train()
print (torch.cuda.memory_summary())
# Evaluate on the test set
evaluation_results = trainer.evaluate(dataset_encoded["test"])
print(f"\nEvaluation results on the test set:")
print(evaluation_results)

# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt

# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]

# Extract validation loss
val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]

# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")

plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss vs. Steps")
plt.legend()
plt.show()

val_accuracy_steps = [entry["step"] for entry in log_history if "eval_accuracy" in entry]
val_accuracies = [entry["eval_accuracy"] for entry in log_history if "eval_accuracy" in entry]

val_f1_steps = [entry["step"] for entry in log_history if "eval_f1" in entry]
val_f1_scores = [entry["eval_f1"] for entry in log_history if "eval_f1" in entry]

plt.plot(val_accuracy_steps, val_accuracies, label="Validation Accuracy", linestyle="-")
plt.plot(val_f1_steps, val_f1_scores, label="Validation F1 score", linestyle="--")

plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Val Accurace and F1 vs. Steps")
plt.legend()
plt.show()
[1564/1564 45:03, Epoch 2/2]
Step Training Loss Validation Loss Accuracy F1
20 0.689400 0.676229 0.744000 0.739512
40 0.659500 0.588595 0.820800 0.819493
60 0.494700 0.388211 0.865600 0.865512
80 0.366500 0.328102 0.866400 0.865610
100 0.323200 0.317173 0.876800 0.877069
120 0.339900 0.284819 0.884000 0.883705
140 0.298000 0.275737 0.885600 0.885591
160 0.305800 0.292675 0.880800 0.879911
180 0.258600 0.272887 0.888800 0.888825
200 0.257000 0.272066 0.884800 0.884265
220 0.255000 0.281653 0.884800 0.883885
240 0.286100 0.253439 0.901600 0.901510
260 0.257200 0.248806 0.902400 0.902369
280 0.244400 0.249558 0.902400 0.902184
300 0.282500 0.258910 0.907200 0.906769
320 0.266500 0.241840 0.906400 0.906141
340 0.269500 0.239128 0.900800 0.900752
360 0.252200 0.258152 0.897600 0.896886
380 0.250800 0.246389 0.908800 0.908924
400 0.246200 0.252573 0.901600 0.901795
420 0.252400 0.269475 0.894400 0.894634
440 0.219100 0.262670 0.898400 0.898618
460 0.300400 0.263741 0.899200 0.898037
480 0.270000 0.264092 0.896000 0.894800
500 0.258500 0.240369 0.904800 0.904421
520 0.254600 0.230196 0.908000 0.907948
540 0.238300 0.233112 0.905600 0.905677
560 0.233700 0.233234 0.904000 0.904111
580 0.284200 0.229832 0.905600 0.905570
600 0.254800 0.229044 0.911200 0.911050
620 0.215700 0.228184 0.913600 0.913528
640 0.247200 0.244651 0.901600 0.901771
660 0.223600 0.232534 0.911200 0.910891
680 0.234400 0.248592 0.906400 0.906585
700 0.247600 0.223161 0.913600 0.913625
720 0.269200 0.223580 0.912000 0.912128
740 0.202000 0.222004 0.909600 0.909669
760 0.226900 0.233625 0.908800 0.908961
780 0.215700 0.219046 0.908800 0.908814
800 0.195900 0.222158 0.917600 0.917478
820 0.175600 0.259501 0.896800 0.897026
840 0.213900 0.228438 0.914400 0.914306
860 0.212300 0.244444 0.904800 0.904978
880 0.188900 0.224636 0.914400 0.914337
900 0.166400 0.226523 0.912000 0.911943
920 0.177200 0.237086 0.910400 0.910530
940 0.214900 0.226861 0.913600 0.913463
960 0.178100 0.239278 0.908000 0.908153
980 0.189200 0.224798 0.912000 0.911943
1000 0.166500 0.232560 0.913600 0.913660
1020 0.187100 0.230836 0.913600 0.913625
1040 0.154100 0.229299 0.916000 0.915923
1060 0.207200 0.226525 0.916800 0.916847
1080 0.191800 0.225182 0.916000 0.915907
1100 0.204000 0.233329 0.912000 0.912093
1120 0.165900 0.227521 0.916800 0.916836
1140 0.156800 0.226366 0.917600 0.917524
1160 0.203400 0.224103 0.914400 0.914443
1180 0.203600 0.221393 0.917600 0.917478
1200 0.174700 0.232756 0.908800 0.908932
1220 0.164400 0.225512 0.916000 0.915967
1240 0.169700 0.230427 0.908800 0.908886
1260 0.220100 0.228160 0.916000 0.915823
1280 0.177400 0.233028 0.907200 0.907317
1300 0.207600 0.226083 0.914400 0.914406
1320 0.210700 0.225015 0.911200 0.911245
1340 0.148000 0.224107 0.916800 0.916634
1360 0.170400 0.224322 0.910400 0.910439
1380 0.200600 0.223000 0.913600 0.913558
1400 0.175900 0.223361 0.911200 0.911245
1420 0.169600 0.224143 0.909600 0.909657
1440 0.157600 0.224942 0.909600 0.909657
1460 0.183600 0.223980 0.912000 0.911986
1480 0.204600 0.224059 0.912000 0.911972
1500 0.181200 0.224224 0.914400 0.914337
1520 0.175000 0.224378 0.912800 0.912765
1540 0.183800 0.225751 0.907200 0.907240
1560 0.188800 0.226429 0.908000 0.908070

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1589 MiB |   7290 MiB | 440273 GiB | 440272 GiB |
|       from large pool |   1566 MiB |   7240 MiB | 438070 GiB | 438069 GiB |
|       from small pool |     23 MiB |     51 MiB |   2203 GiB |   2203 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   1589 MiB |   7290 MiB | 440273 GiB | 440272 GiB |
|       from large pool |   1566 MiB |   7240 MiB | 438070 GiB | 438069 GiB |
|       from small pool |     23 MiB |     51 MiB |   2203 GiB |   2203 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |   1583 MiB |   7285 MiB | 439981 GiB | 439979 GiB |
|       from large pool |   1560 MiB |   7235 MiB | 437788 GiB | 437786 GiB |
|       from small pool |     23 MiB |     51 MiB |   2192 GiB |   2192 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   5538 MiB |   7380 MiB |  18572 MiB |  13034 MiB |
|       from large pool |   5508 MiB |   7328 MiB |  18468 MiB |  12960 MiB |
|       from small pool |     30 MiB |     52 MiB |    104 MiB |     74 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 264726 KiB | 545173 KiB | 154496 GiB | 154496 GiB |
|       from large pool | 259656 KiB | 540498 KiB | 152152 GiB | 152152 GiB |
|       from small pool |   5070 KiB |  22942 KiB |   2343 GiB |   2343 GiB |
|---------------------------------------------------------------------------|
| Allocations           |    1518    |    1736    |   48429 K  |   48428 K  |
|       from large pool |     242    |     320    |   14932 K  |   14932 K  |
|       from small pool |    1276    |    1455    |   33496 K  |   33495 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1518    |    1736    |   48429 K  |   48428 K  |
|       from large pool |     242    |     320    |   14932 K  |   14932 K  |
|       from small pool |    1276    |    1455    |   33496 K  |   33495 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     119    |     190    |     399    |     280    |
|       from large pool |     104    |     164    |     347    |     243    |
|       from small pool |      15    |      26    |      52    |      37    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      41    |      60    |   22134 K  |   22134 K  |
|       from large pool |      21    |      28    |    3630 K  |    3630 K  |
|       from small pool |      20    |      43    |   18503 K  |   18503 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

[743/743 6:00:33]
Evaluation results on the test set:
{'eval_loss': 0.21304276585578918, 'eval_accuracy': 0.9208421052631579, 'eval_f1': 0.9208423095953104, 'eval_runtime': 122.0514, 'eval_samples_per_second': 194.59, 'eval_steps_per_second': 6.088, 'epoch': 2.0}
No description has been provided for this image
No description has been provided for this image

Bonus - DDoRA training - continued...¶

In [24]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")

print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])

print('LoRA Heatmap')
layer_names = []
b_norms = []

for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())

sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()

import torch
import numpy as np
import matplotlib.pyplot as plt

m_in_values = []
m_out_values = []

for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)

# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)

# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")

# Plot histograms
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')

plt.tight_layout()
plt.show()
[3910/3910 2:54:29, Epoch 5/5]
Step Training Loss Validation Loss Accuracy F1
50 0.220400 0.234968 0.918400 0.918165
100 0.218500 0.291846 0.909600 0.909350
150 0.262700 0.238732 0.904800 0.904872
200 0.202600 0.255080 0.905600 0.905789
250 0.236200 0.241231 0.912800 0.912896
300 0.241200 0.219332 0.915200 0.915200
350 0.250000 0.223478 0.909600 0.909217
400 0.212100 0.238065 0.910400 0.910055
450 0.205200 0.279559 0.912800 0.912232
500 0.218100 0.243256 0.918400 0.918435
550 0.204100 0.298102 0.916800 0.916634
600 0.246300 0.275445 0.916000 0.915923
650 0.220000 0.258947 0.905600 0.905260
700 0.196100 0.254702 0.912800 0.912704
750 0.234500 0.282484 0.887200 0.885504
800 0.202900 0.328079 0.892800 0.891303
850 0.186800 0.242680 0.912800 0.912780
900 0.180200 0.268671 0.921600 0.921575
950 0.173000 0.297629 0.917600 0.917391
1000 0.196400 0.250007 0.904800 0.904209
1050 0.196400 0.241103 0.920800 0.920794
1100 0.188100 0.263111 0.920000 0.919934
1150 0.170100 0.242923 0.916800 0.916941
1200 0.195600 0.233525 0.927200 0.927236
1250 0.175300 0.244239 0.909600 0.908955
1300 0.197200 0.263203 0.914400 0.914565
1350 0.225300 0.233377 0.916000 0.915907
1400 0.193400 0.232386 0.922400 0.922477
1450 0.201500 0.216410 0.919200 0.919168
1500 0.185000 0.218531 0.916800 0.916458
1550 0.208200 0.246917 0.910400 0.910552
1600 0.156400 0.343675 0.893600 0.893824
1650 0.145800 0.321663 0.923200 0.923150
1700 0.170000 0.295388 0.918400 0.918467
1750 0.166200 0.247114 0.928800 0.928760
1800 0.154300 0.227517 0.920000 0.920066
1850 0.165400 0.281829 0.922400 0.922459
1900 0.153700 0.256229 0.924800 0.924710
1950 0.164500 0.270434 0.923200 0.923297
2000 0.167300 0.226990 0.924800 0.924724
2050 0.167500 0.239940 0.927200 0.927120
2100 0.163400 0.226345 0.918400 0.918146
2150 0.116000 0.307402 0.923200 0.923108
2200 0.172000 0.269007 0.921600 0.921410
2250 0.160900 0.270219 0.924800 0.924634
2300 0.165600 0.256944 0.921600 0.921521
2350 0.155100 0.233660 0.916000 0.915666
2400 0.138500 0.256974 0.913600 0.913331
2450 0.124600 0.244719 0.913600 0.913371
2500 0.118200 0.303150 0.920000 0.920109
2550 0.102700 0.305154 0.912000 0.911568
2600 0.140000 0.266935 0.921600 0.921664
2650 0.114900 0.281125 0.915200 0.915099
2700 0.124200 0.297636 0.916800 0.916746
2750 0.112800 0.275313 0.919200 0.919096
2800 0.109300 0.253661 0.920000 0.919934
2850 0.107300 0.281190 0.913600 0.913127
2900 0.118900 0.262891 0.920800 0.920850
2950 0.123200 0.258890 0.922400 0.922300
3000 0.130200 0.258712 0.921600 0.921562
3050 0.119200 0.269527 0.917600 0.917567
3100 0.124000 0.263077 0.913600 0.913127
3150 0.101900 0.296506 0.919200 0.919181
3200 0.093300 0.302173 0.923200 0.923093
3250 0.084200 0.305703 0.920800 0.920781
3300 0.091100 0.294771 0.921600 0.921506
3350 0.077800 0.341725 0.920000 0.919806
3400 0.087400 0.324175 0.916800 0.916716
3450 0.100900 0.307702 0.921600 0.921443
3500 0.080300 0.325871 0.913600 0.913573
3550 0.067700 0.348119 0.916800 0.916800
3600 0.109800 0.301693 0.918400 0.918184
3650 0.094600 0.309193 0.915200 0.915114
3700 0.087600 0.316292 0.916800 0.916598
3750 0.076500 0.331572 0.915200 0.915225
3800 0.082600 0.318606 0.918400 0.918302
3850 0.075700 0.319939 0.919200 0.919126
3900 0.076200 0.319871 0.919200 0.919126

DDoRA (All Attention) Test Results: {'eval_loss': 0.20528165996074677, 'eval_accuracy': 0.9178526315789474, 'eval_f1': 0.9178393215322689, 'eval_runtime': 160.0105, 'eval_samples_per_second': 148.428, 'eval_steps_per_second': 4.643, 'epoch': 5.0}
Prediction drift
[[-2.8385003  1.8467747]
 [-4.145672   3.0743616]
 [ 2.190418  -2.6564922]
 [-2.6781795  1.6880615]
 [ 4.3640585 -5.1461053]] [1 1 0 1 0]
LoRA Heatmap
No description has been provided for this image
[m_out] Mean: 0.0013, Std: 0.7323, Min: -3.2482, Max: 5.2270
[m_in ] Mean: 0.0093, Std: 0.5591, Min: -2.4425, Max: 2.6041
No description has been provided for this image
In [25]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")

print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])

print('LoRA Heatmap')
layer_names = []
b_norms = []

for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())

sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()

import torch
import numpy as np
import matplotlib.pyplot as plt

m_in_values = []
m_out_values = []

for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)

# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)

# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")

# Plot histograms
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')

plt.tight_layout()
plt.show()
[3910/3910 2:58:07, Epoch 5/5]
Step Training Loss Validation Loss Accuracy F1
50 0.193100 0.231645 0.906400 0.905762
100 0.181300 0.284688 0.905600 0.905806
150 0.201600 0.228212 0.919200 0.919194
200 0.165300 0.329915 0.904000 0.903447
250 0.180800 0.242941 0.906400 0.905702
300 0.200800 0.223974 0.911200 0.910594
350 0.207600 0.294283 0.903200 0.902281
400 0.172700 0.272758 0.912800 0.912720
450 0.174600 0.307540 0.912000 0.911301
500 0.194500 0.255431 0.923200 0.922979
550 0.179700 0.227661 0.924000 0.923982
600 0.206200 0.250669 0.924800 0.924788
650 0.161900 0.223719 0.920800 0.920616
700 0.160000 0.327061 0.913600 0.913709
750 0.183100 0.310250 0.889600 0.887916
800 0.176400 0.296977 0.902400 0.901688
850 0.130100 0.293862 0.912800 0.912720
900 0.132400 0.327696 0.912000 0.911493
950 0.145000 0.389061 0.912800 0.912430
1000 0.154600 0.376502 0.908800 0.908494
1050 0.149700 0.331415 0.908000 0.908145
1100 0.150100 0.422993 0.908800 0.908786
1150 0.122200 0.281298 0.911200 0.911316
1200 0.146500 0.323460 0.913600 0.913747
1250 0.152600 0.273852 0.918400 0.918271
1300 0.174300 0.231292 0.920000 0.919919
1350 0.172700 0.253323 0.918400 0.918237
1400 0.174600 0.237144 0.910400 0.910008
1450 0.150800 0.240501 0.921600 0.921506
1500 0.196900 0.244910 0.920800 0.920769
1550 0.202900 0.221191 0.916000 0.916064
1600 0.134700 0.323171 0.916800 0.916824
1650 0.106900 0.297118 0.923200 0.922923
1700 0.138200 0.273596 0.913600 0.913725
1750 0.125000 0.324790 0.922400 0.922382
1800 0.113900 0.257193 0.908800 0.908915
1850 0.138000 0.295698 0.916800 0.916836
1900 0.110100 0.312183 0.923200 0.923163
1950 0.120400 0.276098 0.916000 0.916074
2000 0.135500 0.260222 0.923200 0.923233
2050 0.134200 0.242475 0.920800 0.920829
2100 0.124400 0.246645 0.924800 0.924634
2150 0.099600 0.298770 0.922400 0.922369
2200 0.120900 0.293304 0.923200 0.923136
2250 0.119700 0.287883 0.924800 0.924681
2300 0.125600 0.258974 0.915200 0.915315
2350 0.118500 0.262726 0.921600 0.921427
2400 0.094500 0.329549 0.912800 0.912178
2450 0.102100 0.306560 0.922400 0.922269
2500 0.100400 0.277154 0.922400 0.922439
2550 0.080800 0.327069 0.921600 0.921374
2600 0.103800 0.279821 0.918400 0.918526
2650 0.097400 0.280838 0.918400 0.918495
2700 0.089000 0.333324 0.921600 0.921634
2750 0.093300 0.278819 0.921600 0.921443
2800 0.069700 0.351026 0.924000 0.923916
2850 0.071100 0.334101 0.923200 0.923013
2900 0.089700 0.317389 0.924000 0.923970
2950 0.101400 0.295993 0.924800 0.924724
3000 0.090500 0.301456 0.925600 0.925594
3050 0.092200 0.300997 0.924000 0.924028
3100 0.091500 0.297800 0.920000 0.919731
3150 0.079600 0.339592 0.920800 0.920840
3200 0.072300 0.360106 0.922400 0.922369
3250 0.053700 0.383484 0.925600 0.925545
3300 0.072800 0.334235 0.922400 0.922356
3350 0.065400 0.371789 0.921600 0.921410
3400 0.068900 0.343820 0.924000 0.923930
3450 0.071900 0.332017 0.922400 0.922329
3500 0.046600 0.393188 0.922400 0.922394
3550 0.057900 0.388426 0.919200 0.919240
3600 0.065900 0.366166 0.923200 0.923176
3650 0.073700 0.351996 0.924000 0.923916
3700 0.070800 0.337321 0.922400 0.922314
3750 0.054100 0.353665 0.923200 0.923188
3800 0.051200 0.360188 0.924000 0.923957
3850 0.061000 0.360322 0.923200 0.923176
3900 0.063800 0.360985 0.924800 0.924788

DDoRA (All Attention) Test Results: {'eval_loss': 0.2019222378730774, 'eval_accuracy': 0.9212631578947369, 'eval_f1': 0.9212608684750585, 'eval_runtime': 157.6478, 'eval_samples_per_second': 150.652, 'eval_steps_per_second': 4.713, 'epoch': 5.0}
Prediction drift
[[-2.494168   1.6709884]
 [-4.1562243  3.2771716]
 [ 2.5441208 -2.9046338]
 [-2.77265    1.9955821]
 [ 5.211438  -5.7563157]] [1 1 0 1 0]
LoRA Heatmap
No description has been provided for this image
[m_out] Mean: 0.0017, Std: 0.8710, Min: -3.8374, Max: 6.9031
[m_in ] Mean: 0.0119, Std: 0.6580, Min: -2.7239, Max: 2.7596
No description has been provided for this image
In [26]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.scale_out 2.211477279663086
distilbert.transformer.layer.0.attention.q_lin.scale_in 1.7990350723266602
distilbert.transformer.layer.0.attention.k_lin.scale_out 2.160134792327881
distilbert.transformer.layer.0.attention.k_lin.scale_in 1.7010796070098877
distilbert.transformer.layer.0.attention.v_lin.scale_out 2.182528257369995
distilbert.transformer.layer.0.attention.v_lin.scale_in 1.0187453031539917
distilbert.transformer.layer.0.attention.out_lin.scale_out 1.5009715557098389
distilbert.transformer.layer.0.attention.out_lin.scale_in 1.3177869319915771
distilbert.transformer.layer.1.attention.q_lin.scale_out 2.0520782470703125
distilbert.transformer.layer.1.attention.q_lin.scale_in 1.6306018829345703
distilbert.transformer.layer.1.attention.k_lin.scale_out 1.8983781337738037
distilbert.transformer.layer.1.attention.k_lin.scale_in 1.684961199760437
distilbert.transformer.layer.1.attention.v_lin.scale_out 1.3902130126953125
distilbert.transformer.layer.1.attention.v_lin.scale_in 1.4229816198349
distilbert.transformer.layer.1.attention.out_lin.scale_out 1.4535112380981445
distilbert.transformer.layer.1.attention.out_lin.scale_in 1.2360076904296875
distilbert.transformer.layer.2.attention.q_lin.scale_out 1.8091825246810913
distilbert.transformer.layer.2.attention.q_lin.scale_in 1.6466021537780762
distilbert.transformer.layer.2.attention.k_lin.scale_out 2.054741859436035
distilbert.transformer.layer.2.attention.k_lin.scale_in 1.6497936248779297
distilbert.transformer.layer.2.attention.v_lin.scale_out 1.364464282989502
distilbert.transformer.layer.2.attention.v_lin.scale_in 1.3622424602508545
distilbert.transformer.layer.2.attention.out_lin.scale_out 1.31589674949646
distilbert.transformer.layer.2.attention.out_lin.scale_in 1.1265673637390137
distilbert.transformer.layer.3.attention.q_lin.scale_out 2.026132345199585
distilbert.transformer.layer.3.attention.q_lin.scale_in 1.7109956741333008
distilbert.transformer.layer.3.attention.k_lin.scale_out 2.1449785232543945
distilbert.transformer.layer.3.attention.k_lin.scale_in 1.6678569316864014
distilbert.transformer.layer.3.attention.v_lin.scale_out 1.5447391271591187
distilbert.transformer.layer.3.attention.v_lin.scale_in 1.2968910932540894
distilbert.transformer.layer.3.attention.out_lin.scale_out 1.459152102470398
distilbert.transformer.layer.3.attention.out_lin.scale_in 1.3626108169555664
distilbert.transformer.layer.4.attention.q_lin.scale_out 1.8095883131027222
distilbert.transformer.layer.4.attention.q_lin.scale_in 1.434910535812378
distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0096232891082764
distilbert.transformer.layer.4.attention.k_lin.scale_in 1.4818964004516602
distilbert.transformer.layer.4.attention.v_lin.scale_out 1.224574089050293
distilbert.transformer.layer.4.attention.v_lin.scale_in 0.9664930701255798
distilbert.transformer.layer.4.attention.out_lin.scale_out 1.186560034751892
distilbert.transformer.layer.4.attention.out_lin.scale_in 1.138476848602295
distilbert.transformer.layer.5.attention.q_lin.scale_out 1.6922223567962646
distilbert.transformer.layer.5.attention.q_lin.scale_in 1.1524317264556885
distilbert.transformer.layer.5.attention.k_lin.scale_out 1.892466425895691
distilbert.transformer.layer.5.attention.k_lin.scale_in 1.325523853302002
distilbert.transformer.layer.5.attention.v_lin.scale_out 1.4372739791870117
distilbert.transformer.layer.5.attention.v_lin.scale_in 0.7315273284912109
distilbert.transformer.layer.5.attention.out_lin.scale_out 1.1462877988815308
distilbert.transformer.layer.5.attention.out_lin.scale_in 1.2485125064849854
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.m_out 0.904198169708252
distilbert.transformer.layer.0.attention.q_lin.m_in 0.5918328166007996
distilbert.transformer.layer.0.attention.k_lin.m_out 0.8658651113510132
distilbert.transformer.layer.0.attention.k_lin.m_in 0.552005410194397
distilbert.transformer.layer.0.attention.v_lin.m_out 0.891939640045166
distilbert.transformer.layer.0.attention.v_lin.m_in 0.3521280288696289
distilbert.transformer.layer.0.attention.out_lin.m_out 0.5082004070281982
distilbert.transformer.layer.0.attention.out_lin.m_in 0.3994007110595703
distilbert.transformer.layer.1.attention.q_lin.m_out 0.7819246053695679
distilbert.transformer.layer.1.attention.q_lin.m_in 0.5072041153907776
distilbert.transformer.layer.1.attention.k_lin.m_out 0.7096916437149048
distilbert.transformer.layer.1.attention.k_lin.m_in 0.5101008415222168
distilbert.transformer.layer.1.attention.v_lin.m_out 0.48597097396850586
distilbert.transformer.layer.1.attention.v_lin.m_in 0.4318428933620453
distilbert.transformer.layer.1.attention.out_lin.m_out 0.510459303855896
distilbert.transformer.layer.1.attention.out_lin.m_in 0.370250940322876
distilbert.transformer.layer.2.attention.q_lin.m_out 0.672264039516449
distilbert.transformer.layer.2.attention.q_lin.m_in 0.49723148345947266
distilbert.transformer.layer.2.attention.k_lin.m_out 0.8109812140464783
distilbert.transformer.layer.2.attention.k_lin.m_in 0.5146569013595581
distilbert.transformer.layer.2.attention.v_lin.m_out 0.4366285800933838
distilbert.transformer.layer.2.attention.v_lin.m_in 0.36893153190612793
distilbert.transformer.layer.2.attention.out_lin.m_out 0.4298613965511322
distilbert.transformer.layer.2.attention.out_lin.m_in 0.3405319154262543
distilbert.transformer.layer.3.attention.q_lin.m_out 0.7905937433242798
distilbert.transformer.layer.3.attention.q_lin.m_in 0.5291155576705933
distilbert.transformer.layer.3.attention.k_lin.m_out 0.8796322345733643
distilbert.transformer.layer.3.attention.k_lin.m_in 0.539239764213562
distilbert.transformer.layer.3.attention.v_lin.m_out 0.5514059066772461
distilbert.transformer.layer.3.attention.v_lin.m_in 0.395082950592041
distilbert.transformer.layer.3.attention.out_lin.m_out 0.5047212839126587
distilbert.transformer.layer.3.attention.out_lin.m_in 0.4035267233848572
distilbert.transformer.layer.4.attention.q_lin.m_out 0.682336688041687
distilbert.transformer.layer.4.attention.q_lin.m_in 0.4449799656867981
distilbert.transformer.layer.4.attention.k_lin.m_out 0.7756460905075073
distilbert.transformer.layer.4.attention.k_lin.m_in 0.46460646390914917
distilbert.transformer.layer.4.attention.v_lin.m_out 0.41292259097099304
distilbert.transformer.layer.4.attention.v_lin.m_in 0.2883872985839844
distilbert.transformer.layer.4.attention.out_lin.m_out 0.3856971859931946
distilbert.transformer.layer.4.attention.out_lin.m_in 0.3439074456691742
distilbert.transformer.layer.5.attention.q_lin.m_out 0.5993542671203613
distilbert.transformer.layer.5.attention.q_lin.m_in 0.3158256709575653
distilbert.transformer.layer.5.attention.k_lin.m_out 0.6338485479354858
distilbert.transformer.layer.5.attention.k_lin.m_in 0.3383934199810028
distilbert.transformer.layer.5.attention.v_lin.m_out 0.4689478278160095
distilbert.transformer.layer.5.attention.v_lin.m_in 0.2039266675710678
distilbert.transformer.layer.5.attention.out_lin.m_out 0.3742349147796631
distilbert.transformer.layer.5.attention.out_lin.m_in 0.39616748690605164
Parameter Statistics: mean.abs()
distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6073777079582214
distilbert.transformer.layer.0.attention.q_lin.lora.B 0.5936362743377686
distilbert.transformer.layer.0.attention.k_lin.lora.A 0.5834858417510986
distilbert.transformer.layer.0.attention.k_lin.lora.B 0.5944669246673584
distilbert.transformer.layer.0.attention.v_lin.lora.A 0.477506160736084
distilbert.transformer.layer.0.attention.v_lin.lora.B 0.5750714540481567
distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4822663962841034
distilbert.transformer.layer.0.attention.out_lin.lora.B 0.5455904006958008
distilbert.transformer.layer.1.attention.q_lin.lora.A 0.5509202480316162
distilbert.transformer.layer.1.attention.q_lin.lora.B 0.5553864240646362
distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5384150147438049
distilbert.transformer.layer.1.attention.k_lin.lora.B 0.5774441957473755
distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5051205158233643
distilbert.transformer.layer.1.attention.v_lin.lora.B 0.4909871518611908
distilbert.transformer.layer.1.attention.out_lin.lora.A 0.4681047201156616
distilbert.transformer.layer.1.attention.out_lin.lora.B 0.5537931323051453
distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5328584909439087
distilbert.transformer.layer.2.attention.q_lin.lora.B 0.5830724239349365
distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5513193011283875
distilbert.transformer.layer.2.attention.k_lin.lora.B 0.5603161454200745
distilbert.transformer.layer.2.attention.v_lin.lora.A 0.46175330877304077
distilbert.transformer.layer.2.attention.v_lin.lora.B 0.4580882787704468
distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4390673041343689
distilbert.transformer.layer.2.attention.out_lin.lora.B 0.5048208236694336
distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5584718585014343
distilbert.transformer.layer.3.attention.q_lin.lora.B 0.5750592947006226
distilbert.transformer.layer.3.attention.k_lin.lora.A 0.5745105147361755
distilbert.transformer.layer.3.attention.k_lin.lora.B 0.5809159278869629
distilbert.transformer.layer.3.attention.v_lin.lora.A 0.47269874811172485
distilbert.transformer.layer.3.attention.v_lin.lora.B 0.5128231048583984
distilbert.transformer.layer.3.attention.out_lin.lora.A 0.460321843624115
distilbert.transformer.layer.3.attention.out_lin.lora.B 0.48533257842063904
distilbert.transformer.layer.4.attention.q_lin.lora.A 0.5054908990859985
distilbert.transformer.layer.4.attention.q_lin.lora.B 0.5460034608840942
distilbert.transformer.layer.4.attention.k_lin.lora.A 0.5137321352958679
distilbert.transformer.layer.4.attention.k_lin.lora.B 0.5609502196311951
distilbert.transformer.layer.4.attention.v_lin.lora.A 0.4001336991786957
distilbert.transformer.layer.4.attention.v_lin.lora.B 0.4369596838951111
distilbert.transformer.layer.4.attention.out_lin.lora.A 0.42696064710617065
distilbert.transformer.layer.4.attention.out_lin.lora.B 0.4693056344985962
distilbert.transformer.layer.5.attention.q_lin.lora.A 0.4086011052131653
distilbert.transformer.layer.5.attention.q_lin.lora.B 0.48308905959129333
distilbert.transformer.layer.5.attention.k_lin.lora.A 0.432908296585083
distilbert.transformer.layer.5.attention.k_lin.lora.B 0.4992391765117645
distilbert.transformer.layer.5.attention.v_lin.lora.A 0.34997496008872986
distilbert.transformer.layer.5.attention.v_lin.lora.B 0.45308423042297363
distilbert.transformer.layer.5.attention.out_lin.lora.A 0.46718907356262207
distilbert.transformer.layer.5.attention.out_lin.lora.B 0.5256445407867432
In [27]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 88.6770
distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 83.6840
distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 85.9274
distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 83.4974
distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 73.3022
distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 80.9918
distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 73.7286
distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 77.5629
distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 81.9203
distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 78.9050
distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 80.0022
distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 81.3284
distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 76.1209
distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 70.1328
distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 71.7114
distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 78.6895
distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 80.5862
distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 81.8691
distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 81.0452
distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 79.2133
distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 70.1643
distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 65.6669
distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 68.0408
distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 71.9416
distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.9671
distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 80.9681
distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.2276
distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 82.0434
distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 72.1863
distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 73.0758
distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 70.6866
distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 69.2462
distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 75.9484
distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 77.3235
distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 76.9465
distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 79.4074
distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 62.7360
distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 63.9343
distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 66.6852
distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 67.7626
distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 63.4330
distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 69.0939
distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 66.2249
distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 70.6252
distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 56.1786
distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 67.4088
distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 71.9610
distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 76.3459
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 67.8149
distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.8788
distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 66.3004
distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 53.4427
distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 67.3438
distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 40.8041
distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 49.3765
distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 46.0389
distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 63.8303
distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 51.8581
distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 59.3891
distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 52.9394
distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 47.0823
distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 47.6743
distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 48.7647
distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 44.8163
distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 57.6219
distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 52.3231
distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 63.9857
distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 52.2575
distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 46.9235
distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 45.0910
distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 45.7886
distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 41.9345
distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 63.2718
distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 53.4324
distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 66.0361
distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 52.7628
distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.4346
distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 45.3650
distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 49.2689
distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 46.8675
distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 58.5162
distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 47.5943
distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 62.9462
distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 48.5050
distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 45.4089
distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 38.7889
distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 44.1285
distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 42.3015
distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 55.4135
distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 41.3333
distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 58.2401
distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 44.2531
distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 49.2695
distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 33.0258
distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 43.1049
distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 44.5303
Parameter Statistics: param.norm()
distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 31.4448
distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 22.4398
distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 30.1271
distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 21.4311
distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 31.0435
distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 16.8509
distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 19.6845
distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 17.7480
distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 28.3183
distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 20.1127
distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 25.9295
distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 20.4797
distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 19.3969
distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 18.3722
distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 19.8745
distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 17.4045
distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 24.9047
distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 20.1862
distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 29.0122
distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 20.1494
distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 17.9841
distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 16.4528
distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 17.7306
distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 16.3007
distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 28.5406
distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 20.5190
distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 30.7927
distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 20.6228
distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 21.3742
distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 17.7191
distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 20.0696
distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 17.6262
distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 25.6429
distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 18.5982
distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 27.9177
distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 18.6512
distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 18.0866
distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 15.1881
distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 17.4686
distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 16.4471
distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 22.9480
distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 15.1683
distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 22.7621
distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 15.4354
distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 19.4222
distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 12.9561
distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 16.6265
distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 17.4611
In [28]:
evaluation_results = trainer.evaluate(dataset_encoded["validation"])
print(f"\nEvaluation results on the validation set:")
print(evaluation_results)
print("Last 10 Validation Accuracy Steps:", val_accuracy_steps[-10:])
print("Last 10 Validation Accuracies:", val_accuracies[-10:])
print("Last 10 Validation F1 Steps:", val_f1_steps[-10:])
print("Last 10 Validation F1 Scores:", val_f1_scores[-10:])
Evaluation results on the validation set:
{'eval_loss': 0.22643864154815674, 'eval_accuracy': 0.908, 'eval_f1': 0.908069813307235, 'eval_runtime': 6.7981, 'eval_samples_per_second': 183.876, 'eval_steps_per_second': 5.884, 'epoch': 2.0}
Last 10 Validation Accuracy Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564]
Last 10 Validation Accuracies: [0.9112, 0.9096, 0.9096, 0.912, 0.912, 0.9144, 0.9128, 0.9072, 0.908, 0.9208421052631579]
Last 10 Validation F1 Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564]
Last 10 Validation F1 Scores: [0.9112445094405226, 0.9096571916118659, 0.9096571916118659, 0.9119863574704475, 0.9119722515145877, 0.9143367287669, 0.9127653425067421, 0.9072402344324614, 0.908069813307235, 0.9208423095953104]
In [ ]: