In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(torch.cuda.get_device_name(0))
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import bitsandbytes
import peft
import transformers
print(transformers.__version__)
print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)] PyTorch Version: 2.5.1+cu121 True 1 CUDA Version: 12.1 NVIDIA GeForce RTX 4080 Laptop GPU 4.50.0.dev0 bitsandbytes version: 0.45.3 peft version: 0.15.2.dev0 True
Load dataset, base model, and tokeniser¶
In [2]:
from datasets import load_dataset
imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset.rename_column("label", "labels")
# Split the test set into validation and test sets
test_val_split = imdb_dataset['test'].train_test_split(test_size=0.95, seed=42)
imdb_dataset['validation'] = test_val_split['train']
imdb_dataset['test'] = test_val_split['test']
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score
# Determine the number of labels
num_labels = len(set(imdb_dataset["train"]["labels"]))
print(f"Number of labels: {num_labels}")
# Load the tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Tokenize the whole dataset, truncate to 384 tokens
def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, max_length=384)
dataset_encoded = imdb_dataset.map(tokenize, batched=True, batch_size=None)
# Load the pretrained model for sequence classification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (AutoModelForSequenceClassification
         .from_pretrained(model_ckpt, num_labels=num_labels)
         .to(device))
print(model)
Number of labels: 2
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
  )
  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)
In [3]:
torch.autograd.set_detect_anomaly(True)
# Define the performance metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}
LoRA¶
In [4]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# Modify model by injecting LoRA and DoRA layers
# Borrowed from: https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim)) #### all zeroes!
        self.alpha = alpha
    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
    def forward(self, x):
        return self.linear(x) + self.lora(x)
class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha, scaling_factor=1.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.m = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
        #self.scale = nn.Parameter(torch.tensor(float(scaling_factor)))
        self.scale = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))
    def forward(self, x):
        linear_output = self.linear(x)
        lora_output = self.lora(x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=-1, keepdim=True) + 1e-9)        
        dora_modification = self.scale * self.m * lora_output_norm
        return linear_output + dora_modification
In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# Function to inject LoRA into specified linear layers
def inject_lora_all_attn(model, rank, alpha):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            lora_linear = LinearWithLoRA(original_linear, rank, alpha)
            setattr(parent_module, name.split('.')[-1], lora_linear)
    return model
# Function to inject DoRA into specified linear layers
def inject_dora_all_attn(model, rank, alpha, scaling_factor=1.0):
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            dora_linear = LinearWithDoRA(original_linear, rank, alpha, scaling_factor)
            setattr(parent_module, name.split('.')[-1], dora_linear)
    return model
def count_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params, 100 * trainable_params / total_params
def freeze_model_layers(model, unfreeze_pre_classifier=False):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze LoRA and DoRA-specific params
    for name, param in model.named_parameters():
        if (
            "lora.A" in name
            or "lora.B" in name
            or name.endswith(".m")  # For DoRA
            or name.endswith(".m_in") # For DDoRA
            or name.endswith(".m_out") # For DDoRA
            or "scale" in name 
        ):
            param.requires_grad = True
    # Unfreeze classifier layer (always)
    for name, param in model.named_parameters():
        if name.startswith("classifier."):
            param.requires_grad = True
    # unfreeze pre-classifier
    if unfreeze_pre_classifier:
        for name, param in model.named_parameters():
            if name.startswith("pre_classifier."):            
                param.requires_grad = True
In [6]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 1.5e-5 ##########
weight_decay = 0.0
output_dir_prefix = "finetuned-imdb-"
model_lora_all_attn = copy.deepcopy(model)
model_lora_all_attn = inject_lora_all_attn(model_lora_all_attn, lora_rank, lora_alpha)
freeze_model_layers(model_lora_all_attn, unfreeze_pre_classifier=True)
total_params_lora, trainable_params_lora, percentage_lora = count_trainable_parameters(model_lora_all_attn)
print(f"\nLoRA (All Attention) - Total parameters: {total_params_lora:,}")
print(f"LoRA (All Attention) - Trainable parameters: {trainable_params_lora:,} ({percentage_lora:.2f}%)")
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_lora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}lora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)
trainer_lora_all_attn = Trainer(model=model_lora_all_attn, args=training_args_lora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
LoRA (All Attention) - Total parameters: 67,544,834 LoRA (All Attention) - Trainable parameters: 1,181,954 (1.75%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [7]:
trainer_lora_all_attn.train()
eval_results_lora_all_attn = trainer_lora_all_attn.evaluate(dataset_encoded["test"])
print(f"LoRA (All Attention) Test Results: {eval_results_lora_all_attn}")
print('Prediction drift')
outputs = trainer_lora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print (torch.cuda.memory_summary())
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_lora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
      [3910/3910 1:29:10, Epoch 5/5]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 50 | 0.549600 | 0.334510 | 0.864800 | 0.864919 | 
| 100 | 0.363100 | 0.314867 | 0.873600 | 0.873637 | 
| 150 | 0.368100 | 0.283882 | 0.884000 | 0.884009 | 
| 200 | 0.289200 | 0.270112 | 0.887200 | 0.887403 | 
| 250 | 0.304900 | 0.246406 | 0.903200 | 0.903162 | 
| 300 | 0.304000 | 0.256413 | 0.904800 | 0.904319 | 
| 350 | 0.313300 | 0.257387 | 0.896000 | 0.896121 | 
| 400 | 0.275000 | 0.360744 | 0.857600 | 0.857483 | 
| 450 | 0.278100 | 0.277102 | 0.896800 | 0.897004 | 
| 500 | 0.305700 | 0.243079 | 0.904000 | 0.904014 | 
| 550 | 0.285300 | 0.258897 | 0.897600 | 0.897417 | 
| 600 | 0.303100 | 0.258372 | 0.900800 | 0.900313 | 
| 650 | 0.269200 | 0.246777 | 0.901600 | 0.901607 | 
| 700 | 0.260800 | 0.241955 | 0.904000 | 0.903848 | 
| 750 | 0.258400 | 0.243109 | 0.909600 | 0.909285 | 
| 800 | 0.282400 | 0.338753 | 0.866400 | 0.863283 | 
| 850 | 0.237000 | 0.239093 | 0.908800 | 0.908864 | 
| 900 | 0.217300 | 0.261071 | 0.907200 | 0.907265 | 
| 950 | 0.231800 | 0.263062 | 0.904000 | 0.903630 | 
| 1000 | 0.229900 | 0.261102 | 0.907200 | 0.907155 | 
| 1050 | 0.221300 | 0.245763 | 0.904000 | 0.903360 | 
| 1100 | 0.218000 | 0.285926 | 0.899200 | 0.898366 | 
| 1150 | 0.208400 | 0.238600 | 0.906400 | 0.906363 | 
| 1200 | 0.211900 | 0.268987 | 0.906400 | 0.905927 | 
| 1250 | 0.231800 | 0.293073 | 0.892000 | 0.892213 | 
| 1300 | 0.231000 | 0.240379 | 0.904800 | 0.904965 | 
| 1350 | 0.232700 | 0.251226 | 0.903200 | 0.903375 | 
| 1400 | 0.231500 | 0.236147 | 0.908800 | 0.908771 | 
| 1450 | 0.216500 | 0.243060 | 0.906400 | 0.906119 | 
| 1500 | 0.218400 | 0.235548 | 0.908000 | 0.907745 | 
| 1550 | 0.223300 | 0.237270 | 0.908800 | 0.908579 | 
| 1600 | 0.174400 | 0.258777 | 0.914400 | 0.914256 | 
| 1650 | 0.156200 | 0.266737 | 0.911200 | 0.910846 | 
| 1700 | 0.168200 | 0.254983 | 0.914400 | 0.914337 | 
| 1750 | 0.184600 | 0.265700 | 0.902400 | 0.902585 | 
| 1800 | 0.182500 | 0.258778 | 0.907200 | 0.907287 | 
| 1850 | 0.184000 | 0.251654 | 0.912800 | 0.912877 | 
| 1900 | 0.154400 | 0.277561 | 0.906400 | 0.906447 | 
| 1950 | 0.192000 | 0.241895 | 0.903200 | 0.903296 | 
| 2000 | 0.185900 | 0.254578 | 0.913600 | 0.913789 | 
| 2050 | 0.197500 | 0.240309 | 0.902400 | 0.902245 | 
| 2100 | 0.192700 | 0.260683 | 0.901600 | 0.901103 | 
| 2150 | 0.147100 | 0.262304 | 0.908000 | 0.908081 | 
| 2200 | 0.183600 | 0.251887 | 0.910400 | 0.910439 | 
| 2250 | 0.207200 | 0.251087 | 0.909600 | 0.909750 | 
| 2300 | 0.196100 | 0.249030 | 0.900800 | 0.901003 | 
| 2350 | 0.172600 | 0.248261 | 0.910400 | 0.910426 | 
| 2400 | 0.146400 | 0.240490 | 0.919200 | 0.919168 | 
| 2450 | 0.145300 | 0.248667 | 0.908800 | 0.908598 | 
| 2500 | 0.108100 | 0.305569 | 0.899200 | 0.899355 | 
| 2550 | 0.140500 | 0.277428 | 0.910400 | 0.910142 | 
| 2600 | 0.152600 | 0.284274 | 0.916800 | 0.916913 | 
| 2650 | 0.141600 | 0.252135 | 0.916800 | 0.916731 | 
| 2700 | 0.133600 | 0.268836 | 0.912000 | 0.911958 | 
| 2750 | 0.151300 | 0.262349 | 0.914400 | 0.913991 | 
| 2800 | 0.121400 | 0.255012 | 0.917600 | 0.917494 | 
| 2850 | 0.131700 | 0.269369 | 0.913600 | 0.913222 | 
| 2900 | 0.152800 | 0.246611 | 0.916800 | 0.916716 | 
| 2950 | 0.132300 | 0.251266 | 0.908000 | 0.907724 | 
| 3000 | 0.142600 | 0.265384 | 0.918400 | 0.918532 | 
| 3050 | 0.139300 | 0.260968 | 0.909600 | 0.909483 | 
| 3100 | 0.137700 | 0.303503 | 0.904000 | 0.903330 | 
| 3150 | 0.121600 | 0.266256 | 0.920000 | 0.920056 | 
| 3200 | 0.101800 | 0.283502 | 0.911200 | 0.911288 | 
| 3250 | 0.090600 | 0.293355 | 0.912800 | 0.912635 | 
| 3300 | 0.112900 | 0.290351 | 0.912000 | 0.911958 | 
| 3350 | 0.086300 | 0.298528 | 0.916000 | 0.915907 | 
| 3400 | 0.086100 | 0.310346 | 0.915200 | 0.915280 | 
| 3450 | 0.093600 | 0.306871 | 0.914400 | 0.914352 | 
| 3500 | 0.107200 | 0.312322 | 0.917600 | 0.917606 | 
| 3550 | 0.121300 | 0.315530 | 0.912800 | 0.912844 | 
| 3600 | 0.117800 | 0.316997 | 0.915200 | 0.914956 | 
| 3650 | 0.133400 | 0.311634 | 0.911200 | 0.911179 | 
| 3700 | 0.107600 | 0.316201 | 0.915200 | 0.915248 | 
| 3750 | 0.107800 | 0.312200 | 0.916000 | 0.916019 | 
| 3800 | 0.086900 | 0.314815 | 0.912800 | 0.912844 | 
| 3850 | 0.105500 | 0.314706 | 0.913600 | 0.913600 | 
| 3900 | 0.069900 | 0.314938 | 0.912800 | 0.912780 | 
LoRA (All Attention) Test Results: {'eval_loss': 0.22863808274269104, 'eval_accuracy': 0.9094315789473684, 'eval_f1': 0.9093754564531454, 'eval_runtime': 129.5483, 'eval_samples_per_second': 183.329, 'eval_steps_per_second': 5.735, 'epoch': 5.0}
Prediction drift
[[-1.9898635  1.6761417]
 [-2.4649115  2.2287219]
 [ 2.1762567 -1.8627577]
 [-1.866952   1.787821 ]
 [ 2.8843796 -2.5472975]] [1 1 0 1 0]
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 553486 KiB |   3293 MiB |  90884 GiB |  90883 GiB |
|       from large pool | 546048 KiB |   3264 MiB |  90258 GiB |  90258 GiB |
|       from small pool |   7438 KiB |     31 MiB |    625 GiB |    625 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 553486 KiB |   3293 MiB |  90884 GiB |  90883 GiB |
|       from large pool | 546048 KiB |   3264 MiB |  90258 GiB |  90258 GiB |
|       from small pool |   7438 KiB |     31 MiB |    625 GiB |    625 GiB |
|---------------------------------------------------------------------------|
| Requested memory      | 551272 KiB |   3290 MiB |  90677 GiB |  90676 GiB |
|       from large pool | 543836 KiB |   3260 MiB |  90054 GiB |  90053 GiB |
|       from small pool |   7436 KiB |     31 MiB |    623 GiB |    623 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   3514 MiB |   3514 MiB |   3514 MiB |      0 B   |
|       from large pool |   3476 MiB |   3476 MiB |   3476 MiB |      0 B   |
|       from small pool |     38 MiB |     38 MiB |     38 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory |  67058 KiB | 308581 KiB |  40038 GiB |  40038 GiB |
|       from large pool |  64256 KiB | 300928 KiB |  39368 GiB |  39368 GiB |
|       from small pool |   2802 KiB |   9018 KiB |    670 GiB |    670 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     364    |     491    |   10562 K  |   10562 K  |
|       from large pool |      82    |     140    |    3162 K  |    3162 K  |
|       from small pool |     282    |     395    |    7400 K  |    7400 K  |
|---------------------------------------------------------------------------|
| Active allocs         |     364    |     491    |   10562 K  |   10562 K  |
|       from large pool |      82    |     140    |    3162 K  |    3162 K  |
|       from small pool |     282    |     395    |    7400 K  |    7400 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      79    |      79    |      79    |       0    |
|       from large pool |      60    |      60    |      60    |       0    |
|       from small pool |      19    |      19    |      19    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      27    |      44    |    4813 K  |    4813 K  |
|       from large pool |      18    |      23    |     935 K  |     935 K  |
|       from small pool |       9    |      23    |    3878 K  |    3878 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|
LoRA Heatmap
In [8]:
print('Parameter Statistics: mean.abs()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.20137116312980652 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.00041380099719390273 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.2003246247768402 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.00042878554086200893 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.19911907613277435 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.000338834710419178 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.20047757029533386 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.00034657015930861235 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.1993226706981659 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.0004094692412763834 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.1983720064163208 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.000427080609370023 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.20035219192504883 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.00030109973158687353 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.20086750388145447 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.00034795209649018943 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.19976946711540222 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.00040535288280807436 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.19660572707653046 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.00038226376636885107 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.19878578186035156 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.00027373715420253575 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.19986116886138916 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.0003197303449269384 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.1998608410358429 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.00037932227132841945 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.2004813402891159 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.00038690734072588384 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.2002096176147461 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.00028052262496203184 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.19950413703918457 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.00029115224606357515 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.19811344146728516 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.0004094548639841378 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.1975163221359253 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.00039162003668025136 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.1999872624874115 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.000269879907136783 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.19805268943309784 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.00028756255051121116 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.20075345039367676 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.00037421341403387487 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.19600680470466614 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.0003485583874862641 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.19779205322265625 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.0002423622936476022 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.19977574050426483 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.0003247036365792155
In [9]:
print('Parameter Statistics: param.norm()')
for name, param in model_lora_all_attn.named_parameters():
    if "lora" in name or name.endswith(".m"):
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 27.8491 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 0.0578 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 27.7708 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 0.0597 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 27.5879 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 0.0473 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 27.8318 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 0.0484 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 27.8054 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 0.0578 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 27.5557 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 0.0595 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 27.9069 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 0.0428 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 27.9222 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 0.0486 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 27.7035 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 0.0568 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 27.3314 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 0.0532 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 27.6290 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 0.0388 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 27.6716 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 0.0449 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 27.6330 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 0.0529 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 27.8436 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 0.0541 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 27.8198 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 0.0395 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 27.7259 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 0.0411 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 27.6222 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 0.0573 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 27.4509 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 0.0547 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 27.8650 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 0.0380 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 27.6073 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 0.0405 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 27.9097 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 0.0525 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 27.3422 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 0.0490 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 27.4174 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 0.0341 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 27.7203 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 0.0450
DoRA¶
In [10]:
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 2e-2 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"
model_dora_all_attn = copy.deepcopy(model)
model_dora_all_attn = inject_dora_all_attn(model_dora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_dora_all_attn, unfreeze_pre_classifier=True)
total_params_dora, trainable_params_dora, percentage_dora = count_trainable_parameters(model_dora_all_attn)
print(f"\nDoRA (All Attention) - Total parameters: {total_params_dora:,}")
print(f"DoRA (All Attention) - Trainable parameters: {trainable_params_dora:,} ({percentage_dora:.2f}%)")
# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_dora_all_attn.named_parameters():
    if param.requires_grad:
        print(name)
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_dora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}dora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)
trainer_dora_all_attn = Trainer(model=model_dora_all_attn, args=training_args_dora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DoRA (All Attention) - Total parameters: 67,581,698 DoRA (All Attention) - Trainable parameters: 1,218,818 (1.80%) Trainable parameters after freezing: distilbert.transformer.layer.0.attention.q_lin.m distilbert.transformer.layer.0.attention.q_lin.scale distilbert.transformer.layer.0.attention.q_lin.lora.A distilbert.transformer.layer.0.attention.q_lin.lora.B distilbert.transformer.layer.0.attention.k_lin.m distilbert.transformer.layer.0.attention.k_lin.scale distilbert.transformer.layer.0.attention.k_lin.lora.A distilbert.transformer.layer.0.attention.k_lin.lora.B distilbert.transformer.layer.0.attention.v_lin.m distilbert.transformer.layer.0.attention.v_lin.scale distilbert.transformer.layer.0.attention.v_lin.lora.A distilbert.transformer.layer.0.attention.v_lin.lora.B distilbert.transformer.layer.0.attention.out_lin.m distilbert.transformer.layer.0.attention.out_lin.scale distilbert.transformer.layer.0.attention.out_lin.lora.A distilbert.transformer.layer.0.attention.out_lin.lora.B distilbert.transformer.layer.1.attention.q_lin.m distilbert.transformer.layer.1.attention.q_lin.scale distilbert.transformer.layer.1.attention.q_lin.lora.A distilbert.transformer.layer.1.attention.q_lin.lora.B distilbert.transformer.layer.1.attention.k_lin.m distilbert.transformer.layer.1.attention.k_lin.scale distilbert.transformer.layer.1.attention.k_lin.lora.A distilbert.transformer.layer.1.attention.k_lin.lora.B distilbert.transformer.layer.1.attention.v_lin.m distilbert.transformer.layer.1.attention.v_lin.scale distilbert.transformer.layer.1.attention.v_lin.lora.A distilbert.transformer.layer.1.attention.v_lin.lora.B distilbert.transformer.layer.1.attention.out_lin.m distilbert.transformer.layer.1.attention.out_lin.scale distilbert.transformer.layer.1.attention.out_lin.lora.A distilbert.transformer.layer.1.attention.out_lin.lora.B distilbert.transformer.layer.2.attention.q_lin.m distilbert.transformer.layer.2.attention.q_lin.scale distilbert.transformer.layer.2.attention.q_lin.lora.A distilbert.transformer.layer.2.attention.q_lin.lora.B distilbert.transformer.layer.2.attention.k_lin.m distilbert.transformer.layer.2.attention.k_lin.scale distilbert.transformer.layer.2.attention.k_lin.lora.A distilbert.transformer.layer.2.attention.k_lin.lora.B distilbert.transformer.layer.2.attention.v_lin.m distilbert.transformer.layer.2.attention.v_lin.scale distilbert.transformer.layer.2.attention.v_lin.lora.A distilbert.transformer.layer.2.attention.v_lin.lora.B distilbert.transformer.layer.2.attention.out_lin.m distilbert.transformer.layer.2.attention.out_lin.scale distilbert.transformer.layer.2.attention.out_lin.lora.A distilbert.transformer.layer.2.attention.out_lin.lora.B distilbert.transformer.layer.3.attention.q_lin.m distilbert.transformer.layer.3.attention.q_lin.scale distilbert.transformer.layer.3.attention.q_lin.lora.A distilbert.transformer.layer.3.attention.q_lin.lora.B distilbert.transformer.layer.3.attention.k_lin.m distilbert.transformer.layer.3.attention.k_lin.scale distilbert.transformer.layer.3.attention.k_lin.lora.A distilbert.transformer.layer.3.attention.k_lin.lora.B distilbert.transformer.layer.3.attention.v_lin.m distilbert.transformer.layer.3.attention.v_lin.scale distilbert.transformer.layer.3.attention.v_lin.lora.A distilbert.transformer.layer.3.attention.v_lin.lora.B distilbert.transformer.layer.3.attention.out_lin.m distilbert.transformer.layer.3.attention.out_lin.scale distilbert.transformer.layer.3.attention.out_lin.lora.A distilbert.transformer.layer.3.attention.out_lin.lora.B distilbert.transformer.layer.4.attention.q_lin.m distilbert.transformer.layer.4.attention.q_lin.scale distilbert.transformer.layer.4.attention.q_lin.lora.A distilbert.transformer.layer.4.attention.q_lin.lora.B distilbert.transformer.layer.4.attention.k_lin.m distilbert.transformer.layer.4.attention.k_lin.scale distilbert.transformer.layer.4.attention.k_lin.lora.A distilbert.transformer.layer.4.attention.k_lin.lora.B distilbert.transformer.layer.4.attention.v_lin.m distilbert.transformer.layer.4.attention.v_lin.scale distilbert.transformer.layer.4.attention.v_lin.lora.A distilbert.transformer.layer.4.attention.v_lin.lora.B distilbert.transformer.layer.4.attention.out_lin.m distilbert.transformer.layer.4.attention.out_lin.scale distilbert.transformer.layer.4.attention.out_lin.lora.A distilbert.transformer.layer.4.attention.out_lin.lora.B distilbert.transformer.layer.5.attention.q_lin.m distilbert.transformer.layer.5.attention.q_lin.scale distilbert.transformer.layer.5.attention.q_lin.lora.A distilbert.transformer.layer.5.attention.q_lin.lora.B distilbert.transformer.layer.5.attention.k_lin.m distilbert.transformer.layer.5.attention.k_lin.scale distilbert.transformer.layer.5.attention.k_lin.lora.A distilbert.transformer.layer.5.attention.k_lin.lora.B distilbert.transformer.layer.5.attention.v_lin.m distilbert.transformer.layer.5.attention.v_lin.scale distilbert.transformer.layer.5.attention.v_lin.lora.A distilbert.transformer.layer.5.attention.v_lin.lora.B distilbert.transformer.layer.5.attention.out_lin.m distilbert.transformer.layer.5.attention.out_lin.scale distilbert.transformer.layer.5.attention.out_lin.lora.A distilbert.transformer.layer.5.attention.out_lin.lora.B pre_classifier.weight pre_classifier.bias classifier.weight classifier.bias
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [11]:
trainer_dora_all_attn.train()
eval_results_dora_all_attn = trainer_dora_all_attn.evaluate(dataset_encoded["test"])
print(f"DoRA (All Attention) Test Results: {eval_results_dora_all_attn}")
print('Prediction drift')
outputs = trainer_dora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print (torch.cuda.memory_summary())
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_dora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_values = []
for name, module in model_dora_all_attn.named_modules():
    if isinstance(module, LinearWithDoRA):
        if hasattr(module, 'm'):
            m_param = module.m.detach().cpu().numpy().flatten()
            m_values.extend(m_param)
m_values = np.array(m_values)
# Analyze the distribution
print(f"Mean of m values: {np.mean(m_values):.4f}")
print(f"Standard deviation of m values: {np.std(m_values):.4f}")
print(f"Minimum m value: {np.min(m_values):.4f}")
print(f"Maximum m value: {np.max(m_values):.4f}")
# Plot a histogram
plt.hist(m_values, bins=50, alpha=0.7)
plt.title('Distribution of Learned m Values (DoRA)')
plt.xlabel('Magnitude (m)')
plt.ylabel('Frequency')
plt.show()
      [3910/3910 2:15:09, Epoch 5/5]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 50 | 1.102800 | 0.291960 | 0.870400 | 0.869901 | 
| 100 | 0.356100 | 0.278253 | 0.894400 | 0.894522 | 
| 150 | 0.343500 | 0.268145 | 0.897600 | 0.897671 | 
| 200 | 0.305300 | 0.261925 | 0.897600 | 0.897630 | 
| 250 | 0.321600 | 0.257692 | 0.904800 | 0.904558 | 
| 300 | 0.305000 | 0.272489 | 0.887200 | 0.887449 | 
| 350 | 0.300000 | 0.259093 | 0.901600 | 0.901282 | 
| 400 | 0.266300 | 0.279098 | 0.884000 | 0.884259 | 
| 450 | 0.258800 | 0.272474 | 0.892000 | 0.891513 | 
| 500 | 0.288000 | 0.242344 | 0.905600 | 0.905585 | 
| 550 | 0.280400 | 0.270589 | 0.901600 | 0.900959 | 
| 600 | 0.301900 | 0.249609 | 0.897600 | 0.897794 | 
| 650 | 0.277600 | 0.265940 | 0.904800 | 0.904469 | 
| 700 | 0.278000 | 0.279490 | 0.911200 | 0.910675 | 
| 750 | 0.293100 | 0.268084 | 0.897600 | 0.896787 | 
| 800 | 0.254900 | 0.299735 | 0.896800 | 0.895963 | 
| 850 | 0.253600 | 0.252459 | 0.896800 | 0.896935 | 
| 900 | 0.235600 | 0.269197 | 0.900800 | 0.900881 | 
| 950 | 0.271800 | 0.285552 | 0.902400 | 0.902503 | 
| 1000 | 0.232000 | 0.225912 | 0.908800 | 0.908579 | 
| 1050 | 0.219700 | 0.244564 | 0.914400 | 0.914337 | 
| 1100 | 0.234700 | 0.229250 | 0.913600 | 0.913445 | 
| 1150 | 0.217400 | 0.221898 | 0.905600 | 0.905600 | 
| 1200 | 0.230000 | 0.237117 | 0.909600 | 0.909350 | 
| 1250 | 0.221200 | 0.219938 | 0.912000 | 0.911958 | 
| 1300 | 0.255900 | 0.222013 | 0.908000 | 0.908007 | 
| 1350 | 0.260000 | 0.216425 | 0.915200 | 0.914956 | 
| 1400 | 0.223400 | 0.217995 | 0.911200 | 0.911256 | 
| 1450 | 0.219600 | 0.222689 | 0.912000 | 0.912000 | 
| 1500 | 0.215300 | 0.224147 | 0.914400 | 0.914337 | 
| 1550 | 0.227700 | 0.224390 | 0.919200 | 0.919251 | 
| 1600 | 0.196300 | 0.254192 | 0.920800 | 0.920870 | 
| 1650 | 0.182700 | 0.261529 | 0.916000 | 0.915823 | 
| 1700 | 0.205700 | 0.254054 | 0.916800 | 0.916579 | 
| 1750 | 0.199100 | 0.270258 | 0.912000 | 0.912000 | 
| 1800 | 0.209500 | 0.224310 | 0.917600 | 0.917509 | 
| 1850 | 0.211800 | 0.236137 | 0.915200 | 0.915200 | 
| 1900 | 0.177900 | 0.244019 | 0.916800 | 0.916760 | 
| 1950 | 0.204800 | 0.241636 | 0.916800 | 0.916836 | 
| 2000 | 0.205300 | 0.228495 | 0.919200 | 0.919280 | 
| 2050 | 0.212400 | 0.233124 | 0.918400 | 0.918435 | 
| 2100 | 0.200900 | 0.225622 | 0.911200 | 0.911278 | 
| 2150 | 0.153700 | 0.256996 | 0.912000 | 0.912061 | 
| 2200 | 0.209800 | 0.220482 | 0.914400 | 0.914352 | 
| 2250 | 0.196600 | 0.229680 | 0.916000 | 0.915858 | 
| 2300 | 0.202500 | 0.218464 | 0.910400 | 0.910504 | 
| 2350 | 0.181500 | 0.225591 | 0.917600 | 0.917509 | 
| 2400 | 0.185800 | 0.250574 | 0.912000 | 0.911543 | 
| 2450 | 0.157600 | 0.230985 | 0.915200 | 0.914915 | 
| 2500 | 0.141200 | 0.245401 | 0.910400 | 0.910473 | 
| 2550 | 0.136100 | 0.296957 | 0.918400 | 0.918165 | 
| 2600 | 0.167800 | 0.249483 | 0.902400 | 0.902613 | 
| 2650 | 0.157600 | 0.243417 | 0.912000 | 0.911705 | 
| 2700 | 0.157700 | 0.241792 | 0.918400 | 0.918412 | 
| 2750 | 0.158800 | 0.230028 | 0.912000 | 0.911683 | 
| 2800 | 0.133300 | 0.229130 | 0.916800 | 0.916651 | 
| 2850 | 0.145600 | 0.294840 | 0.917600 | 0.917251 | 
| 2900 | 0.157700 | 0.267986 | 0.916000 | 0.915728 | 
| 2950 | 0.150400 | 0.273056 | 0.912800 | 0.912384 | 
| 3000 | 0.157900 | 0.240002 | 0.918400 | 0.918184 | 
| 3050 | 0.157400 | 0.243159 | 0.920000 | 0.919919 | 
| 3100 | 0.148800 | 0.250069 | 0.914400 | 0.913894 | 
| 3150 | 0.124400 | 0.255732 | 0.916800 | 0.916700 | 
| 3200 | 0.127000 | 0.246215 | 0.920800 | 0.920698 | 
| 3250 | 0.116300 | 0.283555 | 0.912000 | 0.911786 | 
| 3300 | 0.110700 | 0.277807 | 0.912000 | 0.911860 | 
| 3350 | 0.112800 | 0.296765 | 0.916800 | 0.916634 | 
| 3400 | 0.115000 | 0.259858 | 0.912800 | 0.912670 | 
| 3450 | 0.117300 | 0.274300 | 0.912800 | 0.912598 | 
| 3500 | 0.110300 | 0.284478 | 0.915200 | 0.915114 | 
| 3550 | 0.123000 | 0.267429 | 0.916000 | 0.915923 | 
| 3600 | 0.126300 | 0.266704 | 0.918400 | 0.918202 | 
| 3650 | 0.131900 | 0.254044 | 0.916000 | 0.915892 | 
| 3700 | 0.125600 | 0.256258 | 0.916800 | 0.916634 | 
| 3750 | 0.109800 | 0.262411 | 0.916000 | 0.915938 | 
| 3800 | 0.105500 | 0.270246 | 0.914400 | 0.914238 | 
| 3850 | 0.115400 | 0.272160 | 0.914400 | 0.914238 | 
| 3900 | 0.108500 | 0.268805 | 0.914400 | 0.914238 | 
DoRA (All Attention) Test Results: {'eval_loss': 0.21464677155017853, 'eval_accuracy': 0.9129263157894737, 'eval_f1': 0.9128732729336847, 'eval_runtime': 146.8168, 'eval_samples_per_second': 161.766, 'eval_steps_per_second': 5.061, 'epoch': 5.0}
Prediction drift
[[-2.0174963  1.9359767]
 [-2.774578   2.4995706]
 [ 1.9803579 -2.1773782]
 [-1.6433766  1.4992952]
 [ 6.4219594 -7.323969 ]] [1 1 0 1 0]
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    809 MiB |   5291 MiB | 230768 GiB | 230767 GiB |
|       from large pool |    794 MiB |   5252 MiB | 229489 GiB | 229488 GiB |
|       from small pool |     14 MiB |     40 MiB |   1278 GiB |   1278 GiB |
|---------------------------------------------------------------------------|
| Active memory         |    809 MiB |   5291 MiB | 230768 GiB | 230767 GiB |
|       from large pool |    794 MiB |   5252 MiB | 229489 GiB | 229488 GiB |
|       from small pool |     14 MiB |     40 MiB |   1278 GiB |   1278 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |    805 MiB |   5287 MiB | 230551 GiB | 230550 GiB |
|       from large pool |    790 MiB |   5248 MiB | 229278 GiB | 229277 GiB |
|       from small pool |     14 MiB |     40 MiB |   1273 GiB |   1273 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   5500 MiB |   5500 MiB |   8120 MiB |   2620 MiB |
|       from large pool |   5456 MiB |   5456 MiB |   8048 MiB |   2592 MiB |
|       from small pool |     44 MiB |     44 MiB |     72 MiB |     28 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |  98691 KiB | 545173 KiB |  88783 GiB |  88783 GiB |
|       from large pool |  91218 KiB | 540498 KiB |  87421 GiB |  87420 GiB |
|       from small pool |   7473 KiB |  17003 KiB |   1362 GiB |   1362 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     765    |    1012    |   26004 K  |   26003 K  |
|       from large pool |     123    |     229    |    7919 K  |    7919 K  |
|       from small pool |     642    |     851    |   18084 K  |   18084 K  |
|---------------------------------------------------------------------------|
| Active allocs         |     765    |    1012    |   26004 K  |   26003 K  |
|       from large pool |     123    |     229    |    7919 K  |    7919 K  |
|       from small pool |     642    |     851    |   18084 K  |   18084 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     134    |     134    |     181    |      47    |
|       from large pool |     112    |     112    |     145    |      33    |
|       from small pool |      22    |      22    |      36    |      14    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      35    |      53    |   12196 K  |   12196 K  |
|       from large pool |      17    |      26    |    2304 K  |    2304 K  |
|       from small pool |      18    |      32    |    9892 K  |    9892 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|
LoRA Heatmap
Mean of m values: 0.0014 Standard deviation of m values: 0.5832 Minimum m value: -4.2395 Maximum m value: 3.2943
In [12]:
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_dora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m 0.6335324048995972 distilbert.transformer.layer.0.attention.k_lin.m 0.5340768694877625 distilbert.transformer.layer.0.attention.v_lin.m 0.4119078516960144 distilbert.transformer.layer.0.attention.out_lin.m 0.20187006890773773 distilbert.transformer.layer.1.attention.q_lin.m 0.45603156089782715 distilbert.transformer.layer.1.attention.k_lin.m 0.39353370666503906 distilbert.transformer.layer.1.attention.v_lin.m 0.2410321682691574 distilbert.transformer.layer.1.attention.out_lin.m 0.22153238952159882 distilbert.transformer.layer.2.attention.q_lin.m 0.4081450402736664 distilbert.transformer.layer.2.attention.k_lin.m 0.44084256887435913 distilbert.transformer.layer.2.attention.v_lin.m 0.24717780947685242 distilbert.transformer.layer.2.attention.out_lin.m 0.16525350511074066 distilbert.transformer.layer.3.attention.q_lin.m 0.4876405596733093 distilbert.transformer.layer.3.attention.k_lin.m 0.4945552349090576 distilbert.transformer.layer.3.attention.v_lin.m 0.2434060424566269 distilbert.transformer.layer.3.attention.out_lin.m 0.24130374193191528 distilbert.transformer.layer.4.attention.q_lin.m 0.48577117919921875 distilbert.transformer.layer.4.attention.k_lin.m 0.5650743246078491 distilbert.transformer.layer.4.attention.v_lin.m 0.18068918585777283 distilbert.transformer.layer.4.attention.out_lin.m 0.2408987283706665 distilbert.transformer.layer.5.attention.q_lin.m 0.37161561846733093 distilbert.transformer.layer.5.attention.k_lin.m 0.6317863464355469 distilbert.transformer.layer.5.attention.v_lin.m 0.19014406204223633 distilbert.transformer.layer.5.attention.out_lin.m 0.28851351141929626 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale 2.162614583969116 distilbert.transformer.layer.0.attention.k_lin.scale 1.9926241636276245 distilbert.transformer.layer.0.attention.v_lin.scale 1.9132812023162842 distilbert.transformer.layer.0.attention.out_lin.scale 1.6308307647705078 distilbert.transformer.layer.1.attention.q_lin.scale 1.9580211639404297 distilbert.transformer.layer.1.attention.k_lin.scale 1.8660268783569336 distilbert.transformer.layer.1.attention.v_lin.scale 1.6165016889572144 distilbert.transformer.layer.1.attention.out_lin.scale 1.5104354619979858 distilbert.transformer.layer.2.attention.q_lin.scale 1.838735818862915 distilbert.transformer.layer.2.attention.k_lin.scale 1.8984363079071045 distilbert.transformer.layer.2.attention.v_lin.scale 1.6771743297576904 distilbert.transformer.layer.2.attention.out_lin.scale 1.53759765625 distilbert.transformer.layer.3.attention.q_lin.scale 1.9363187551498413 distilbert.transformer.layer.3.attention.k_lin.scale 1.934944987297058 distilbert.transformer.layer.3.attention.v_lin.scale 1.6585863828659058 distilbert.transformer.layer.3.attention.out_lin.scale 1.585425853729248 distilbert.transformer.layer.4.attention.q_lin.scale 2.027859687805176 distilbert.transformer.layer.4.attention.k_lin.scale 2.082413673400879 distilbert.transformer.layer.4.attention.v_lin.scale 1.5296372175216675 distilbert.transformer.layer.4.attention.out_lin.scale 1.618304967880249 distilbert.transformer.layer.5.attention.q_lin.scale 1.9490175247192383 distilbert.transformer.layer.5.attention.k_lin.scale 2.1951510906219482 distilbert.transformer.layer.5.attention.v_lin.scale 1.6163949966430664 distilbert.transformer.layer.5.attention.out_lin.scale 1.8345234394073486 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6016935110092163 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.31266412138938904 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.6316931247711182 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.2952665090560913 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.6079772710800171 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.2492581456899643 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4950897693634033 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.18595407903194427 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.6037637591362 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.2575852870941162 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5715399384498596 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.21878382563591003 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5992859601974487 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.23800623416900635 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.5236724019050598 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24507002532482147 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5646731853485107 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.28050732612609863 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5719441771507263 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2973446249961853 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.5592368841171265 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.20374992489814758 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4538516402244568 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.18519560992717743 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5853655934333801 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.30569520592689514 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.6070346832275391 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.31933528184890747 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.5087218284606934 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.21019062399864197 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.5895759463310242 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.21719345450401306 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.6171450614929199 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.21375727653503418 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.49676698446273804 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2762928605079651 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.5012800097465515 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.17582902312278748 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.5615952610969543 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.17805537581443787 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.44116735458374023 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.16018781065940857 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.47184187173843384 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.27603811025619507 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.39102858304977417 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.14989939332008362 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.4704250693321228 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.1333407163619995
In [13]:
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_dora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 83.6438 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 50.3954 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 88.2270 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 49.0835 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 86.0185 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 41.9394 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 69.8777 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 32.6178 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 84.0461 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 42.5906 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 79.4528 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 37.7325 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 83.9680 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 39.0889 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 74.0652 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 40.8909 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 78.3896 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 45.2058 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 79.5360 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 47.1290 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 77.7262 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 34.1224 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 63.6211 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 31.4881 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.7695 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 49.9913 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.6226 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 50.7520 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 71.5488 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 34.2226 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 82.8060 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 36.0456 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 86.1940 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 37.5884 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 69.4333 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 43.1834 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 70.5987 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 28.1566 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 79.8895 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 29.1581 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 61.4309 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 28.0646 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 65.4785 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 40.8050 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 55.7645 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 23.9943 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 66.5217 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 21.6479 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m weight norm: 24.3380 distilbert.transformer.layer.0.attention.k_lin.m weight norm: 21.5448 distilbert.transformer.layer.0.attention.v_lin.m weight norm: 16.9134 distilbert.transformer.layer.0.attention.out_lin.m weight norm: 10.2597 distilbert.transformer.layer.1.attention.q_lin.m weight norm: 18.8046 distilbert.transformer.layer.1.attention.k_lin.m weight norm: 17.2833 distilbert.transformer.layer.1.attention.v_lin.m weight norm: 10.7684 distilbert.transformer.layer.1.attention.out_lin.m weight norm: 9.8165 distilbert.transformer.layer.2.attention.q_lin.m weight norm: 16.5732 distilbert.transformer.layer.2.attention.k_lin.m weight norm: 17.3353 distilbert.transformer.layer.2.attention.v_lin.m weight norm: 11.1499 distilbert.transformer.layer.2.attention.out_lin.m weight norm: 8.2443 distilbert.transformer.layer.3.attention.q_lin.m weight norm: 19.2109 distilbert.transformer.layer.3.attention.k_lin.m weight norm: 19.1746 distilbert.transformer.layer.3.attention.v_lin.m weight norm: 11.2522 distilbert.transformer.layer.3.attention.out_lin.m weight norm: 11.3760 distilbert.transformer.layer.4.attention.q_lin.m weight norm: 20.1803 distilbert.transformer.layer.4.attention.k_lin.m weight norm: 21.7132 distilbert.transformer.layer.4.attention.v_lin.m weight norm: 9.0074 distilbert.transformer.layer.4.attention.out_lin.m weight norm: 12.0543 distilbert.transformer.layer.5.attention.q_lin.m weight norm: 16.7970 distilbert.transformer.layer.5.attention.k_lin.m weight norm: 22.2302 distilbert.transformer.layer.5.attention.v_lin.m weight norm: 10.2690 distilbert.transformer.layer.5.attention.out_lin.m weight norm: 14.5015 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale weight norm: 63.7493 distilbert.transformer.layer.0.attention.k_lin.scale weight norm: 58.9528 distilbert.transformer.layer.0.attention.v_lin.scale weight norm: 55.6887 distilbert.transformer.layer.0.attention.out_lin.scale weight norm: 47.4388 distilbert.transformer.layer.1.attention.q_lin.scale weight norm: 57.3731 distilbert.transformer.layer.1.attention.k_lin.scale weight norm: 54.8432 distilbert.transformer.layer.1.attention.v_lin.scale weight norm: 47.2854 distilbert.transformer.layer.1.attention.out_lin.scale weight norm: 44.8088 distilbert.transformer.layer.2.attention.q_lin.scale weight norm: 53.9172 distilbert.transformer.layer.2.attention.k_lin.scale weight norm: 55.4214 distilbert.transformer.layer.2.attention.v_lin.scale weight norm: 48.5935 distilbert.transformer.layer.2.attention.out_lin.scale weight norm: 44.7601 distilbert.transformer.layer.3.attention.q_lin.scale weight norm: 56.9923 distilbert.transformer.layer.3.attention.k_lin.scale weight norm: 56.9395 distilbert.transformer.layer.3.attention.v_lin.scale weight norm: 48.5040 distilbert.transformer.layer.3.attention.out_lin.scale weight norm: 47.1696 distilbert.transformer.layer.4.attention.q_lin.scale weight norm: 59.5007 distilbert.transformer.layer.4.attention.k_lin.scale weight norm: 61.3537 distilbert.transformer.layer.4.attention.v_lin.scale weight norm: 45.5443 distilbert.transformer.layer.4.attention.out_lin.scale weight norm: 48.1659 distilbert.transformer.layer.5.attention.q_lin.scale weight norm: 56.5651 distilbert.transformer.layer.5.attention.k_lin.scale weight norm: 63.7824 distilbert.transformer.layer.5.attention.v_lin.scale weight norm: 47.5776 distilbert.transformer.layer.5.attention.out_lin.scale weight norm: 53.4955
DDoRA, Double DoRA¶
(Double Weight-Decomposed Low-Rank Adaptation)¶
In [14]:
class LinearWithDoubleDoRA(nn.Module):
    def __init__(self, linear, rank, alpha, scaling_factor=1.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())        
        self.m_out = nn.Parameter(torch.randn(1, linear.out_features) * std_dev)
        self.m_in = nn.Parameter(torch.randn(linear.in_features, 1) * std_dev)        
        self.scale_out = nn.Parameter(torch.full((1, linear.out_features), float(scaling_factor)))
        self.scale_in = nn.Parameter(torch.full((linear.in_features, 1), float(scaling_factor)))
    def forward(self, x):
        scaled_x = x * self.scale_in.T * self.m_in.T # Broadcasting m_in + scaling
        linear_output = self.linear(x)
        lora_output = self.lora(scaled_x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)
        dora_modification = self.scale_out * self.m_out * lora_output_norm
        return linear_output + dora_modification
In [15]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
def inject_ddora_all_attn(model, rank, alpha, scaling_factor=1.0):
    #target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "ffn.lin1", "ffn.lin2"]
    target_layers = ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(layer in name for layer in target_layers):
            parent_name = name.rsplit('.', 1)[0]
            parent_module = model.get_submodule(parent_name)
            original_linear = getattr(parent_module, name.split('.')[-1])
            dora_linear = LinearWithDoubleDoRA(original_linear, rank, alpha, scaling_factor)
            setattr(parent_module, name.split('.')[-1], dora_linear) 
    return model
In [16]:
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
torch.manual_seed(137)
lora_rank = 16
lora_alpha = 32
batch_size = 32
learning_rate = 0.015 ############
weight_decay = 1e-4
scaling_factor=2.0
output_dir_prefix = "finetuned-imdb-"
model_ddora_all_attn = copy.deepcopy(model)
model_ddora_all_attn = inject_ddora_all_attn(model_ddora_all_attn, lora_rank, lora_alpha, scaling_factor)
freeze_model_layers(model_ddora_all_attn, unfreeze_pre_classifier=True)
total_params_ddora, trainable_params_ddora, percentage_ddora = count_trainable_parameters(model_ddora_all_attn)
print(f"\nDDoRA (All Attention) - Total parameters: {total_params_ddora:,}")
print(f"DDoRA (All Attention) - Trainable parameters: {trainable_params_ddora:,} ({percentage_ddora:.2f}%)")
eval_steps = 50
logging_steps = 50
output_dir_prefix = "finetuned-imdb-"
training_args_ddora_all_attn = TrainingArguments(
    output_dir=f"{output_dir_prefix}ddora-all-attn",
    num_train_epochs=5,
    #max_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    save_steps=eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    disable_tqdm=False,
    push_to_hub=False,
    report_to="none",
    log_level="error"
)
trainer_ddora_all_attn = Trainer(model=model_ddora_all_attn, args=training_args_ddora_all_attn, 
                                train_dataset=dataset_encoded["train"], eval_dataset=dataset_encoded["validation"], compute_metrics=compute_metrics)
DDoRA (All Attention) - Total parameters: 67,618,562 DDoRA (All Attention) - Trainable parameters: 1,255,682 (1.86%)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
In [17]:
# Sanity check
print("\nTrainable parameters after freezing:")
for name, param in model_ddora_all_attn.named_parameters():
    if param.requires_grad:
        print(name)
Trainable parameters after freezing: distilbert.transformer.layer.0.attention.q_lin.m_out distilbert.transformer.layer.0.attention.q_lin.m_in distilbert.transformer.layer.0.attention.q_lin.scale_out distilbert.transformer.layer.0.attention.q_lin.scale_in distilbert.transformer.layer.0.attention.q_lin.lora.A distilbert.transformer.layer.0.attention.q_lin.lora.B distilbert.transformer.layer.0.attention.k_lin.m_out distilbert.transformer.layer.0.attention.k_lin.m_in distilbert.transformer.layer.0.attention.k_lin.scale_out distilbert.transformer.layer.0.attention.k_lin.scale_in distilbert.transformer.layer.0.attention.k_lin.lora.A distilbert.transformer.layer.0.attention.k_lin.lora.B distilbert.transformer.layer.0.attention.v_lin.m_out distilbert.transformer.layer.0.attention.v_lin.m_in distilbert.transformer.layer.0.attention.v_lin.scale_out distilbert.transformer.layer.0.attention.v_lin.scale_in distilbert.transformer.layer.0.attention.v_lin.lora.A distilbert.transformer.layer.0.attention.v_lin.lora.B distilbert.transformer.layer.0.attention.out_lin.m_out distilbert.transformer.layer.0.attention.out_lin.m_in distilbert.transformer.layer.0.attention.out_lin.scale_out distilbert.transformer.layer.0.attention.out_lin.scale_in distilbert.transformer.layer.0.attention.out_lin.lora.A distilbert.transformer.layer.0.attention.out_lin.lora.B distilbert.transformer.layer.1.attention.q_lin.m_out distilbert.transformer.layer.1.attention.q_lin.m_in distilbert.transformer.layer.1.attention.q_lin.scale_out distilbert.transformer.layer.1.attention.q_lin.scale_in distilbert.transformer.layer.1.attention.q_lin.lora.A distilbert.transformer.layer.1.attention.q_lin.lora.B distilbert.transformer.layer.1.attention.k_lin.m_out distilbert.transformer.layer.1.attention.k_lin.m_in distilbert.transformer.layer.1.attention.k_lin.scale_out distilbert.transformer.layer.1.attention.k_lin.scale_in distilbert.transformer.layer.1.attention.k_lin.lora.A distilbert.transformer.layer.1.attention.k_lin.lora.B distilbert.transformer.layer.1.attention.v_lin.m_out distilbert.transformer.layer.1.attention.v_lin.m_in distilbert.transformer.layer.1.attention.v_lin.scale_out distilbert.transformer.layer.1.attention.v_lin.scale_in distilbert.transformer.layer.1.attention.v_lin.lora.A distilbert.transformer.layer.1.attention.v_lin.lora.B distilbert.transformer.layer.1.attention.out_lin.m_out distilbert.transformer.layer.1.attention.out_lin.m_in distilbert.transformer.layer.1.attention.out_lin.scale_out distilbert.transformer.layer.1.attention.out_lin.scale_in distilbert.transformer.layer.1.attention.out_lin.lora.A distilbert.transformer.layer.1.attention.out_lin.lora.B distilbert.transformer.layer.2.attention.q_lin.m_out distilbert.transformer.layer.2.attention.q_lin.m_in distilbert.transformer.layer.2.attention.q_lin.scale_out distilbert.transformer.layer.2.attention.q_lin.scale_in distilbert.transformer.layer.2.attention.q_lin.lora.A distilbert.transformer.layer.2.attention.q_lin.lora.B distilbert.transformer.layer.2.attention.k_lin.m_out distilbert.transformer.layer.2.attention.k_lin.m_in distilbert.transformer.layer.2.attention.k_lin.scale_out distilbert.transformer.layer.2.attention.k_lin.scale_in distilbert.transformer.layer.2.attention.k_lin.lora.A distilbert.transformer.layer.2.attention.k_lin.lora.B distilbert.transformer.layer.2.attention.v_lin.m_out distilbert.transformer.layer.2.attention.v_lin.m_in distilbert.transformer.layer.2.attention.v_lin.scale_out distilbert.transformer.layer.2.attention.v_lin.scale_in distilbert.transformer.layer.2.attention.v_lin.lora.A distilbert.transformer.layer.2.attention.v_lin.lora.B distilbert.transformer.layer.2.attention.out_lin.m_out distilbert.transformer.layer.2.attention.out_lin.m_in distilbert.transformer.layer.2.attention.out_lin.scale_out distilbert.transformer.layer.2.attention.out_lin.scale_in distilbert.transformer.layer.2.attention.out_lin.lora.A distilbert.transformer.layer.2.attention.out_lin.lora.B distilbert.transformer.layer.3.attention.q_lin.m_out distilbert.transformer.layer.3.attention.q_lin.m_in distilbert.transformer.layer.3.attention.q_lin.scale_out distilbert.transformer.layer.3.attention.q_lin.scale_in distilbert.transformer.layer.3.attention.q_lin.lora.A distilbert.transformer.layer.3.attention.q_lin.lora.B distilbert.transformer.layer.3.attention.k_lin.m_out distilbert.transformer.layer.3.attention.k_lin.m_in distilbert.transformer.layer.3.attention.k_lin.scale_out distilbert.transformer.layer.3.attention.k_lin.scale_in distilbert.transformer.layer.3.attention.k_lin.lora.A distilbert.transformer.layer.3.attention.k_lin.lora.B distilbert.transformer.layer.3.attention.v_lin.m_out distilbert.transformer.layer.3.attention.v_lin.m_in distilbert.transformer.layer.3.attention.v_lin.scale_out distilbert.transformer.layer.3.attention.v_lin.scale_in distilbert.transformer.layer.3.attention.v_lin.lora.A distilbert.transformer.layer.3.attention.v_lin.lora.B distilbert.transformer.layer.3.attention.out_lin.m_out distilbert.transformer.layer.3.attention.out_lin.m_in distilbert.transformer.layer.3.attention.out_lin.scale_out distilbert.transformer.layer.3.attention.out_lin.scale_in distilbert.transformer.layer.3.attention.out_lin.lora.A distilbert.transformer.layer.3.attention.out_lin.lora.B distilbert.transformer.layer.4.attention.q_lin.m_out distilbert.transformer.layer.4.attention.q_lin.m_in distilbert.transformer.layer.4.attention.q_lin.scale_out distilbert.transformer.layer.4.attention.q_lin.scale_in distilbert.transformer.layer.4.attention.q_lin.lora.A distilbert.transformer.layer.4.attention.q_lin.lora.B distilbert.transformer.layer.4.attention.k_lin.m_out distilbert.transformer.layer.4.attention.k_lin.m_in distilbert.transformer.layer.4.attention.k_lin.scale_out distilbert.transformer.layer.4.attention.k_lin.scale_in distilbert.transformer.layer.4.attention.k_lin.lora.A distilbert.transformer.layer.4.attention.k_lin.lora.B distilbert.transformer.layer.4.attention.v_lin.m_out distilbert.transformer.layer.4.attention.v_lin.m_in distilbert.transformer.layer.4.attention.v_lin.scale_out distilbert.transformer.layer.4.attention.v_lin.scale_in distilbert.transformer.layer.4.attention.v_lin.lora.A distilbert.transformer.layer.4.attention.v_lin.lora.B distilbert.transformer.layer.4.attention.out_lin.m_out distilbert.transformer.layer.4.attention.out_lin.m_in distilbert.transformer.layer.4.attention.out_lin.scale_out distilbert.transformer.layer.4.attention.out_lin.scale_in distilbert.transformer.layer.4.attention.out_lin.lora.A distilbert.transformer.layer.4.attention.out_lin.lora.B distilbert.transformer.layer.5.attention.q_lin.m_out distilbert.transformer.layer.5.attention.q_lin.m_in distilbert.transformer.layer.5.attention.q_lin.scale_out distilbert.transformer.layer.5.attention.q_lin.scale_in distilbert.transformer.layer.5.attention.q_lin.lora.A distilbert.transformer.layer.5.attention.q_lin.lora.B distilbert.transformer.layer.5.attention.k_lin.m_out distilbert.transformer.layer.5.attention.k_lin.m_in distilbert.transformer.layer.5.attention.k_lin.scale_out distilbert.transformer.layer.5.attention.k_lin.scale_in distilbert.transformer.layer.5.attention.k_lin.lora.A distilbert.transformer.layer.5.attention.k_lin.lora.B distilbert.transformer.layer.5.attention.v_lin.m_out distilbert.transformer.layer.5.attention.v_lin.m_in distilbert.transformer.layer.5.attention.v_lin.scale_out distilbert.transformer.layer.5.attention.v_lin.scale_in distilbert.transformer.layer.5.attention.v_lin.lora.A distilbert.transformer.layer.5.attention.v_lin.lora.B distilbert.transformer.layer.5.attention.out_lin.m_out distilbert.transformer.layer.5.attention.out_lin.m_in distilbert.transformer.layer.5.attention.out_lin.scale_out distilbert.transformer.layer.5.attention.out_lin.scale_in distilbert.transformer.layer.5.attention.out_lin.lora.A distilbert.transformer.layer.5.attention.out_lin.lora.B pre_classifier.weight pre_classifier.bias classifier.weight classifier.bias
In [18]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
      [3910/3910 2:50:30, Epoch 5/5]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 50 | 0.770700 | 0.286238 | 0.872000 | 0.871847 | 
| 100 | 0.349700 | 0.288429 | 0.880000 | 0.879626 | 
| 150 | 0.364200 | 0.300074 | 0.872800 | 0.873083 | 
| 200 | 0.288800 | 0.273452 | 0.883200 | 0.883296 | 
| 250 | 0.299600 | 0.251706 | 0.898400 | 0.897942 | 
| 300 | 0.296600 | 0.246586 | 0.898400 | 0.898229 | 
| 350 | 0.294200 | 0.256276 | 0.902400 | 0.902205 | 
| 400 | 0.267100 | 0.248215 | 0.896800 | 0.896852 | 
| 450 | 0.259600 | 0.261862 | 0.902400 | 0.902072 | 
| 500 | 0.286100 | 0.339048 | 0.903200 | 0.903249 | 
| 550 | 0.280100 | 0.271318 | 0.905600 | 0.905391 | 
| 600 | 0.296700 | 0.254167 | 0.913600 | 0.913544 | 
| 650 | 0.248500 | 0.248038 | 0.909600 | 0.909448 | 
| 700 | 0.272000 | 0.240011 | 0.910400 | 0.910121 | 
| 750 | 0.256800 | 0.272073 | 0.900800 | 0.900581 | 
| 800 | 0.244800 | 0.336190 | 0.899200 | 0.898152 | 
| 850 | 0.232100 | 0.244281 | 0.904000 | 0.903788 | 
| 900 | 0.217700 | 0.275926 | 0.913600 | 0.913151 | 
| 950 | 0.255400 | 0.276984 | 0.910400 | 0.910473 | 
| 1000 | 0.223600 | 0.227863 | 0.909600 | 0.909143 | 
| 1050 | 0.223100 | 0.244880 | 0.909600 | 0.909593 | 
| 1100 | 0.229500 | 0.250599 | 0.912800 | 0.912539 | 
| 1150 | 0.203900 | 0.231664 | 0.911200 | 0.911013 | 
| 1200 | 0.240400 | 0.231101 | 0.911200 | 0.911179 | 
| 1250 | 0.246100 | 0.229517 | 0.914400 | 0.914454 | 
| 1300 | 0.248700 | 0.231470 | 0.912800 | 0.912765 | 
| 1350 | 0.243300 | 0.221713 | 0.910400 | 0.910400 | 
| 1400 | 0.236300 | 0.219190 | 0.912800 | 0.912704 | 
| 1450 | 0.219200 | 0.231880 | 0.913600 | 0.913625 | 
| 1500 | 0.212400 | 0.220595 | 0.912800 | 0.912559 | 
| 1550 | 0.212400 | 0.239685 | 0.913600 | 0.913613 | 
| 1600 | 0.185500 | 0.284434 | 0.899200 | 0.899420 | 
| 1650 | 0.180000 | 0.298952 | 0.914400 | 0.914183 | 
| 1700 | 0.200800 | 0.263503 | 0.914400 | 0.914443 | 
| 1750 | 0.197100 | 0.274289 | 0.912000 | 0.912083 | 
| 1800 | 0.204200 | 0.241341 | 0.904800 | 0.904905 | 
| 1850 | 0.201200 | 0.239144 | 0.911200 | 0.911220 | 
| 1900 | 0.187300 | 0.246986 | 0.915200 | 0.915130 | 
| 1950 | 0.199200 | 0.275895 | 0.904800 | 0.905001 | 
| 2000 | 0.202000 | 0.238767 | 0.917600 | 0.917682 | 
| 2050 | 0.197800 | 0.252750 | 0.917600 | 0.917272 | 
| 2100 | 0.199800 | 0.244306 | 0.915200 | 0.915099 | 
| 2150 | 0.151500 | 0.297675 | 0.920000 | 0.919948 | 
| 2200 | 0.204700 | 0.226266 | 0.913600 | 0.913681 | 
| 2250 | 0.206400 | 0.223971 | 0.920000 | 0.919873 | 
| 2300 | 0.191300 | 0.228316 | 0.919200 | 0.919206 | 
| 2350 | 0.177600 | 0.251830 | 0.921600 | 0.921337 | 
| 2400 | 0.188400 | 0.279854 | 0.918400 | 0.918220 | 
| 2450 | 0.149600 | 0.237880 | 0.915200 | 0.915082 | 
| 2500 | 0.148100 | 0.255322 | 0.910400 | 0.910462 | 
| 2550 | 0.144100 | 0.284158 | 0.918400 | 0.918184 | 
| 2600 | 0.167500 | 0.257225 | 0.908000 | 0.908138 | 
| 2650 | 0.161400 | 0.226382 | 0.917600 | 0.917594 | 
| 2700 | 0.157400 | 0.233534 | 0.919200 | 0.919229 | 
| 2750 | 0.163500 | 0.243015 | 0.921600 | 0.921410 | 
| 2800 | 0.121700 | 0.245816 | 0.920800 | 0.920682 | 
| 2850 | 0.141500 | 0.285453 | 0.922400 | 0.922071 | 
| 2900 | 0.163300 | 0.232592 | 0.921600 | 0.921549 | 
| 2950 | 0.155700 | 0.241530 | 0.916800 | 0.916616 | 
| 3000 | 0.150700 | 0.241032 | 0.916800 | 0.916760 | 
| 3050 | 0.148100 | 0.261533 | 0.916000 | 0.915994 | 
| 3100 | 0.146800 | 0.262320 | 0.912000 | 0.911413 | 
| 3150 | 0.135700 | 0.245895 | 0.920000 | 0.919975 | 
| 3200 | 0.129700 | 0.245331 | 0.921600 | 0.921476 | 
| 3250 | 0.112900 | 0.275194 | 0.917600 | 0.917581 | 
| 3300 | 0.124800 | 0.264425 | 0.923200 | 0.923013 | 
| 3350 | 0.115600 | 0.292223 | 0.924000 | 0.923887 | 
| 3400 | 0.105400 | 0.271411 | 0.917600 | 0.917427 | 
| 3450 | 0.131200 | 0.260183 | 0.920000 | 0.919934 | 
| 3500 | 0.109600 | 0.265847 | 0.922400 | 0.922314 | 
| 3550 | 0.116700 | 0.269032 | 0.920800 | 0.920682 | 
| 3600 | 0.133200 | 0.260401 | 0.924000 | 0.923916 | 
| 3650 | 0.133300 | 0.267982 | 0.924000 | 0.923840 | 
| 3700 | 0.124600 | 0.265107 | 0.919200 | 0.919194 | 
| 3750 | 0.104300 | 0.276479 | 0.920800 | 0.920769 | 
| 3800 | 0.106300 | 0.271788 | 0.921600 | 0.921562 | 
| 3850 | 0.109200 | 0.271291 | 0.922400 | 0.922343 | 
| 3900 | 0.110700 | 0.268460 | 0.920800 | 0.920755 | 
DDoRA (All Attention) Test Results: {'eval_loss': 0.20760443806648254, 'eval_accuracy': 0.9165473684210527, 'eval_f1': 0.9165350069271784, 'eval_runtime': 155.718, 'eval_samples_per_second': 152.519, 'eval_steps_per_second': 4.771, 'epoch': 5.0}
Prediction drift
[[-2.5889697  1.7642486]
 [-3.9092817  2.3953364]
 [ 2.2905545 -2.4951563]
 [-2.053718   1.375324 ]
 [ 5.5414968 -5.709928 ]] [1 1 0 1 0]
LoRA Heatmap
[m_out] Mean: -0.0008, Std: 0.5493, Min: -2.8318, Max: 3.1991 [m_in ] Mean: 0.0086, Std: 0.4519, Min: -2.4116, Max: 2.3131
In [19]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale_out 2.0901947021484375 distilbert.transformer.layer.0.attention.q_lin.scale_in 1.9317268133163452 distilbert.transformer.layer.0.attention.k_lin.scale_out 2.0788636207580566 distilbert.transformer.layer.0.attention.k_lin.scale_in 1.8939082622528076 distilbert.transformer.layer.0.attention.v_lin.scale_out 2.0176312923431396 distilbert.transformer.layer.0.attention.v_lin.scale_in 1.6768195629119873 distilbert.transformer.layer.0.attention.out_lin.scale_out 1.7772330045700073 distilbert.transformer.layer.0.attention.out_lin.scale_in 1.8626189231872559 distilbert.transformer.layer.1.attention.q_lin.scale_out 2.018207311630249 distilbert.transformer.layer.1.attention.q_lin.scale_in 1.8735299110412598 distilbert.transformer.layer.1.attention.k_lin.scale_out 1.9069545269012451 distilbert.transformer.layer.1.attention.k_lin.scale_in 1.9223332405090332 distilbert.transformer.layer.1.attention.v_lin.scale_out 1.74566650390625 distilbert.transformer.layer.1.attention.v_lin.scale_in 1.838054895401001 distilbert.transformer.layer.1.attention.out_lin.scale_out 1.740161418914795 distilbert.transformer.layer.1.attention.out_lin.scale_in 1.8335179090499878 distilbert.transformer.layer.2.attention.q_lin.scale_out 1.9253816604614258 distilbert.transformer.layer.2.attention.q_lin.scale_in 1.9024275541305542 distilbert.transformer.layer.2.attention.k_lin.scale_out 1.9852197170257568 distilbert.transformer.layer.2.attention.k_lin.scale_in 1.8761833906173706 distilbert.transformer.layer.2.attention.v_lin.scale_out 1.796034574508667 distilbert.transformer.layer.2.attention.v_lin.scale_in 1.8118958473205566 distilbert.transformer.layer.2.attention.out_lin.scale_out 1.7673529386520386 distilbert.transformer.layer.2.attention.out_lin.scale_in 1.7620668411254883 distilbert.transformer.layer.3.attention.q_lin.scale_out 1.9725602865219116 distilbert.transformer.layer.3.attention.q_lin.scale_in 1.912266492843628 distilbert.transformer.layer.3.attention.k_lin.scale_out 1.9736626148223877 distilbert.transformer.layer.3.attention.k_lin.scale_in 1.8784260749816895 distilbert.transformer.layer.3.attention.v_lin.scale_out 1.7679297924041748 distilbert.transformer.layer.3.attention.v_lin.scale_in 1.8323858976364136 distilbert.transformer.layer.3.attention.out_lin.scale_out 1.8022470474243164 distilbert.transformer.layer.3.attention.out_lin.scale_in 1.8460712432861328 distilbert.transformer.layer.4.attention.q_lin.scale_out 1.9774129390716553 distilbert.transformer.layer.4.attention.q_lin.scale_in 1.8028367757797241 distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0612895488739014 distilbert.transformer.layer.4.attention.k_lin.scale_in 1.8295701742172241 distilbert.transformer.layer.4.attention.v_lin.scale_out 1.719773292541504 distilbert.transformer.layer.4.attention.v_lin.scale_in 1.7408709526062012 distilbert.transformer.layer.4.attention.out_lin.scale_out 1.7295746803283691 distilbert.transformer.layer.4.attention.out_lin.scale_in 1.7716983556747437 distilbert.transformer.layer.5.attention.q_lin.scale_out 1.9752777814865112 distilbert.transformer.layer.5.attention.q_lin.scale_in 1.7774198055267334 distilbert.transformer.layer.5.attention.k_lin.scale_out 2.129931926727295 distilbert.transformer.layer.5.attention.k_lin.scale_in 1.8098797798156738 distilbert.transformer.layer.5.attention.v_lin.scale_out 1.919114112854004 distilbert.transformer.layer.5.attention.v_lin.scale_in 1.6814241409301758 distilbert.transformer.layer.5.attention.out_lin.scale_out 1.7194103002548218 distilbert.transformer.layer.5.attention.out_lin.scale_in 1.7984395027160645 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m_out 0.5647355318069458 distilbert.transformer.layer.0.attention.q_lin.m_in 0.3798569142818451 distilbert.transformer.layer.0.attention.k_lin.m_out 0.5425061583518982 distilbert.transformer.layer.0.attention.k_lin.m_in 0.37627243995666504 distilbert.transformer.layer.0.attention.v_lin.m_out 0.480976939201355 distilbert.transformer.layer.0.attention.v_lin.m_in 0.26557350158691406 distilbert.transformer.layer.0.attention.out_lin.m_out 0.3110979497432709 distilbert.transformer.layer.0.attention.out_lin.m_in 0.29723748564720154 distilbert.transformer.layer.1.attention.q_lin.m_out 0.4283873736858368 distilbert.transformer.layer.1.attention.q_lin.m_in 0.3325177729129791 distilbert.transformer.layer.1.attention.k_lin.m_out 0.40755534172058105 distilbert.transformer.layer.1.attention.k_lin.m_in 0.34619420766830444 distilbert.transformer.layer.1.attention.v_lin.m_out 0.2836112380027771 distilbert.transformer.layer.1.attention.v_lin.m_in 0.29925063252449036 distilbert.transformer.layer.1.attention.out_lin.m_out 0.3088749349117279 distilbert.transformer.layer.1.attention.out_lin.m_in 0.27071183919906616 distilbert.transformer.layer.2.attention.q_lin.m_out 0.4211108684539795 distilbert.transformer.layer.2.attention.q_lin.m_in 0.30545711517333984 distilbert.transformer.layer.2.attention.k_lin.m_out 0.44150853157043457 distilbert.transformer.layer.2.attention.k_lin.m_in 0.3178306221961975 distilbert.transformer.layer.2.attention.v_lin.m_out 0.34029704332351685 distilbert.transformer.layer.2.attention.v_lin.m_in 0.24604985117912292 distilbert.transformer.layer.2.attention.out_lin.m_out 0.3179643154144287 distilbert.transformer.layer.2.attention.out_lin.m_in 0.2530387341976166 distilbert.transformer.layer.3.attention.q_lin.m_out 0.4518465995788574 distilbert.transformer.layer.3.attention.q_lin.m_in 0.3557422161102295 distilbert.transformer.layer.3.attention.k_lin.m_out 0.47002920508384705 distilbert.transformer.layer.3.attention.k_lin.m_in 0.35762590169906616 distilbert.transformer.layer.3.attention.v_lin.m_out 0.32528677582740784 distilbert.transformer.layer.3.attention.v_lin.m_in 0.2892145812511444 distilbert.transformer.layer.3.attention.out_lin.m_out 0.33790677785873413 distilbert.transformer.layer.3.attention.out_lin.m_in 0.28367918729782104 distilbert.transformer.layer.4.attention.q_lin.m_out 0.47639477252960205 distilbert.transformer.layer.4.attention.q_lin.m_in 0.29029154777526855 distilbert.transformer.layer.4.attention.k_lin.m_out 0.5150153636932373 distilbert.transformer.layer.4.attention.k_lin.m_in 0.3103344440460205 distilbert.transformer.layer.4.attention.v_lin.m_out 0.32499608397483826 distilbert.transformer.layer.4.attention.v_lin.m_in 0.2510720491409302 distilbert.transformer.layer.4.attention.out_lin.m_out 0.3214051127433777 distilbert.transformer.layer.4.attention.out_lin.m_in 0.2650536000728607 distilbert.transformer.layer.5.attention.q_lin.m_out 0.43587324023246765 distilbert.transformer.layer.5.attention.q_lin.m_in 0.22241328656673431 distilbert.transformer.layer.5.attention.k_lin.m_out 0.5570618510246277 distilbert.transformer.layer.5.attention.k_lin.m_in 0.2527024447917938 distilbert.transformer.layer.5.attention.v_lin.m_out 0.37015169858932495 distilbert.transformer.layer.5.attention.v_lin.m_in 0.20549951493740082 distilbert.transformer.layer.5.attention.out_lin.m_out 0.31466561555862427 distilbert.transformer.layer.5.attention.out_lin.m_in 0.27215391397476196 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.39448559284210205 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.29920345544815063 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.388362854719162 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.291858047246933 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.3334265351295471 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.26573705673217773 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.33093491196632385 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.2437349259853363 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.3549439609050751 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.23443476855754852 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.3481384217739105 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.27366185188293457 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.3255350589752197 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.1981675922870636 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.31510788202285767 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.24702942371368408 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.32814159989356995 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.2725476622581482 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.33942073583602905 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.2639490067958832 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.28993457555770874 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.19012793898582458 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.30221056938171387 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.23099581897258759 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.3653438091278076 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.2912411689758301 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.37983566522598267 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.2919533848762512 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.3182007968425751 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.22263233363628387 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.3163568377494812 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.22031466662883759 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.32468748092651367 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.25622785091400146 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.3412896394729614 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.2617414891719818 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.3016916513442993 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.18387971818447113 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.31307727098464966 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.20827889442443848 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.2830806076526642 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.23386383056640625 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.3035750091075897 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.2312200516462326 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.28683286905288696 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.16982325911521912 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.3157418370246887 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.2109360694885254
In [20]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 57.6984 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 42.6242 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 57.0737 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 41.5720 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 50.5210 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 38.4451 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 49.7002 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 35.1399 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 52.7742 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 34.2313 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 50.7043 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 38.7794 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 48.0156 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 28.3948 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 47.3806 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 35.4175 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 48.9293 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 38.7429 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 49.8689 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 37.3596 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 43.0877 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 27.5580 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 45.4100 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 32.9905 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 53.3644 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 41.3374 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 56.0531 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 41.8924 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 47.7002 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 31.9272 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 47.3052 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 31.5475 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 47.6581 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 36.8326 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 50.2827 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 37.8750 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 45.1280 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 26.8244 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 47.0836 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 30.0062 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 42.4531 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 33.3586 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 44.9119 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 33.6370 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 43.6901 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 24.9901 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 47.2171 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 30.6058 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 60.7115 distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.4249 distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 60.4370 distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 54.5890 distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 58.4866 distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 49.1681 distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 51.2425 distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 53.3198 distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 57.8829 distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 53.6451 distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 55.0335 distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 54.6923 distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 50.1200 distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 52.6154 distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 50.4183 distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 52.5371 distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 55.7336 distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 54.1221 distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 57.0685 distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 53.5724 distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 52.1679 distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 51.4892 distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 51.0791 distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 50.5406 distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 57.1301 distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 54.7893 distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 57.2046 distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 54.1217 distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.3382 distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 52.4507 distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 51.9679 distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 52.7578 distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 57.4674 distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 51.6265 distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 59.7401 distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 52.3292 distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 50.7752 distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 50.0643 distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 50.8358 distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 50.8573 distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 57.1890 distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 50.6338 distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 61.1200 distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 51.5444 distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 55.0821 distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 48.5430 distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 50.9995 distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 51.4657 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 20.0095 distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 15.0613 distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 19.6359 distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 15.1377 distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 17.2782 distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 12.7195 distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 11.9346 distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 12.8137 distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 15.6942 distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 13.5273 distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 14.8891 distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 13.4656 distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 11.5086 distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 12.5572 distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 11.9882 distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 12.0898 distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 15.6175 distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 12.3557 distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 15.7173 distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 12.8633 distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 12.7515 distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 10.7463 distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 12.1293 distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 11.4631 distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 16.4127 distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 13.9588 distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 16.8955 distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 14.1378 distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 12.3394 distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 12.3336 distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 12.4052 distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 11.9973 distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 17.3603 distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 12.2036 distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 18.3908 distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 12.4500 distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 12.6651 distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 11.2682 distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 12.6220 distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 11.9401 distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 16.4559 distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 10.3415 distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 18.6853 distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 10.9104 distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 13.6871 distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 10.8449 distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 12.6990 distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 11.9039
In [21]:
print (torch.cuda.memory_summary())
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 1077 MiB | 7290 MiB | 406631 GiB | 406630 GiB | | from large pool | 1055 MiB | 7240 MiB | 404673 GiB | 404672 GiB | | from small pool | 22 MiB | 51 MiB | 1958 GiB | 1958 GiB | |---------------------------------------------------------------------------| | Active memory | 1077 MiB | 7290 MiB | 406631 GiB | 406630 GiB | | from large pool | 1055 MiB | 7240 MiB | 404673 GiB | 404672 GiB | | from small pool | 22 MiB | 51 MiB | 1958 GiB | 1958 GiB | |---------------------------------------------------------------------------| | Requested memory | 1072 MiB | 7285 MiB | 406404 GiB | 406403 GiB | | from large pool | 1050 MiB | 7235 MiB | 404455 GiB | 404454 GiB | | from small pool | 22 MiB | 51 MiB | 1948 GiB | 1948 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 7380 MiB | 7380 MiB | 14270 MiB | 6890 MiB | | from large pool | 7328 MiB | 7328 MiB | 14168 MiB | 6840 MiB | | from small pool | 52 MiB | 52 MiB | 102 MiB | 50 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 162149 KiB | 545173 KiB | 144872 GiB | 144872 GiB | | from large pool | 156580 KiB | 540498 KiB | 142781 GiB | 142781 GiB | | from small pool | 5569 KiB | 22942 KiB | 2091 GiB | 2091 GiB | |---------------------------------------------------------------------------| | Allocations | 1310 | 1616 | 44683 K | 44682 K | | from large pool | 164 | 318 | 13813 K | 13813 K | | from small pool | 1146 | 1451 | 30869 K | 30868 K | |---------------------------------------------------------------------------| | Active allocs | 1310 | 1616 | 44683 K | 44682 K | | from large pool | 164 | 318 | 13813 K | 13813 K | | from small pool | 1146 | 1451 | 30869 K | 30868 K | |---------------------------------------------------------------------------| | GPU reserved segments | 190 | 190 | 327 | 137 | | from large pool | 164 | 164 | 276 | 112 | | from small pool | 26 | 26 | 51 | 25 | |---------------------------------------------------------------------------| | Non-releasable allocs | 37 | 60 | 20413 K | 20413 K | | from large pool | 18 | 26 | 3326 K | 3326 K | | from small pool | 19 | 43 | 17086 K | 17086 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
Full DistilBERT on IMDB Sentiment¶
In [22]:
# Define the training parameters
batch_size = 32
logging_steps = len(dataset_encoded["train"]) // batch_size
model_name = f"{model_ckpt}-finetuned-imdb"
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=2, 
                                  learning_rate=1e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="steps",
                                  eval_steps=20,  # Evaluate more frequently on the larger dataset
                                  logging_steps=20,
                                  save_steps=20,
                                  max_grad_norm=1.0,
                                  disable_tqdm=False,
                                  push_to_hub=False,
                                  report_to="none",
                                  log_level="error")
trainer = Trainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=dataset_encoded["train"],
                  eval_dataset=dataset_encoded["validation"],
                  tokenizer=tokenizer)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\transformers\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn( C:\Users\alexa\AppData\Local\Temp\ipykernel_22420\3915345233.py:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead. trainer = Trainer(model=model,
In [23]:
# Train!
trainer.train()
print (torch.cuda.memory_summary())
# Evaluate on the test set
evaluation_results = trainer.evaluate(dataset_encoded["test"])
print(f"\nEvaluation results on the test set:")
print(evaluation_results)
# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt
# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]
# Extract validation loss
val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]
# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss vs. Steps")
plt.legend()
plt.show()
val_accuracy_steps = [entry["step"] for entry in log_history if "eval_accuracy" in entry]
val_accuracies = [entry["eval_accuracy"] for entry in log_history if "eval_accuracy" in entry]
val_f1_steps = [entry["step"] for entry in log_history if "eval_f1" in entry]
val_f1_scores = [entry["eval_f1"] for entry in log_history if "eval_f1" in entry]
plt.plot(val_accuracy_steps, val_accuracies, label="Validation Accuracy", linestyle="-")
plt.plot(val_f1_steps, val_f1_scores, label="Validation F1 score", linestyle="--")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Val Accurace and F1 vs. Steps")
plt.legend()
plt.show()
      [1564/1564 45:03, Epoch 2/2]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 20 | 0.689400 | 0.676229 | 0.744000 | 0.739512 | 
| 40 | 0.659500 | 0.588595 | 0.820800 | 0.819493 | 
| 60 | 0.494700 | 0.388211 | 0.865600 | 0.865512 | 
| 80 | 0.366500 | 0.328102 | 0.866400 | 0.865610 | 
| 100 | 0.323200 | 0.317173 | 0.876800 | 0.877069 | 
| 120 | 0.339900 | 0.284819 | 0.884000 | 0.883705 | 
| 140 | 0.298000 | 0.275737 | 0.885600 | 0.885591 | 
| 160 | 0.305800 | 0.292675 | 0.880800 | 0.879911 | 
| 180 | 0.258600 | 0.272887 | 0.888800 | 0.888825 | 
| 200 | 0.257000 | 0.272066 | 0.884800 | 0.884265 | 
| 220 | 0.255000 | 0.281653 | 0.884800 | 0.883885 | 
| 240 | 0.286100 | 0.253439 | 0.901600 | 0.901510 | 
| 260 | 0.257200 | 0.248806 | 0.902400 | 0.902369 | 
| 280 | 0.244400 | 0.249558 | 0.902400 | 0.902184 | 
| 300 | 0.282500 | 0.258910 | 0.907200 | 0.906769 | 
| 320 | 0.266500 | 0.241840 | 0.906400 | 0.906141 | 
| 340 | 0.269500 | 0.239128 | 0.900800 | 0.900752 | 
| 360 | 0.252200 | 0.258152 | 0.897600 | 0.896886 | 
| 380 | 0.250800 | 0.246389 | 0.908800 | 0.908924 | 
| 400 | 0.246200 | 0.252573 | 0.901600 | 0.901795 | 
| 420 | 0.252400 | 0.269475 | 0.894400 | 0.894634 | 
| 440 | 0.219100 | 0.262670 | 0.898400 | 0.898618 | 
| 460 | 0.300400 | 0.263741 | 0.899200 | 0.898037 | 
| 480 | 0.270000 | 0.264092 | 0.896000 | 0.894800 | 
| 500 | 0.258500 | 0.240369 | 0.904800 | 0.904421 | 
| 520 | 0.254600 | 0.230196 | 0.908000 | 0.907948 | 
| 540 | 0.238300 | 0.233112 | 0.905600 | 0.905677 | 
| 560 | 0.233700 | 0.233234 | 0.904000 | 0.904111 | 
| 580 | 0.284200 | 0.229832 | 0.905600 | 0.905570 | 
| 600 | 0.254800 | 0.229044 | 0.911200 | 0.911050 | 
| 620 | 0.215700 | 0.228184 | 0.913600 | 0.913528 | 
| 640 | 0.247200 | 0.244651 | 0.901600 | 0.901771 | 
| 660 | 0.223600 | 0.232534 | 0.911200 | 0.910891 | 
| 680 | 0.234400 | 0.248592 | 0.906400 | 0.906585 | 
| 700 | 0.247600 | 0.223161 | 0.913600 | 0.913625 | 
| 720 | 0.269200 | 0.223580 | 0.912000 | 0.912128 | 
| 740 | 0.202000 | 0.222004 | 0.909600 | 0.909669 | 
| 760 | 0.226900 | 0.233625 | 0.908800 | 0.908961 | 
| 780 | 0.215700 | 0.219046 | 0.908800 | 0.908814 | 
| 800 | 0.195900 | 0.222158 | 0.917600 | 0.917478 | 
| 820 | 0.175600 | 0.259501 | 0.896800 | 0.897026 | 
| 840 | 0.213900 | 0.228438 | 0.914400 | 0.914306 | 
| 860 | 0.212300 | 0.244444 | 0.904800 | 0.904978 | 
| 880 | 0.188900 | 0.224636 | 0.914400 | 0.914337 | 
| 900 | 0.166400 | 0.226523 | 0.912000 | 0.911943 | 
| 920 | 0.177200 | 0.237086 | 0.910400 | 0.910530 | 
| 940 | 0.214900 | 0.226861 | 0.913600 | 0.913463 | 
| 960 | 0.178100 | 0.239278 | 0.908000 | 0.908153 | 
| 980 | 0.189200 | 0.224798 | 0.912000 | 0.911943 | 
| 1000 | 0.166500 | 0.232560 | 0.913600 | 0.913660 | 
| 1020 | 0.187100 | 0.230836 | 0.913600 | 0.913625 | 
| 1040 | 0.154100 | 0.229299 | 0.916000 | 0.915923 | 
| 1060 | 0.207200 | 0.226525 | 0.916800 | 0.916847 | 
| 1080 | 0.191800 | 0.225182 | 0.916000 | 0.915907 | 
| 1100 | 0.204000 | 0.233329 | 0.912000 | 0.912093 | 
| 1120 | 0.165900 | 0.227521 | 0.916800 | 0.916836 | 
| 1140 | 0.156800 | 0.226366 | 0.917600 | 0.917524 | 
| 1160 | 0.203400 | 0.224103 | 0.914400 | 0.914443 | 
| 1180 | 0.203600 | 0.221393 | 0.917600 | 0.917478 | 
| 1200 | 0.174700 | 0.232756 | 0.908800 | 0.908932 | 
| 1220 | 0.164400 | 0.225512 | 0.916000 | 0.915967 | 
| 1240 | 0.169700 | 0.230427 | 0.908800 | 0.908886 | 
| 1260 | 0.220100 | 0.228160 | 0.916000 | 0.915823 | 
| 1280 | 0.177400 | 0.233028 | 0.907200 | 0.907317 | 
| 1300 | 0.207600 | 0.226083 | 0.914400 | 0.914406 | 
| 1320 | 0.210700 | 0.225015 | 0.911200 | 0.911245 | 
| 1340 | 0.148000 | 0.224107 | 0.916800 | 0.916634 | 
| 1360 | 0.170400 | 0.224322 | 0.910400 | 0.910439 | 
| 1380 | 0.200600 | 0.223000 | 0.913600 | 0.913558 | 
| 1400 | 0.175900 | 0.223361 | 0.911200 | 0.911245 | 
| 1420 | 0.169600 | 0.224143 | 0.909600 | 0.909657 | 
| 1440 | 0.157600 | 0.224942 | 0.909600 | 0.909657 | 
| 1460 | 0.183600 | 0.223980 | 0.912000 | 0.911986 | 
| 1480 | 0.204600 | 0.224059 | 0.912000 | 0.911972 | 
| 1500 | 0.181200 | 0.224224 | 0.914400 | 0.914337 | 
| 1520 | 0.175000 | 0.224378 | 0.912800 | 0.912765 | 
| 1540 | 0.183800 | 0.225751 | 0.907200 | 0.907240 | 
| 1560 | 0.188800 | 0.226429 | 0.908000 | 0.908070 | 
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 1589 MiB | 7290 MiB | 440273 GiB | 440272 GiB | | from large pool | 1566 MiB | 7240 MiB | 438070 GiB | 438069 GiB | | from small pool | 23 MiB | 51 MiB | 2203 GiB | 2203 GiB | |---------------------------------------------------------------------------| | Active memory | 1589 MiB | 7290 MiB | 440273 GiB | 440272 GiB | | from large pool | 1566 MiB | 7240 MiB | 438070 GiB | 438069 GiB | | from small pool | 23 MiB | 51 MiB | 2203 GiB | 2203 GiB | |---------------------------------------------------------------------------| | Requested memory | 1583 MiB | 7285 MiB | 439981 GiB | 439979 GiB | | from large pool | 1560 MiB | 7235 MiB | 437788 GiB | 437786 GiB | | from small pool | 23 MiB | 51 MiB | 2192 GiB | 2192 GiB | |---------------------------------------------------------------------------| | GPU reserved memory | 5538 MiB | 7380 MiB | 18572 MiB | 13034 MiB | | from large pool | 5508 MiB | 7328 MiB | 18468 MiB | 12960 MiB | | from small pool | 30 MiB | 52 MiB | 104 MiB | 74 MiB | |---------------------------------------------------------------------------| | Non-releasable memory | 264726 KiB | 545173 KiB | 154496 GiB | 154496 GiB | | from large pool | 259656 KiB | 540498 KiB | 152152 GiB | 152152 GiB | | from small pool | 5070 KiB | 22942 KiB | 2343 GiB | 2343 GiB | |---------------------------------------------------------------------------| | Allocations | 1518 | 1736 | 48429 K | 48428 K | | from large pool | 242 | 320 | 14932 K | 14932 K | | from small pool | 1276 | 1455 | 33496 K | 33495 K | |---------------------------------------------------------------------------| | Active allocs | 1518 | 1736 | 48429 K | 48428 K | | from large pool | 242 | 320 | 14932 K | 14932 K | | from small pool | 1276 | 1455 | 33496 K | 33495 K | |---------------------------------------------------------------------------| | GPU reserved segments | 119 | 190 | 399 | 280 | | from large pool | 104 | 164 | 347 | 243 | | from small pool | 15 | 26 | 52 | 37 | |---------------------------------------------------------------------------| | Non-releasable allocs | 41 | 60 | 22134 K | 22134 K | | from large pool | 21 | 28 | 3630 K | 3630 K | | from small pool | 20 | 43 | 18503 K | 18503 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
      [743/743 6:00:33]
    
Evaluation results on the test set:
{'eval_loss': 0.21304276585578918, 'eval_accuracy': 0.9208421052631579, 'eval_f1': 0.9208423095953104, 'eval_runtime': 122.0514, 'eval_samples_per_second': 194.59, 'eval_steps_per_second': 6.088, 'epoch': 2.0}
Bonus - DDoRA training - continued...¶
In [24]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
      [3910/3910 2:54:29, Epoch 5/5]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 50 | 0.220400 | 0.234968 | 0.918400 | 0.918165 | 
| 100 | 0.218500 | 0.291846 | 0.909600 | 0.909350 | 
| 150 | 0.262700 | 0.238732 | 0.904800 | 0.904872 | 
| 200 | 0.202600 | 0.255080 | 0.905600 | 0.905789 | 
| 250 | 0.236200 | 0.241231 | 0.912800 | 0.912896 | 
| 300 | 0.241200 | 0.219332 | 0.915200 | 0.915200 | 
| 350 | 0.250000 | 0.223478 | 0.909600 | 0.909217 | 
| 400 | 0.212100 | 0.238065 | 0.910400 | 0.910055 | 
| 450 | 0.205200 | 0.279559 | 0.912800 | 0.912232 | 
| 500 | 0.218100 | 0.243256 | 0.918400 | 0.918435 | 
| 550 | 0.204100 | 0.298102 | 0.916800 | 0.916634 | 
| 600 | 0.246300 | 0.275445 | 0.916000 | 0.915923 | 
| 650 | 0.220000 | 0.258947 | 0.905600 | 0.905260 | 
| 700 | 0.196100 | 0.254702 | 0.912800 | 0.912704 | 
| 750 | 0.234500 | 0.282484 | 0.887200 | 0.885504 | 
| 800 | 0.202900 | 0.328079 | 0.892800 | 0.891303 | 
| 850 | 0.186800 | 0.242680 | 0.912800 | 0.912780 | 
| 900 | 0.180200 | 0.268671 | 0.921600 | 0.921575 | 
| 950 | 0.173000 | 0.297629 | 0.917600 | 0.917391 | 
| 1000 | 0.196400 | 0.250007 | 0.904800 | 0.904209 | 
| 1050 | 0.196400 | 0.241103 | 0.920800 | 0.920794 | 
| 1100 | 0.188100 | 0.263111 | 0.920000 | 0.919934 | 
| 1150 | 0.170100 | 0.242923 | 0.916800 | 0.916941 | 
| 1200 | 0.195600 | 0.233525 | 0.927200 | 0.927236 | 
| 1250 | 0.175300 | 0.244239 | 0.909600 | 0.908955 | 
| 1300 | 0.197200 | 0.263203 | 0.914400 | 0.914565 | 
| 1350 | 0.225300 | 0.233377 | 0.916000 | 0.915907 | 
| 1400 | 0.193400 | 0.232386 | 0.922400 | 0.922477 | 
| 1450 | 0.201500 | 0.216410 | 0.919200 | 0.919168 | 
| 1500 | 0.185000 | 0.218531 | 0.916800 | 0.916458 | 
| 1550 | 0.208200 | 0.246917 | 0.910400 | 0.910552 | 
| 1600 | 0.156400 | 0.343675 | 0.893600 | 0.893824 | 
| 1650 | 0.145800 | 0.321663 | 0.923200 | 0.923150 | 
| 1700 | 0.170000 | 0.295388 | 0.918400 | 0.918467 | 
| 1750 | 0.166200 | 0.247114 | 0.928800 | 0.928760 | 
| 1800 | 0.154300 | 0.227517 | 0.920000 | 0.920066 | 
| 1850 | 0.165400 | 0.281829 | 0.922400 | 0.922459 | 
| 1900 | 0.153700 | 0.256229 | 0.924800 | 0.924710 | 
| 1950 | 0.164500 | 0.270434 | 0.923200 | 0.923297 | 
| 2000 | 0.167300 | 0.226990 | 0.924800 | 0.924724 | 
| 2050 | 0.167500 | 0.239940 | 0.927200 | 0.927120 | 
| 2100 | 0.163400 | 0.226345 | 0.918400 | 0.918146 | 
| 2150 | 0.116000 | 0.307402 | 0.923200 | 0.923108 | 
| 2200 | 0.172000 | 0.269007 | 0.921600 | 0.921410 | 
| 2250 | 0.160900 | 0.270219 | 0.924800 | 0.924634 | 
| 2300 | 0.165600 | 0.256944 | 0.921600 | 0.921521 | 
| 2350 | 0.155100 | 0.233660 | 0.916000 | 0.915666 | 
| 2400 | 0.138500 | 0.256974 | 0.913600 | 0.913331 | 
| 2450 | 0.124600 | 0.244719 | 0.913600 | 0.913371 | 
| 2500 | 0.118200 | 0.303150 | 0.920000 | 0.920109 | 
| 2550 | 0.102700 | 0.305154 | 0.912000 | 0.911568 | 
| 2600 | 0.140000 | 0.266935 | 0.921600 | 0.921664 | 
| 2650 | 0.114900 | 0.281125 | 0.915200 | 0.915099 | 
| 2700 | 0.124200 | 0.297636 | 0.916800 | 0.916746 | 
| 2750 | 0.112800 | 0.275313 | 0.919200 | 0.919096 | 
| 2800 | 0.109300 | 0.253661 | 0.920000 | 0.919934 | 
| 2850 | 0.107300 | 0.281190 | 0.913600 | 0.913127 | 
| 2900 | 0.118900 | 0.262891 | 0.920800 | 0.920850 | 
| 2950 | 0.123200 | 0.258890 | 0.922400 | 0.922300 | 
| 3000 | 0.130200 | 0.258712 | 0.921600 | 0.921562 | 
| 3050 | 0.119200 | 0.269527 | 0.917600 | 0.917567 | 
| 3100 | 0.124000 | 0.263077 | 0.913600 | 0.913127 | 
| 3150 | 0.101900 | 0.296506 | 0.919200 | 0.919181 | 
| 3200 | 0.093300 | 0.302173 | 0.923200 | 0.923093 | 
| 3250 | 0.084200 | 0.305703 | 0.920800 | 0.920781 | 
| 3300 | 0.091100 | 0.294771 | 0.921600 | 0.921506 | 
| 3350 | 0.077800 | 0.341725 | 0.920000 | 0.919806 | 
| 3400 | 0.087400 | 0.324175 | 0.916800 | 0.916716 | 
| 3450 | 0.100900 | 0.307702 | 0.921600 | 0.921443 | 
| 3500 | 0.080300 | 0.325871 | 0.913600 | 0.913573 | 
| 3550 | 0.067700 | 0.348119 | 0.916800 | 0.916800 | 
| 3600 | 0.109800 | 0.301693 | 0.918400 | 0.918184 | 
| 3650 | 0.094600 | 0.309193 | 0.915200 | 0.915114 | 
| 3700 | 0.087600 | 0.316292 | 0.916800 | 0.916598 | 
| 3750 | 0.076500 | 0.331572 | 0.915200 | 0.915225 | 
| 3800 | 0.082600 | 0.318606 | 0.918400 | 0.918302 | 
| 3850 | 0.075700 | 0.319939 | 0.919200 | 0.919126 | 
| 3900 | 0.076200 | 0.319871 | 0.919200 | 0.919126 | 
DDoRA (All Attention) Test Results: {'eval_loss': 0.20528165996074677, 'eval_accuracy': 0.9178526315789474, 'eval_f1': 0.9178393215322689, 'eval_runtime': 160.0105, 'eval_samples_per_second': 148.428, 'eval_steps_per_second': 4.643, 'epoch': 5.0}
Prediction drift
[[-2.8385003  1.8467747]
 [-4.145672   3.0743616]
 [ 2.190418  -2.6564922]
 [-2.6781795  1.6880615]
 [ 4.3640585 -5.1461053]] [1 1 0 1 0]
LoRA Heatmap
[m_out] Mean: 0.0013, Std: 0.7323, Min: -3.2482, Max: 5.2270 [m_in ] Mean: 0.0093, Std: 0.5591, Min: -2.4425, Max: 2.6041
In [25]:
trainer_ddora_all_attn.train()
eval_results_ddora_all_attn = trainer_ddora_all_attn.evaluate(dataset_encoded["test"])
print(f"DDoRA (All Attention) Test Results: {eval_results_ddora_all_attn}")
print('Prediction drift')
outputs = trainer_ddora_all_attn.predict(dataset_encoded["validation"])
print(outputs.predictions[:5], outputs.label_ids[:5])
print('LoRA Heatmap')
layer_names = []
b_norms = []
for name, param in model_ddora_all_attn.named_parameters():
    if "lora.B" in name:
        layer_names.append(name)
        b_norms.append(param.norm().item())
sns.barplot(x=b_norms, y=layer_names, color='navy')
plt.xlabel("Weight Norm")
plt.title("LoRA B Norms by Layer")
plt.tight_layout()
plt.show()
import torch
import numpy as np
import matplotlib.pyplot as plt
m_in_values = []
m_out_values = []
for name, module in model_ddora_all_attn.named_modules():
    if isinstance(module, LinearWithDoubleDoRA):
        if hasattr(module, 'm_out'):
            m_out_param = module.m_out.detach().cpu().numpy().flatten()
            m_out_values.extend(m_out_param)
        if hasattr(module, 'm_in'):
            m_in_param = module.m_in.detach().cpu().numpy().flatten()
            m_in_values.extend(m_in_param)
# Convert to numpy arrays
m_in_values = np.array(m_in_values)
m_out_values = np.array(m_out_values)
# Summary stats
print(f"[m_out] Mean: {np.mean(m_out_values):.4f}, Std: {np.std(m_out_values):.4f}, Min: {np.min(m_out_values):.4f}, Max: {np.max(m_out_values):.4f}")
print(f"[m_in ] Mean: {np.mean(m_in_values):.4f}, Std: {np.std(m_in_values):.4f}, Min: {np.min(m_in_values):.4f}, Max: {np.max(m_in_values):.4f}")
# Plot histograms
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(m_out_values, bins=50, alpha=0.7, color='olive')
plt.title('Distribution of Learned m_out Values (DDoRA)')
plt.xlabel('Magnitude (m_out)')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(m_in_values, bins=50, alpha=0.7, color='navy')
plt.title('Distribution of Learned m_in Values (DDoRA)')
plt.xlabel('Magnitude (m_in)')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
      [3910/3910 2:58:07, Epoch 5/5]
    
| Step | Training Loss | Validation Loss | Accuracy | F1 | 
|---|---|---|---|---|
| 50 | 0.193100 | 0.231645 | 0.906400 | 0.905762 | 
| 100 | 0.181300 | 0.284688 | 0.905600 | 0.905806 | 
| 150 | 0.201600 | 0.228212 | 0.919200 | 0.919194 | 
| 200 | 0.165300 | 0.329915 | 0.904000 | 0.903447 | 
| 250 | 0.180800 | 0.242941 | 0.906400 | 0.905702 | 
| 300 | 0.200800 | 0.223974 | 0.911200 | 0.910594 | 
| 350 | 0.207600 | 0.294283 | 0.903200 | 0.902281 | 
| 400 | 0.172700 | 0.272758 | 0.912800 | 0.912720 | 
| 450 | 0.174600 | 0.307540 | 0.912000 | 0.911301 | 
| 500 | 0.194500 | 0.255431 | 0.923200 | 0.922979 | 
| 550 | 0.179700 | 0.227661 | 0.924000 | 0.923982 | 
| 600 | 0.206200 | 0.250669 | 0.924800 | 0.924788 | 
| 650 | 0.161900 | 0.223719 | 0.920800 | 0.920616 | 
| 700 | 0.160000 | 0.327061 | 0.913600 | 0.913709 | 
| 750 | 0.183100 | 0.310250 | 0.889600 | 0.887916 | 
| 800 | 0.176400 | 0.296977 | 0.902400 | 0.901688 | 
| 850 | 0.130100 | 0.293862 | 0.912800 | 0.912720 | 
| 900 | 0.132400 | 0.327696 | 0.912000 | 0.911493 | 
| 950 | 0.145000 | 0.389061 | 0.912800 | 0.912430 | 
| 1000 | 0.154600 | 0.376502 | 0.908800 | 0.908494 | 
| 1050 | 0.149700 | 0.331415 | 0.908000 | 0.908145 | 
| 1100 | 0.150100 | 0.422993 | 0.908800 | 0.908786 | 
| 1150 | 0.122200 | 0.281298 | 0.911200 | 0.911316 | 
| 1200 | 0.146500 | 0.323460 | 0.913600 | 0.913747 | 
| 1250 | 0.152600 | 0.273852 | 0.918400 | 0.918271 | 
| 1300 | 0.174300 | 0.231292 | 0.920000 | 0.919919 | 
| 1350 | 0.172700 | 0.253323 | 0.918400 | 0.918237 | 
| 1400 | 0.174600 | 0.237144 | 0.910400 | 0.910008 | 
| 1450 | 0.150800 | 0.240501 | 0.921600 | 0.921506 | 
| 1500 | 0.196900 | 0.244910 | 0.920800 | 0.920769 | 
| 1550 | 0.202900 | 0.221191 | 0.916000 | 0.916064 | 
| 1600 | 0.134700 | 0.323171 | 0.916800 | 0.916824 | 
| 1650 | 0.106900 | 0.297118 | 0.923200 | 0.922923 | 
| 1700 | 0.138200 | 0.273596 | 0.913600 | 0.913725 | 
| 1750 | 0.125000 | 0.324790 | 0.922400 | 0.922382 | 
| 1800 | 0.113900 | 0.257193 | 0.908800 | 0.908915 | 
| 1850 | 0.138000 | 0.295698 | 0.916800 | 0.916836 | 
| 1900 | 0.110100 | 0.312183 | 0.923200 | 0.923163 | 
| 1950 | 0.120400 | 0.276098 | 0.916000 | 0.916074 | 
| 2000 | 0.135500 | 0.260222 | 0.923200 | 0.923233 | 
| 2050 | 0.134200 | 0.242475 | 0.920800 | 0.920829 | 
| 2100 | 0.124400 | 0.246645 | 0.924800 | 0.924634 | 
| 2150 | 0.099600 | 0.298770 | 0.922400 | 0.922369 | 
| 2200 | 0.120900 | 0.293304 | 0.923200 | 0.923136 | 
| 2250 | 0.119700 | 0.287883 | 0.924800 | 0.924681 | 
| 2300 | 0.125600 | 0.258974 | 0.915200 | 0.915315 | 
| 2350 | 0.118500 | 0.262726 | 0.921600 | 0.921427 | 
| 2400 | 0.094500 | 0.329549 | 0.912800 | 0.912178 | 
| 2450 | 0.102100 | 0.306560 | 0.922400 | 0.922269 | 
| 2500 | 0.100400 | 0.277154 | 0.922400 | 0.922439 | 
| 2550 | 0.080800 | 0.327069 | 0.921600 | 0.921374 | 
| 2600 | 0.103800 | 0.279821 | 0.918400 | 0.918526 | 
| 2650 | 0.097400 | 0.280838 | 0.918400 | 0.918495 | 
| 2700 | 0.089000 | 0.333324 | 0.921600 | 0.921634 | 
| 2750 | 0.093300 | 0.278819 | 0.921600 | 0.921443 | 
| 2800 | 0.069700 | 0.351026 | 0.924000 | 0.923916 | 
| 2850 | 0.071100 | 0.334101 | 0.923200 | 0.923013 | 
| 2900 | 0.089700 | 0.317389 | 0.924000 | 0.923970 | 
| 2950 | 0.101400 | 0.295993 | 0.924800 | 0.924724 | 
| 3000 | 0.090500 | 0.301456 | 0.925600 | 0.925594 | 
| 3050 | 0.092200 | 0.300997 | 0.924000 | 0.924028 | 
| 3100 | 0.091500 | 0.297800 | 0.920000 | 0.919731 | 
| 3150 | 0.079600 | 0.339592 | 0.920800 | 0.920840 | 
| 3200 | 0.072300 | 0.360106 | 0.922400 | 0.922369 | 
| 3250 | 0.053700 | 0.383484 | 0.925600 | 0.925545 | 
| 3300 | 0.072800 | 0.334235 | 0.922400 | 0.922356 | 
| 3350 | 0.065400 | 0.371789 | 0.921600 | 0.921410 | 
| 3400 | 0.068900 | 0.343820 | 0.924000 | 0.923930 | 
| 3450 | 0.071900 | 0.332017 | 0.922400 | 0.922329 | 
| 3500 | 0.046600 | 0.393188 | 0.922400 | 0.922394 | 
| 3550 | 0.057900 | 0.388426 | 0.919200 | 0.919240 | 
| 3600 | 0.065900 | 0.366166 | 0.923200 | 0.923176 | 
| 3650 | 0.073700 | 0.351996 | 0.924000 | 0.923916 | 
| 3700 | 0.070800 | 0.337321 | 0.922400 | 0.922314 | 
| 3750 | 0.054100 | 0.353665 | 0.923200 | 0.923188 | 
| 3800 | 0.051200 | 0.360188 | 0.924000 | 0.923957 | 
| 3850 | 0.061000 | 0.360322 | 0.923200 | 0.923176 | 
| 3900 | 0.063800 | 0.360985 | 0.924800 | 0.924788 | 
DDoRA (All Attention) Test Results: {'eval_loss': 0.2019222378730774, 'eval_accuracy': 0.9212631578947369, 'eval_f1': 0.9212608684750585, 'eval_runtime': 157.6478, 'eval_samples_per_second': 150.652, 'eval_steps_per_second': 4.713, 'epoch': 5.0}
Prediction drift
[[-2.494168   1.6709884]
 [-4.1562243  3.2771716]
 [ 2.5441208 -2.9046338]
 [-2.77265    1.9955821]
 [ 5.211438  -5.7563157]] [1 1 0 1 0]
LoRA Heatmap
[m_out] Mean: 0.0017, Std: 0.8710, Min: -3.8374, Max: 6.9031 [m_in ] Mean: 0.0119, Std: 0.6580, Min: -2.7239, Max: 2.7596
In [26]:
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(name, param.abs().mean().item())
print('Parameter Statistics: mean.abs()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(name, param.abs().mean().item())
Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.scale_out 2.211477279663086 distilbert.transformer.layer.0.attention.q_lin.scale_in 1.7990350723266602 distilbert.transformer.layer.0.attention.k_lin.scale_out 2.160134792327881 distilbert.transformer.layer.0.attention.k_lin.scale_in 1.7010796070098877 distilbert.transformer.layer.0.attention.v_lin.scale_out 2.182528257369995 distilbert.transformer.layer.0.attention.v_lin.scale_in 1.0187453031539917 distilbert.transformer.layer.0.attention.out_lin.scale_out 1.5009715557098389 distilbert.transformer.layer.0.attention.out_lin.scale_in 1.3177869319915771 distilbert.transformer.layer.1.attention.q_lin.scale_out 2.0520782470703125 distilbert.transformer.layer.1.attention.q_lin.scale_in 1.6306018829345703 distilbert.transformer.layer.1.attention.k_lin.scale_out 1.8983781337738037 distilbert.transformer.layer.1.attention.k_lin.scale_in 1.684961199760437 distilbert.transformer.layer.1.attention.v_lin.scale_out 1.3902130126953125 distilbert.transformer.layer.1.attention.v_lin.scale_in 1.4229816198349 distilbert.transformer.layer.1.attention.out_lin.scale_out 1.4535112380981445 distilbert.transformer.layer.1.attention.out_lin.scale_in 1.2360076904296875 distilbert.transformer.layer.2.attention.q_lin.scale_out 1.8091825246810913 distilbert.transformer.layer.2.attention.q_lin.scale_in 1.6466021537780762 distilbert.transformer.layer.2.attention.k_lin.scale_out 2.054741859436035 distilbert.transformer.layer.2.attention.k_lin.scale_in 1.6497936248779297 distilbert.transformer.layer.2.attention.v_lin.scale_out 1.364464282989502 distilbert.transformer.layer.2.attention.v_lin.scale_in 1.3622424602508545 distilbert.transformer.layer.2.attention.out_lin.scale_out 1.31589674949646 distilbert.transformer.layer.2.attention.out_lin.scale_in 1.1265673637390137 distilbert.transformer.layer.3.attention.q_lin.scale_out 2.026132345199585 distilbert.transformer.layer.3.attention.q_lin.scale_in 1.7109956741333008 distilbert.transformer.layer.3.attention.k_lin.scale_out 2.1449785232543945 distilbert.transformer.layer.3.attention.k_lin.scale_in 1.6678569316864014 distilbert.transformer.layer.3.attention.v_lin.scale_out 1.5447391271591187 distilbert.transformer.layer.3.attention.v_lin.scale_in 1.2968910932540894 distilbert.transformer.layer.3.attention.out_lin.scale_out 1.459152102470398 distilbert.transformer.layer.3.attention.out_lin.scale_in 1.3626108169555664 distilbert.transformer.layer.4.attention.q_lin.scale_out 1.8095883131027222 distilbert.transformer.layer.4.attention.q_lin.scale_in 1.434910535812378 distilbert.transformer.layer.4.attention.k_lin.scale_out 2.0096232891082764 distilbert.transformer.layer.4.attention.k_lin.scale_in 1.4818964004516602 distilbert.transformer.layer.4.attention.v_lin.scale_out 1.224574089050293 distilbert.transformer.layer.4.attention.v_lin.scale_in 0.9664930701255798 distilbert.transformer.layer.4.attention.out_lin.scale_out 1.186560034751892 distilbert.transformer.layer.4.attention.out_lin.scale_in 1.138476848602295 distilbert.transformer.layer.5.attention.q_lin.scale_out 1.6922223567962646 distilbert.transformer.layer.5.attention.q_lin.scale_in 1.1524317264556885 distilbert.transformer.layer.5.attention.k_lin.scale_out 1.892466425895691 distilbert.transformer.layer.5.attention.k_lin.scale_in 1.325523853302002 distilbert.transformer.layer.5.attention.v_lin.scale_out 1.4372739791870117 distilbert.transformer.layer.5.attention.v_lin.scale_in 0.7315273284912109 distilbert.transformer.layer.5.attention.out_lin.scale_out 1.1462877988815308 distilbert.transformer.layer.5.attention.out_lin.scale_in 1.2485125064849854 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.m_out 0.904198169708252 distilbert.transformer.layer.0.attention.q_lin.m_in 0.5918328166007996 distilbert.transformer.layer.0.attention.k_lin.m_out 0.8658651113510132 distilbert.transformer.layer.0.attention.k_lin.m_in 0.552005410194397 distilbert.transformer.layer.0.attention.v_lin.m_out 0.891939640045166 distilbert.transformer.layer.0.attention.v_lin.m_in 0.3521280288696289 distilbert.transformer.layer.0.attention.out_lin.m_out 0.5082004070281982 distilbert.transformer.layer.0.attention.out_lin.m_in 0.3994007110595703 distilbert.transformer.layer.1.attention.q_lin.m_out 0.7819246053695679 distilbert.transformer.layer.1.attention.q_lin.m_in 0.5072041153907776 distilbert.transformer.layer.1.attention.k_lin.m_out 0.7096916437149048 distilbert.transformer.layer.1.attention.k_lin.m_in 0.5101008415222168 distilbert.transformer.layer.1.attention.v_lin.m_out 0.48597097396850586 distilbert.transformer.layer.1.attention.v_lin.m_in 0.4318428933620453 distilbert.transformer.layer.1.attention.out_lin.m_out 0.510459303855896 distilbert.transformer.layer.1.attention.out_lin.m_in 0.370250940322876 distilbert.transformer.layer.2.attention.q_lin.m_out 0.672264039516449 distilbert.transformer.layer.2.attention.q_lin.m_in 0.49723148345947266 distilbert.transformer.layer.2.attention.k_lin.m_out 0.8109812140464783 distilbert.transformer.layer.2.attention.k_lin.m_in 0.5146569013595581 distilbert.transformer.layer.2.attention.v_lin.m_out 0.4366285800933838 distilbert.transformer.layer.2.attention.v_lin.m_in 0.36893153190612793 distilbert.transformer.layer.2.attention.out_lin.m_out 0.4298613965511322 distilbert.transformer.layer.2.attention.out_lin.m_in 0.3405319154262543 distilbert.transformer.layer.3.attention.q_lin.m_out 0.7905937433242798 distilbert.transformer.layer.3.attention.q_lin.m_in 0.5291155576705933 distilbert.transformer.layer.3.attention.k_lin.m_out 0.8796322345733643 distilbert.transformer.layer.3.attention.k_lin.m_in 0.539239764213562 distilbert.transformer.layer.3.attention.v_lin.m_out 0.5514059066772461 distilbert.transformer.layer.3.attention.v_lin.m_in 0.395082950592041 distilbert.transformer.layer.3.attention.out_lin.m_out 0.5047212839126587 distilbert.transformer.layer.3.attention.out_lin.m_in 0.4035267233848572 distilbert.transformer.layer.4.attention.q_lin.m_out 0.682336688041687 distilbert.transformer.layer.4.attention.q_lin.m_in 0.4449799656867981 distilbert.transformer.layer.4.attention.k_lin.m_out 0.7756460905075073 distilbert.transformer.layer.4.attention.k_lin.m_in 0.46460646390914917 distilbert.transformer.layer.4.attention.v_lin.m_out 0.41292259097099304 distilbert.transformer.layer.4.attention.v_lin.m_in 0.2883872985839844 distilbert.transformer.layer.4.attention.out_lin.m_out 0.3856971859931946 distilbert.transformer.layer.4.attention.out_lin.m_in 0.3439074456691742 distilbert.transformer.layer.5.attention.q_lin.m_out 0.5993542671203613 distilbert.transformer.layer.5.attention.q_lin.m_in 0.3158256709575653 distilbert.transformer.layer.5.attention.k_lin.m_out 0.6338485479354858 distilbert.transformer.layer.5.attention.k_lin.m_in 0.3383934199810028 distilbert.transformer.layer.5.attention.v_lin.m_out 0.4689478278160095 distilbert.transformer.layer.5.attention.v_lin.m_in 0.2039266675710678 distilbert.transformer.layer.5.attention.out_lin.m_out 0.3742349147796631 distilbert.transformer.layer.5.attention.out_lin.m_in 0.39616748690605164 Parameter Statistics: mean.abs() distilbert.transformer.layer.0.attention.q_lin.lora.A 0.6073777079582214 distilbert.transformer.layer.0.attention.q_lin.lora.B 0.5936362743377686 distilbert.transformer.layer.0.attention.k_lin.lora.A 0.5834858417510986 distilbert.transformer.layer.0.attention.k_lin.lora.B 0.5944669246673584 distilbert.transformer.layer.0.attention.v_lin.lora.A 0.477506160736084 distilbert.transformer.layer.0.attention.v_lin.lora.B 0.5750714540481567 distilbert.transformer.layer.0.attention.out_lin.lora.A 0.4822663962841034 distilbert.transformer.layer.0.attention.out_lin.lora.B 0.5455904006958008 distilbert.transformer.layer.1.attention.q_lin.lora.A 0.5509202480316162 distilbert.transformer.layer.1.attention.q_lin.lora.B 0.5553864240646362 distilbert.transformer.layer.1.attention.k_lin.lora.A 0.5384150147438049 distilbert.transformer.layer.1.attention.k_lin.lora.B 0.5774441957473755 distilbert.transformer.layer.1.attention.v_lin.lora.A 0.5051205158233643 distilbert.transformer.layer.1.attention.v_lin.lora.B 0.4909871518611908 distilbert.transformer.layer.1.attention.out_lin.lora.A 0.4681047201156616 distilbert.transformer.layer.1.attention.out_lin.lora.B 0.5537931323051453 distilbert.transformer.layer.2.attention.q_lin.lora.A 0.5328584909439087 distilbert.transformer.layer.2.attention.q_lin.lora.B 0.5830724239349365 distilbert.transformer.layer.2.attention.k_lin.lora.A 0.5513193011283875 distilbert.transformer.layer.2.attention.k_lin.lora.B 0.5603161454200745 distilbert.transformer.layer.2.attention.v_lin.lora.A 0.46175330877304077 distilbert.transformer.layer.2.attention.v_lin.lora.B 0.4580882787704468 distilbert.transformer.layer.2.attention.out_lin.lora.A 0.4390673041343689 distilbert.transformer.layer.2.attention.out_lin.lora.B 0.5048208236694336 distilbert.transformer.layer.3.attention.q_lin.lora.A 0.5584718585014343 distilbert.transformer.layer.3.attention.q_lin.lora.B 0.5750592947006226 distilbert.transformer.layer.3.attention.k_lin.lora.A 0.5745105147361755 distilbert.transformer.layer.3.attention.k_lin.lora.B 0.5809159278869629 distilbert.transformer.layer.3.attention.v_lin.lora.A 0.47269874811172485 distilbert.transformer.layer.3.attention.v_lin.lora.B 0.5128231048583984 distilbert.transformer.layer.3.attention.out_lin.lora.A 0.460321843624115 distilbert.transformer.layer.3.attention.out_lin.lora.B 0.48533257842063904 distilbert.transformer.layer.4.attention.q_lin.lora.A 0.5054908990859985 distilbert.transformer.layer.4.attention.q_lin.lora.B 0.5460034608840942 distilbert.transformer.layer.4.attention.k_lin.lora.A 0.5137321352958679 distilbert.transformer.layer.4.attention.k_lin.lora.B 0.5609502196311951 distilbert.transformer.layer.4.attention.v_lin.lora.A 0.4001336991786957 distilbert.transformer.layer.4.attention.v_lin.lora.B 0.4369596838951111 distilbert.transformer.layer.4.attention.out_lin.lora.A 0.42696064710617065 distilbert.transformer.layer.4.attention.out_lin.lora.B 0.4693056344985962 distilbert.transformer.layer.5.attention.q_lin.lora.A 0.4086011052131653 distilbert.transformer.layer.5.attention.q_lin.lora.B 0.48308905959129333 distilbert.transformer.layer.5.attention.k_lin.lora.A 0.432908296585083 distilbert.transformer.layer.5.attention.k_lin.lora.B 0.4992391765117645 distilbert.transformer.layer.5.attention.v_lin.lora.A 0.34997496008872986 distilbert.transformer.layer.5.attention.v_lin.lora.B 0.45308423042297363 distilbert.transformer.layer.5.attention.out_lin.lora.A 0.46718907356262207 distilbert.transformer.layer.5.attention.out_lin.lora.B 0.5256445407867432
In [27]:
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lora" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.scale" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
print('Parameter Statistics: param.norm()')
for name, param in model_ddora_all_attn.named_parameters():
    if "lin.m" in name:
        print(f"{name} weight norm: {param.norm().item():.4f}")
Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.lora.A weight norm: 88.6770 distilbert.transformer.layer.0.attention.q_lin.lora.B weight norm: 83.6840 distilbert.transformer.layer.0.attention.k_lin.lora.A weight norm: 85.9274 distilbert.transformer.layer.0.attention.k_lin.lora.B weight norm: 83.4974 distilbert.transformer.layer.0.attention.v_lin.lora.A weight norm: 73.3022 distilbert.transformer.layer.0.attention.v_lin.lora.B weight norm: 80.9918 distilbert.transformer.layer.0.attention.out_lin.lora.A weight norm: 73.7286 distilbert.transformer.layer.0.attention.out_lin.lora.B weight norm: 77.5629 distilbert.transformer.layer.1.attention.q_lin.lora.A weight norm: 81.9203 distilbert.transformer.layer.1.attention.q_lin.lora.B weight norm: 78.9050 distilbert.transformer.layer.1.attention.k_lin.lora.A weight norm: 80.0022 distilbert.transformer.layer.1.attention.k_lin.lora.B weight norm: 81.3284 distilbert.transformer.layer.1.attention.v_lin.lora.A weight norm: 76.1209 distilbert.transformer.layer.1.attention.v_lin.lora.B weight norm: 70.1328 distilbert.transformer.layer.1.attention.out_lin.lora.A weight norm: 71.7114 distilbert.transformer.layer.1.attention.out_lin.lora.B weight norm: 78.6895 distilbert.transformer.layer.2.attention.q_lin.lora.A weight norm: 80.5862 distilbert.transformer.layer.2.attention.q_lin.lora.B weight norm: 81.8691 distilbert.transformer.layer.2.attention.k_lin.lora.A weight norm: 81.0452 distilbert.transformer.layer.2.attention.k_lin.lora.B weight norm: 79.2133 distilbert.transformer.layer.2.attention.v_lin.lora.A weight norm: 70.1643 distilbert.transformer.layer.2.attention.v_lin.lora.B weight norm: 65.6669 distilbert.transformer.layer.2.attention.out_lin.lora.A weight norm: 68.0408 distilbert.transformer.layer.2.attention.out_lin.lora.B weight norm: 71.9416 distilbert.transformer.layer.3.attention.q_lin.lora.A weight norm: 81.9671 distilbert.transformer.layer.3.attention.q_lin.lora.B weight norm: 80.9681 distilbert.transformer.layer.3.attention.k_lin.lora.A weight norm: 84.2276 distilbert.transformer.layer.3.attention.k_lin.lora.B weight norm: 82.0434 distilbert.transformer.layer.3.attention.v_lin.lora.A weight norm: 72.1863 distilbert.transformer.layer.3.attention.v_lin.lora.B weight norm: 73.0758 distilbert.transformer.layer.3.attention.out_lin.lora.A weight norm: 70.6866 distilbert.transformer.layer.3.attention.out_lin.lora.B weight norm: 69.2462 distilbert.transformer.layer.4.attention.q_lin.lora.A weight norm: 75.9484 distilbert.transformer.layer.4.attention.q_lin.lora.B weight norm: 77.3235 distilbert.transformer.layer.4.attention.k_lin.lora.A weight norm: 76.9465 distilbert.transformer.layer.4.attention.k_lin.lora.B weight norm: 79.4074 distilbert.transformer.layer.4.attention.v_lin.lora.A weight norm: 62.7360 distilbert.transformer.layer.4.attention.v_lin.lora.B weight norm: 63.9343 distilbert.transformer.layer.4.attention.out_lin.lora.A weight norm: 66.6852 distilbert.transformer.layer.4.attention.out_lin.lora.B weight norm: 67.7626 distilbert.transformer.layer.5.attention.q_lin.lora.A weight norm: 63.4330 distilbert.transformer.layer.5.attention.q_lin.lora.B weight norm: 69.0939 distilbert.transformer.layer.5.attention.k_lin.lora.A weight norm: 66.2249 distilbert.transformer.layer.5.attention.k_lin.lora.B weight norm: 70.6252 distilbert.transformer.layer.5.attention.v_lin.lora.A weight norm: 56.1786 distilbert.transformer.layer.5.attention.v_lin.lora.B weight norm: 67.4088 distilbert.transformer.layer.5.attention.out_lin.lora.A weight norm: 71.9610 distilbert.transformer.layer.5.attention.out_lin.lora.B weight norm: 76.3459 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.scale_out weight norm: 67.8149 distilbert.transformer.layer.0.attention.q_lin.scale_in weight norm: 55.8788 distilbert.transformer.layer.0.attention.k_lin.scale_out weight norm: 66.3004 distilbert.transformer.layer.0.attention.k_lin.scale_in weight norm: 53.4427 distilbert.transformer.layer.0.attention.v_lin.scale_out weight norm: 67.3438 distilbert.transformer.layer.0.attention.v_lin.scale_in weight norm: 40.8041 distilbert.transformer.layer.0.attention.out_lin.scale_out weight norm: 49.3765 distilbert.transformer.layer.0.attention.out_lin.scale_in weight norm: 46.0389 distilbert.transformer.layer.1.attention.q_lin.scale_out weight norm: 63.8303 distilbert.transformer.layer.1.attention.q_lin.scale_in weight norm: 51.8581 distilbert.transformer.layer.1.attention.k_lin.scale_out weight norm: 59.3891 distilbert.transformer.layer.1.attention.k_lin.scale_in weight norm: 52.9394 distilbert.transformer.layer.1.attention.v_lin.scale_out weight norm: 47.0823 distilbert.transformer.layer.1.attention.v_lin.scale_in weight norm: 47.6743 distilbert.transformer.layer.1.attention.out_lin.scale_out weight norm: 48.7647 distilbert.transformer.layer.1.attention.out_lin.scale_in weight norm: 44.8163 distilbert.transformer.layer.2.attention.q_lin.scale_out weight norm: 57.6219 distilbert.transformer.layer.2.attention.q_lin.scale_in weight norm: 52.3231 distilbert.transformer.layer.2.attention.k_lin.scale_out weight norm: 63.9857 distilbert.transformer.layer.2.attention.k_lin.scale_in weight norm: 52.2575 distilbert.transformer.layer.2.attention.v_lin.scale_out weight norm: 46.9235 distilbert.transformer.layer.2.attention.v_lin.scale_in weight norm: 45.0910 distilbert.transformer.layer.2.attention.out_lin.scale_out weight norm: 45.7886 distilbert.transformer.layer.2.attention.out_lin.scale_in weight norm: 41.9345 distilbert.transformer.layer.3.attention.q_lin.scale_out weight norm: 63.2718 distilbert.transformer.layer.3.attention.q_lin.scale_in weight norm: 53.4324 distilbert.transformer.layer.3.attention.k_lin.scale_out weight norm: 66.0361 distilbert.transformer.layer.3.attention.k_lin.scale_in weight norm: 52.7628 distilbert.transformer.layer.3.attention.v_lin.scale_out weight norm: 51.4346 distilbert.transformer.layer.3.attention.v_lin.scale_in weight norm: 45.3650 distilbert.transformer.layer.3.attention.out_lin.scale_out weight norm: 49.2689 distilbert.transformer.layer.3.attention.out_lin.scale_in weight norm: 46.8675 distilbert.transformer.layer.4.attention.q_lin.scale_out weight norm: 58.5162 distilbert.transformer.layer.4.attention.q_lin.scale_in weight norm: 47.5943 distilbert.transformer.layer.4.attention.k_lin.scale_out weight norm: 62.9462 distilbert.transformer.layer.4.attention.k_lin.scale_in weight norm: 48.5050 distilbert.transformer.layer.4.attention.v_lin.scale_out weight norm: 45.4089 distilbert.transformer.layer.4.attention.v_lin.scale_in weight norm: 38.7889 distilbert.transformer.layer.4.attention.out_lin.scale_out weight norm: 44.1285 distilbert.transformer.layer.4.attention.out_lin.scale_in weight norm: 42.3015 distilbert.transformer.layer.5.attention.q_lin.scale_out weight norm: 55.4135 distilbert.transformer.layer.5.attention.q_lin.scale_in weight norm: 41.3333 distilbert.transformer.layer.5.attention.k_lin.scale_out weight norm: 58.2401 distilbert.transformer.layer.5.attention.k_lin.scale_in weight norm: 44.2531 distilbert.transformer.layer.5.attention.v_lin.scale_out weight norm: 49.2695 distilbert.transformer.layer.5.attention.v_lin.scale_in weight norm: 33.0258 distilbert.transformer.layer.5.attention.out_lin.scale_out weight norm: 43.1049 distilbert.transformer.layer.5.attention.out_lin.scale_in weight norm: 44.5303 Parameter Statistics: param.norm() distilbert.transformer.layer.0.attention.q_lin.m_out weight norm: 31.4448 distilbert.transformer.layer.0.attention.q_lin.m_in weight norm: 22.4398 distilbert.transformer.layer.0.attention.k_lin.m_out weight norm: 30.1271 distilbert.transformer.layer.0.attention.k_lin.m_in weight norm: 21.4311 distilbert.transformer.layer.0.attention.v_lin.m_out weight norm: 31.0435 distilbert.transformer.layer.0.attention.v_lin.m_in weight norm: 16.8509 distilbert.transformer.layer.0.attention.out_lin.m_out weight norm: 19.6845 distilbert.transformer.layer.0.attention.out_lin.m_in weight norm: 17.7480 distilbert.transformer.layer.1.attention.q_lin.m_out weight norm: 28.3183 distilbert.transformer.layer.1.attention.q_lin.m_in weight norm: 20.1127 distilbert.transformer.layer.1.attention.k_lin.m_out weight norm: 25.9295 distilbert.transformer.layer.1.attention.k_lin.m_in weight norm: 20.4797 distilbert.transformer.layer.1.attention.v_lin.m_out weight norm: 19.3969 distilbert.transformer.layer.1.attention.v_lin.m_in weight norm: 18.3722 distilbert.transformer.layer.1.attention.out_lin.m_out weight norm: 19.8745 distilbert.transformer.layer.1.attention.out_lin.m_in weight norm: 17.4045 distilbert.transformer.layer.2.attention.q_lin.m_out weight norm: 24.9047 distilbert.transformer.layer.2.attention.q_lin.m_in weight norm: 20.1862 distilbert.transformer.layer.2.attention.k_lin.m_out weight norm: 29.0122 distilbert.transformer.layer.2.attention.k_lin.m_in weight norm: 20.1494 distilbert.transformer.layer.2.attention.v_lin.m_out weight norm: 17.9841 distilbert.transformer.layer.2.attention.v_lin.m_in weight norm: 16.4528 distilbert.transformer.layer.2.attention.out_lin.m_out weight norm: 17.7306 distilbert.transformer.layer.2.attention.out_lin.m_in weight norm: 16.3007 distilbert.transformer.layer.3.attention.q_lin.m_out weight norm: 28.5406 distilbert.transformer.layer.3.attention.q_lin.m_in weight norm: 20.5190 distilbert.transformer.layer.3.attention.k_lin.m_out weight norm: 30.7927 distilbert.transformer.layer.3.attention.k_lin.m_in weight norm: 20.6228 distilbert.transformer.layer.3.attention.v_lin.m_out weight norm: 21.3742 distilbert.transformer.layer.3.attention.v_lin.m_in weight norm: 17.7191 distilbert.transformer.layer.3.attention.out_lin.m_out weight norm: 20.0696 distilbert.transformer.layer.3.attention.out_lin.m_in weight norm: 17.6262 distilbert.transformer.layer.4.attention.q_lin.m_out weight norm: 25.6429 distilbert.transformer.layer.4.attention.q_lin.m_in weight norm: 18.5982 distilbert.transformer.layer.4.attention.k_lin.m_out weight norm: 27.9177 distilbert.transformer.layer.4.attention.k_lin.m_in weight norm: 18.6512 distilbert.transformer.layer.4.attention.v_lin.m_out weight norm: 18.0866 distilbert.transformer.layer.4.attention.v_lin.m_in weight norm: 15.1881 distilbert.transformer.layer.4.attention.out_lin.m_out weight norm: 17.4686 distilbert.transformer.layer.4.attention.out_lin.m_in weight norm: 16.4471 distilbert.transformer.layer.5.attention.q_lin.m_out weight norm: 22.9480 distilbert.transformer.layer.5.attention.q_lin.m_in weight norm: 15.1683 distilbert.transformer.layer.5.attention.k_lin.m_out weight norm: 22.7621 distilbert.transformer.layer.5.attention.k_lin.m_in weight norm: 15.4354 distilbert.transformer.layer.5.attention.v_lin.m_out weight norm: 19.4222 distilbert.transformer.layer.5.attention.v_lin.m_in weight norm: 12.9561 distilbert.transformer.layer.5.attention.out_lin.m_out weight norm: 16.6265 distilbert.transformer.layer.5.attention.out_lin.m_in weight norm: 17.4611
In [28]:
evaluation_results = trainer.evaluate(dataset_encoded["validation"])
print(f"\nEvaluation results on the validation set:")
print(evaluation_results)
print("Last 10 Validation Accuracy Steps:", val_accuracy_steps[-10:])
print("Last 10 Validation Accuracies:", val_accuracies[-10:])
print("Last 10 Validation F1 Steps:", val_f1_steps[-10:])
print("Last 10 Validation F1 Scores:", val_f1_scores[-10:])
Evaluation results on the validation set:
{'eval_loss': 0.22643864154815674, 'eval_accuracy': 0.908, 'eval_f1': 0.908069813307235, 'eval_runtime': 6.7981, 'eval_samples_per_second': 183.876, 'eval_steps_per_second': 5.884, 'epoch': 2.0}
Last 10 Validation Accuracy Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564]
Last 10 Validation Accuracies: [0.9112, 0.9096, 0.9096, 0.912, 0.912, 0.9144, 0.9128, 0.9072, 0.908, 0.9208421052631579]
Last 10 Validation F1 Steps: [1400, 1420, 1440, 1460, 1480, 1500, 1520, 1540, 1560, 1564]
Last 10 Validation F1 Scores: [0.9112445094405226, 0.9096571916118659, 0.9096571916118659, 0.9119863574704475, 0.9119722515145877, 0.9143367287669, 0.9127653425067421, 0.9072402344324614, 0.908069813307235, 0.9208423095953104]
In [ ]: