In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(torch.cuda.get_device_name(0))

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

import bitsandbytes
import peft

print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
PyTorch Version: 2.5.1+cu121
True
1
CUDA Version: 12.1
NVIDIA GeForce RTX 4080 Laptop GPU
bitsandbytes version: 0.43.1
peft version: 0.11.1
True
In [2]:
from transformers import AutoTokenizer
import re
import torch
from datasets import load_dataset, Dataset

model_name = "Qwen/Qwen2.5-0.5B-Instruct"
run_name="Qwen-0.5B-SFT-ultrachat"

#Load data
SYSTEM_PROMPT = "You are a Taylor Swift expert. Answer CORRECTLY and CONCISELY questions about Taylor Swift's life, achievements, songs, and more."

dataset_name = "lamini/taylor_swift"

def get_data(dataset_name, split="train") -> Dataset:
    """Loads and formats the dataset into Qwen's structured chat format."""
    data = load_dataset(dataset_name)[split]

    def format_as_qwen_chat(example):
        messages = [
            {"role": "user", "content": f"{SYSTEM_PROMPT}\n\n{example['question']}"},
            {"role": "assistant", "content": example["answer"]}
        ]
        return {"messages": messages}

    return data.map(format_as_qwen_chat)

train_dataset = get_data(dataset_name, split="train")
eval_dataset = get_data(dataset_name, split="test")
print(train_dataset.column_names)
print (len(train_dataset), len(eval_dataset))
Map:   0%|          | 0/783 [00:00<?, ? examples/s]
Map:   0%|          | 0/87 [00:00<?, ? examples/s]
['question', 'answer', 'input_ids', 'attention_mask', 'labels', 'messages']
783 87
In [4]:
# load tokeniser
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Qwen models should have an EOS token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
# Qwen models should have bos_token:
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

def format_prompt(example):
    """Format and tokenize multi-turn chat data using Qwen's chat template."""
    formatted_chats = []

    for messages in example["messages"]:
        formatted_chat = ""
        for message in messages:
            role = message["role"]
            content = message["content"]
            if role == "user":
                formatted_chat += f"<|im_start|>user\n{content}\n<|im_end|>\n"
            elif role == "assistant":
                formatted_chat += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"

        formatted_chats.append(formatted_chat)

    # Tokenize
    tokens = tokenizer(formatted_chats, padding="max_length", truncation=True, max_length=512)
    tokens["labels"] = tokens["input_ids"].copy()

    return tokens

train_dataset = train_dataset.map(format_prompt, batched=True)
eval_dataset = eval_dataset.map(format_prompt, batched=True)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map:   0%|          | 0/783 [00:00<?, ? examples/s]
Map:   0%|          | 0/87 [00:00<?, ? examples/s]
In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig

# 8-bit quantization configuration for QLoRA
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,  # Enable 8-bit quantization
    llm_int8_threshold=6.0
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #config=config, # add if you specifically need to change dropout rates
    quantization_config=bnb_config,  # Enables 8-bit QLoRA
    device_map="auto",  # Efficient GPU allocation
    trust_remote_code=True  # Required for Qwen models
)


# LoRA Configuration
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# Prepare LoRA Configuration
peft_config = LoraConfig(
    lora_alpha=32,  # LoRA Scaling
    lora_dropout=0.05,  # Dropout for LoRA Layers
    r=16,  # lower rank to avoid instability in low-bit models, e.g. to 8
    bias="none",
    task_type="CAUSAL_LM",
    target_modules = ['q_proj', 'o_proj', 'k_proj', 'v_proj'] # Layers to target
)

# prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

# Double-check if model is fully on GPU
print(model.hf_device_map)
model.print_trainable_parameters()
{'': 0}
trainable params: 2,162,688 || all params: 496,195,456 || trainable%: 0.4359
In [7]:
#### Training Configuration
from transformers import TrainingArguments

output_dir = "./resultsSFTonTS"

# Training arguments
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,  ### decrease to 1, then gradient_accumulation_steps=8 or even more could work well
    gradient_accumulation_steps=6,
    optim="paged_adamw_32bit",
    learning_rate=1e-5, 
    weight_decay=0.005,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    report_to="none",
    logging_steps=10,
    save_steps=10,
    eval_strategy="steps",
    eval_steps=10,
    bf16=True,
    num_train_epochs=6,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    load_best_model_at_end=True,  # Crucial for saving best model
    metric_for_best_model="eval_loss"
)

from trl import SFTTrainer, SFTConfig

# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,  
    eval_dataset=eval_dataset,
    args=training_arguments,
    peft_config=peft_config,
    max_seq_length=512
)
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\utils\_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_seq_length. Will not be supported from version '1.0.0'.

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\transformers\training_args.py:1965: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.
  warnings.warn(
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\trl\trainer\sft_trainer.py:269: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\trl\trainer\sft_trainer.py:355: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\trl\trainer\sft_trainer.py:494: UserWarning: You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored.
  warnings.warn(
In [8]:
# Training!
trainer.train()
print (torch.cuda.memory_summary())

# Save QLoRA weights
trainer.model.save_pretrained("Qwen-0.5B-qlora", safe_serialization=True)
trainer.eval_dataset = eval_dataset
print("Evaluation on test set:", trainer.evaluate())
trainer.save_model("best_Qwen-0.5B-qlora")

# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt

# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]

# Extract validation loss
val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]

# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")

plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training and Validation Loss vs. Steps")
plt.legend()
plt.show()

# Saving the plot:
#plt.savefig("training_loss_plot5.png")
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[390/390 47:07, Epoch 5/6]
Step Training Loss Validation Loss
10 3.744500 3.368035
20 3.611200 3.139828
30 3.126000 2.841521
40 2.922800 2.623208
50 2.727500 2.465018
60 2.390800 2.289523
70 2.200700 2.104997
80 2.152100 1.963779
90 1.995600 1.843933
100 1.810300 1.740988
110 1.688800 1.630660
120 1.596300 1.527745
130 1.482000 1.424480
140 1.358400 1.338545
150 1.289100 1.280527
160 1.249500 1.233615
170 1.233700 1.203846
180 1.166500 1.181909
190 1.134600 1.165817
200 1.079000 1.157527
210 1.095000 1.153008
220 1.092800 1.147384
230 1.110900 1.141840
240 1.082100 1.141329
250 1.170700 1.139104
260 1.046200 1.138668
270 1.148400 1.135208
280 1.090400 1.132997
290 1.050800 1.131783
300 1.085500 1.131931
310 1.075800 1.131498
320 1.127900 1.131557
330 1.044600 1.131117
340 1.070800 1.128806
350 1.114400 1.131586
360 1.104000 1.127242
370 1.098500 1.128094
380 1.103100 1.130532
390 1.069000 1.129407

C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1174 MiB |   8413 MiB |  90644 GiB |  90643 GiB |
|       from large pool |   1073 MiB |   8306 MiB |  85270 GiB |  85269 GiB |
|       from small pool |    100 MiB |    277 MiB |   5374 GiB |   5374 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   1174 MiB |   8413 MiB |  90644 GiB |  90643 GiB |
|       from large pool |   1073 MiB |   8306 MiB |  85270 GiB |  85269 GiB |
|       from small pool |    100 MiB |    277 MiB |   5374 GiB |   5374 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |   1168 MiB |   8399 MiB |  89549 GiB |  89548 GiB |
|       from large pool |   1068 MiB |   8293 MiB |  84179 GiB |  84178 GiB |
|       from small pool |    100 MiB |    277 MiB |   5369 GiB |   5369 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  12724 MiB |  12760 MiB |   3064 GiB |   3052 GiB |
|       from large pool |  12438 MiB |  12478 MiB |   3036 GiB |   3024 GiB |
|       from small pool |    286 MiB |    294 MiB |     27 GiB |     27 GiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |   1149 MiB |   1936 MiB |  79981 GiB |  79980 GiB |
|       from large pool |   1020 MiB |   1806 MiB |  74490 GiB |  74489 GiB |
|       from small pool |    129 MiB |    189 MiB |   5490 GiB |   5490 GiB |
|---------------------------------------------------------------------------|
| Allocations           |    1572    |    2579    |   47690 K  |   47688 K  |
|       from large pool |     139    |     307    |   12519 K  |   12519 K  |
|       from small pool |    1433    |    2320    |   35170 K  |   35169 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1572    |    2579    |   47690 K  |   47688 K  |
|       from large pool |     139    |     307    |   12519 K  |   12519 K  |
|       from small pool |    1433    |    2320    |   35170 K  |   35169 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     202    |     218    |   19206    |   19004    |
|       from large pool |      59    |      83    |    4924    |    4865    |
|       from small pool |     143    |     147    |   14282    |   14139    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     528    |     718    |   29295 K  |   29294 K  |
|       from large pool |      63    |      76    |    9344 K  |    9344 K  |
|       from small pool |     465    |     648    |   19950 K  |   19949 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\huggingface_hub\file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
C:\Users\alexa\miniconda3\envs\dpo_env\lib\site-packages\bitsandbytes\autograd\_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[11/11 00:07]
Evaluation on test set: {'eval_loss': 1.127242088317871, 'eval_runtime': 8.4641, 'eval_samples_per_second': 10.279, 'eval_steps_per_second': 1.3, 'epoch': 5.969387755102041}
No description has been provided for this image
In [9]:
eval_dataset.column_names
Out[9]:
['question', 'answer', 'input_ids', 'attention_mask', 'labels', 'messages']
In [11]:
print(eval_dataset[0]['question'], eval_dataset[0]['answer'], eval_dataset[0]['messages'])
Has Taylor Swift written songs for other artists? Yes, Taylor Swift has written songs for other artists. Some notable examples include This Is What You Came For by Calvin Harris featuring Rihanna, Better Man by Little Big Town, and You'll Always Find Your Way Back Home by Miley Cyrus. [{'content': "You are a Taylor Swift expert. Answer CORRECTLY and CONCISELY questions about Taylor Swift's life, achievements, songs, and more.\n\nHas Taylor Swift written songs for other artists?", 'role': 'user'}, {'content': "Yes, Taylor Swift has written songs for other artists. Some notable examples include This Is What You Came For by Calvin Harris featuring Rihanna, Better Man by Little Big Town, and You'll Always Find Your Way Back Home by Miley Cyrus.", 'role': 'assistant'}]
In [27]:
import torch, numpy as np
from torch.nn.utils.rnn import pad_sequence

if model.config.pad_token_id is None:
    model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id


def calculate_perplexity(model, dataset, batch_size=8):
    """Computes perplexity on the dataset in batches while ignoring padding tokens in loss calculation."""
    device = model.device  
    model.eval()

    perplexities = []
    pad_token_id = model.config.pad_token_id  # Get pad token ID

    for i in range(0, dataset.num_rows, batch_size):
        batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
        
        # Convert input_ids and attention_mask to tensors
        input_tensors = [torch.tensor(ids, dtype=torch.long) for ids in batch["input_ids"]]
        attention_masks = [torch.tensor(mask, dtype=torch.long) for mask in batch["attention_mask"]]
        label_tensors = [torch.tensor(labels, dtype=torch.long) for labels in batch["labels"]]

        # Pad sequences to max length in batch
        input_tensors = pad_sequence(input_tensors, batch_first=True, padding_value=pad_token_id).to(device)
        attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0).to(device)  # Padding value is 0 for masks
        label_tensors = pad_sequence(label_tensors, batch_first=True, padding_value=pad_token_id).to(device)

        # Mask out padding tokens in labels
        label_tensors[label_tensors == pad_token_id] = -100  # Ignore padding tokens in loss

        with torch.no_grad():
            outputs = model(input_ids=input_tensors, attention_mask=attention_masks, labels=label_tensors)
        
        loss = outputs.loss.item()
        perplexity = torch.exp(torch.tensor(loss)).item()
        perplexities.append(perplexity)

    return perplexities


merged_model = model.merge_and_unload()

# Compute perplexities
perplexities = calculate_perplexity(merged_model, eval_dataset)

# Print first 10 values
print("First 10 Perplexities with batch_size=8:", perplexities[:10])

# Compute average perplexity
average_perplexity = np.mean(perplexities)
print("Average Perplexity:", average_perplexity)

# Compute perplexities
perplexities = calculate_perplexity(merged_model, eval_dataset, batch_size=1)

# Print first 10 values
print("First 10 Perplexities with batch_size=1:", perplexities[:10])

# Compute average perplexity
average_perplexity = np.mean(perplexities)
print("Average Perplexity:", average_perplexity)
First 10 Perplexities with batch_size=8: [3.6886284351348877, 4.2651286125183105, 4.588165283203125, 3.9980263710021973, 4.835107803344727, 3.789165496826172, 4.394216537475586, 4.687495231628418, 4.260453701019287, 4.364171028137207]
Average Perplexity: 4.309888861396096
First 10 Perplexities with batch_size=1: [5.015169143676758, 2.977295160293579, 3.9224281311035156, 4.028156280517578, 3.3580939769744873, 5.4735894203186035, 3.588649272918701, 2.7137558460235596, 2.361142873764038, 7.101618766784668]
Average Perplexity: 4.6114512202383455
In [ ]: