In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(torch.cuda.get_device_name(0))
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import bitsandbytes
import peft
import transformers
print(transformers.__version__)
print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)] PyTorch Version: 2.5.1+cu121 True 1 CUDA Version: 12.1 NVIDIA GeForce RTX 4080 Laptop GPU 4.50.0.dev0 bitsandbytes version: 0.45.3 peft version: 0.14.0 True
GRPO Fine-Tuning of Gemma 3-1B-it for Improving Its Reasoning¶
In [3]:
from transformers import AutoTokenizer
import re
import torch, numpy as np
from datasets import load_dataset, Dataset
model_name = "google/gemma-3-1b-it"
run_name="Gemma-3-1B-BBH"
SYSTEM_PROMPT = "Answer CORRECTLY and CONCISELY the questions."
dataset_name = "maveriq/bigbenchhard"
# Load a tokenizer to use its chat template
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it", trust_remote_code=True)
# add tokens:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
tokenizer.padding_side = "left" # For autoregressive models
# Load and format the BBH dataset
def format_prompt(example):
"""Format the prompt using the <|user|> template Gemma-3-1B expects"""
chat = [
{"role": "system", "content": "You are a helpful assistant. Answer factual questions accurately and concisely."},
{"role": "user", "content": example["input"]},
{"role": "assistant", "content": ""} # The model is to continue from here
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False)
return {"prompt": prompt, "answer": example["target"]}
bbh_causal_judgement = load_dataset("maveriq/bigbenchhard", 'causal_judgement')
dataset = bbh_causal_judgement.shuffle(seed=137)
train_dataset = dataset["train"].select(range(167)).map(format_prompt, remove_columns=["input", 'target'])
eval_dataset = dataset["train"].select(range(167, 187)).map(format_prompt, remove_columns=["input", 'target'])
print(train_dataset[0])
print(train_dataset.column_names)
print(eval_dataset.column_names)
print (len(train_dataset), len(eval_dataset))
{'prompt': '<bos><start_of_turn>user\nYou are a helpful assistant. Answer factual questions accurately and concisely.\n\nHow would a typical person answer each of the following questions about causation?\nJim, Carol, Bob, and Nancy are researchers in a remote area, and they have a limited supply of electricity. Because of their limited supply, the electricity only comes on in the evenings from 8-9 PM, and they have to restrict who can use power on certain days. If four people turn on their lamps at the same time, the breaker will fail. The breaker will not fail if fewer people turn on their lamps at the same time. Jim, Carol, Bob, and Nancy are all allowed to use their lamps on Thursdays. This Thursday Jim turns on his lamp at 8 PM. Just then Carol turns on her lamp, Bob also turns on his lamp, and Nancy turns on her lamp. Since four people turned on their lamps at the same time, the circuit breaker failed. Did Jim turning on his lamp at 8 PM cause the circuit breaker to fail?\nOptions:\n- Yes\n- No<end_of_turn>\n<start_of_turn>model\n<end_of_turn>\n', 'answer': 'No'} ['prompt', 'answer'] ['prompt', 'answer'] 167 20
Reward Functions¶
In [4]:
# Reward Functions
import re
import numpy as np
from rouge_score import rouge_scorer
def extract_yes_no(text: str) -> str:
"""
Extracts 'yes' or 'no' from the start of the response.
Returns '' if ambiguous (contains both 'yes' and 'no').
"""
text = text.lower()
matches = re.findall(r'\b(yes|no)\b', text)
if not matches:
return '' # No clear answer detected
if "yes" in matches and "no" in matches:
return '' # Both found → ambiguous, reject
return matches[0] # Return the first clear "yes" or "no"
def basic_matching_reward(prompts, completions, answers) -> list[float]:
responses = [extract_yes_no(completion) for completion in completions]
correct_answers = [ans.strip().lower() for ans in answers]
#print("from basic_matching_reward: responses =", responses)
#print("from basic_matching_reward: correct_answers =", correct_answers)
rewards = []
for completion, pred, target in zip(completions, responses, correct_answers):
full_response = completion.strip().lower()
word_count = len(full_response.split())
#print("from zip in basic_matching_reward: word_count =", word_count, "full_response=", full_response)
if pred == target:
if word_count == 1: # Exactly "yes" or "no"
rewards.append(1.0)
elif word_count <= 5:
rewards.append(0.85)
else:
rewards.append(0.7) # Further penalty for excessive words
elif pred in ["yes", "no"]:
rewards.append(-0.75) # Incorrect, but at least binary
elif full_response:
rewards.append(-1.5) # Not "yes" or "no" → heavily penalized
else:
rewards.append(-1.75) # Empty response → worst-case scenario
return rewards
def rouge_reward_func(prompts, completions, answer) -> list[float]:
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
responses = [completion for completion in completions]
scores = [scorer.score(ref_answer, response)['rougeL'].fmeasure
for response, ref_answer in zip(responses, answer)]
# Penalize extremely common answers
avg_score = np.mean(scores)
return [s * (0.8 + 0.4 * (s < avg_score)) for s in scores] # Slightly reduce too-high scores
def length_similarity_reward_func(prompts, completions, answer) -> list[float]:
responses = [completion for completion in completions]
answer_lengths = [len(ans.split()) for ans in answer]
response_lengths = [len(resp.split()) for resp in responses]
return [np.exp(-abs(resp_len - ans_len) / max(1, ans_len)) # Exponential decay
for resp_len, ans_len in zip(response_lengths, answer_lengths)]
def logic_explanation_reward(prompts, completions, answers):
"""Rewards responses that provide correct reasoning beyond just 'yes' or 'no'."""
response_texts = [completion for completion in completions]
reward_scores = []
for response in response_texts:
if "because" in response or "due to" in response or "contradicts" in response:
reward_scores.append(1.0) # Strong explanation signal
elif len(response.split()) > 10: # Longer responses *without* reasoning terms
reward_scores.append(-0.5) # Penalize unnecessary long responses
else:
reward_scores.append(0.0) # Neutral
return reward_scores
#def normalize(scores):
#scores = np.array(scores, dtype=np.float64)
#mean = np.mean(scores)
#std = np.std(scores) + 1e-8 # Add small value to prevent division by zero
#return (scores - mean) / std
def normalize(scores):
scores = np.array(scores, dtype=np.float64)
min_val, max_val = np.min(scores), np.max(scores)
return (scores - min_val) / (max_val - min_val + 1e-8) # Scale to [0,1]
def combined_reward(prompts, completions, answer):
#print("from combined_reward: completions =", completions)
matching_scores = basic_matching_reward(prompts, completions, answer)
rouge_scores = rouge_reward_func(prompts, completions, answer)
length_scores = length_similarity_reward_func(prompts, completions, answer)
explanation_scores = logic_explanation_reward(prompts, completions, answer)
#print("Before Normalization - matching_scores, rouge_scores, length_scores:",
# matching_scores, rouge_scores, length_scores)
# Apply normalization
matching_scores = normalize(matching_scores)
rouge_scores = normalize(rouge_scores)
length_scores = normalize(length_scores)
explanation_scores = normalize(explanation_scores)
#print("After Normalization - matching_scores, rouge_scores, length_scores:",
# matching_scores, rouge_scores, length_scores)
return [
0.45 * match + 0.25 * rouge + 0.15 * length + 0.15 * explanation
for match, rouge, length, explanation in zip(matching_scores, rouge_scores, length_scores, explanation_scores)
]
Training settings¶
In [5]:
from transformers import AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
# trainer settings
output_dir = "./Gemma-3-1B_grpo_on_bbh24"
training_args = GRPOConfig(
output_dir=output_dir,
learning_rate=3e-5,
adam_beta1=0.9,
adam_beta2=0.99,
optim="paged_adamw_32bit",
weight_decay=0.01,
warmup_ratio=0.05,
max_steps=1601,
lr_scheduler_type='cosine',
bf16=True,
per_device_train_batch_size=8, #######
gradient_accumulation_steps=2,
num_generations=4,
max_prompt_length=256,
max_completion_length=64,
temperature=0.7,
#num_train_epochs=6,
log_on_each_node=False,
report_to="none",
logging_steps=10,
save_steps=50,
max_grad_norm=1.2, # gradient clipping
#eval_strategy="steps",
#eval_steps=10,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
#load_best_model_at_end=True, # Crucial for saving best model
#metric_for_best_model="eval_loss"
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", ########
trust_remote_code=True
)
# Prepare LoRA Configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules = ['q_proj', 'o_proj', 'k_proj', 'v_proj', 'gate_proj', 'up_proj', 'down_proj'], # Layers to target
task_type="CAUSAL_LM",
lora_dropout=0.1,
bias="none",
)
from peft import get_peft_model
model = get_peft_model(model, peft_config)
# Double-check if model is fully on GPU
print(model.hf_device_map)
model.print_trainable_parameters()
{'': 0} trainable params: 13,045,760 || all params: 1,012,931,712 || trainable%: 1.2879
Training¶
In [6]:
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[lambda prompts, completions, answer: combined_reward(prompts, completions, answer)],
args=training_args,
train_dataset=train_dataset,
#eval_dataset=tokenized_eval_dataset,
peft_config=peft_config,
)
# Train!
trainer.train()
print (torch.cuda.memory_summary())
# Save QLoRA weights
trainer.model.save_pretrained("Gemma-3-1B-qlora", safe_serialization=True)
#trainer.save_model("best_Gemma-3-1B-qlora")
# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt
# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]
# Extract validation loss
#val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
#val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]
# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
#plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss vs. Steps")
plt.legend()
plt.show()
#trainer.eval_dataset = eval_dataset
#print("Evaluation on test set:", trainer.evaluate())
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\mapping.py:185: UserWarning: The PEFT config's `base_model_name_or_path` was renamed from 'google/gemma-3-1b-it' to 'None'. Please ensure that the correct base model is loaded when loading this checkpoint. warnings.warn( No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead. It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`. `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`. C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs)
[1601/1601 9:15:34, Epoch 145/161]
Step | Training Loss |
---|---|
10 | 0.005200 |
20 | 0.006700 |
30 | 0.005600 |
40 | 0.007600 |
50 | 0.036100 |
60 | 0.032300 |
70 | 0.049100 |
80 | 0.975500 |
90 | 0.117700 |
100 | 0.220000 |
110 | 0.123500 |
120 | 0.067200 |
130 | 0.585100 |
140 | 0.084800 |
150 | 0.216800 |
160 | 0.540600 |
170 | 0.728100 |
180 | 0.938700 |
190 | 0.375700 |
200 | 0.703100 |
210 | 0.727100 |
220 | 0.768400 |
230 | 0.903900 |
240 | 0.544300 |
250 | 0.284200 |
260 | 0.208800 |
270 | 0.278200 |
280 | 0.357200 |
290 | 119.257700 |
300 | 0.320500 |
310 | 0.671800 |
320 | 0.353000 |
330 | 0.963300 |
340 | 0.300900 |
350 | 0.089300 |
360 | 0.184600 |
370 | 0.181000 |
380 | 0.148400 |
390 | 0.145400 |
400 | 0.246600 |
410 | 0.901100 |
420 | 0.293900 |
430 | 0.651500 |
440 | 0.816300 |
450 | 0.703300 |
460 | 1.139500 |
470 | 0.893000 |
480 | 0.358200 |
490 | 0.668500 |
500 | 0.760300 |
510 | 17.114000 |
520 | 0.560500 |
530 | 0.323600 |
540 | 0.304600 |
550 | 0.308000 |
560 | 0.223900 |
570 | 0.412700 |
580 | 0.504400 |
590 | 3.308100 |
600 | 0.986400 |
610 | 0.747700 |
620 | 0.724300 |
630 | 0.877700 |
640 | 37.628100 |
650 | 0.748200 |
660 | 105.165200 |
670 | 5.688700 |
680 | 0.671900 |
690 | 0.572100 |
700 | 0.637400 |
710 | 0.878800 |
720 | 1.018200 |
730 | 2.271400 |
740 | 0.686100 |
750 | 0.782000 |
760 | 0.632000 |
770 | 0.545200 |
780 | 0.683600 |
790 | 0.535300 |
800 | 0.550400 |
810 | 0.386800 |
820 | 0.376800 |
830 | 0.460100 |
840 | 0.279600 |
850 | 1.541100 |
860 | 0.362800 |
870 | 0.452300 |
880 | 0.424500 |
890 | 0.429600 |
900 | 0.427900 |
910 | 8.069800 |
920 | 0.716800 |
930 | 0.661700 |
940 | 0.673800 |
950 | 0.670500 |
960 | 0.448400 |
970 | 0.486700 |
980 | 0.451300 |
990 | 0.375400 |
1000 | 0.487800 |
1010 | 0.457500 |
1020 | 0.381300 |
1030 | 0.571600 |
1040 | 0.541700 |
1050 | 0.515800 |
1060 | 1.693800 |
1070 | 0.418100 |
1080 | 0.389900 |
1090 | 0.616000 |
1100 | 0.453500 |
1110 | 0.495500 |
1120 | 0.553900 |
1130 | 0.507400 |
1140 | 0.511200 |
1150 | 0.404900 |
1160 | 0.530800 |
1170 | 0.287900 |
1180 | 0.290300 |
1190 | 0.402500 |
1200 | 0.335000 |
1210 | 0.427400 |
1220 | 0.411300 |
1230 | 0.414100 |
1240 | 0.518500 |
1250 | 0.456400 |
1260 | 0.407900 |
1270 | 0.913000 |
1280 | 0.366600 |
1290 | 0.262700 |
1300 | 0.423800 |
1310 | 0.353400 |
1320 | 0.317800 |
1330 | 0.384500 |
1340 | 0.399000 |
1350 | 0.318000 |
1360 | 0.381100 |
1370 | 0.315700 |
1380 | 0.325400 |
1390 | 1.366600 |
1400 | 0.377700 |
1410 | 0.364400 |
1420 | 0.315800 |
1430 | 0.281800 |
1440 | 0.290200 |
1450 | 0.316800 |
1460 | 0.600000 |
1470 | 0.243400 |
1480 | 0.357200 |
1490 | 0.365800 |
1500 | 0.323600 |
1510 | 0.313000 |
1520 | 0.470400 |
1530 | 0.324600 |
1540 | 0.297300 |
1550 | 0.436000 |
1560 | 0.330100 |
1570 | 0.285600 |
1580 | 0.352300 |
1590 | 0.356200 |
1600 | 0.315200 |
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs) C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead. return self.model.forward(*args, **kwargs)
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 2006 MiB | 7855 MiB | 1703 TiB | 1703 TiB | | from large pool | 1893 MiB | 7691 MiB | 1646 TiB | 1646 TiB | | from small pool | 113 MiB | 188 MiB | 56 TiB | 56 TiB | |---------------------------------------------------------------------------| | Active memory | 2006 MiB | 7855 MiB | 1703 TiB | 1703 TiB | | from large pool | 1893 MiB | 7691 MiB | 1646 TiB | 1646 TiB | | from small pool | 113 MiB | 188 MiB | 56 TiB | 56 TiB | |---------------------------------------------------------------------------| | Requested memory | 2006 MiB | 7855 MiB | 1702 TiB | 1702 TiB | | from large pool | 1893 MiB | 7691 MiB | 1645 TiB | 1645 TiB | | from small pool | 112 MiB | 188 MiB | 56 TiB | 56 TiB | |---------------------------------------------------------------------------| | GPU reserved memory | 9316 MiB | 9316 MiB | 9316 MiB | 0 B | | from large pool | 9124 MiB | 9124 MiB | 9124 MiB | 0 B | | from small pool | 192 MiB | 192 MiB | 192 MiB | 0 B | |---------------------------------------------------------------------------| | Non-releasable memory | 101486 KiB | 4924 MiB | 2262 TiB | 2262 TiB | | from large pool | 47232 KiB | 4872 MiB | 2204 TiB | 2204 TiB | | from small pool | 54254 KiB | 54 MiB | 58 TiB | 58 TiB | |---------------------------------------------------------------------------| | Allocations | 1280 | 2055 | 570315 K | 570313 K | | from large pool | 133 | 206 | 55098 K | 55098 K | | from small pool | 1147 | 1890 | 515216 K | 515215 K | |---------------------------------------------------------------------------| | Active allocs | 1280 | 2055 | 570315 K | 570313 K | | from large pool | 133 | 206 | 55098 K | 55098 K | | from small pool | 1147 | 1890 | 515216 K | 515215 K | |---------------------------------------------------------------------------| | GPU reserved segments | 116 | 116 | 116 | 0 | | from large pool | 20 | 20 | 20 | 0 | | from small pool | 96 | 96 | 96 | 0 | |---------------------------------------------------------------------------| | Non-releasable allocs | 114 | 266 | 202115 K | 202115 K | | from large pool | 4 | 36 | 20769 K | 20769 K | | from small pool | 110 | 253 | 181346 K | 181346 K | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
Evaluation¶
In [6]:
import re
def extract_assistant_yes_no(generated_text):
pattern = r"Assistant:\s*(yes|no)"
match = re.search(pattern, generated_text, re.IGNORECASE) #added re.IGNORECASE to capture Yes, YES, No etc.
if match:
return match.group(0)
else:
return None
tp = 0
tn = 0
fp = 0
fn = 0
for example in eval_dataset:
#prompt_text1 = "".join([d['content'] for d in example["prompt"]])
prompt_text1 = "".join([d for d in example["prompt"]]) + "\n\nAssistant:"
#tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True).to('cuda')
tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True).to('cuda')
#generated_ids = model.generate(**tokenized_input)
generated_ids = model.generate(**tokenized_input, max_new_tokens=50, do_sample=True, temperature=0.7, top_p=0.9)
generated_text1 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
predicted = extract_yes_no(generated_text1)
actual = example["answer"].lower()
predicted_with_assistant = extract_assistant_yes_no(generated_text1)
if predicted_with_assistant:
predicted = predicted_with_assistant.split(": ")[1].lower() # extract just yes or no.
else:
predicted = "none" #handle the case where the prediction failed.
actual = example["answer"].lower()
if actual == "yes":
if predicted == "yes":
tp += 1
elif predicted == "no":
fn += 1
else:
fn +=1 # handle cases where the prediction was not yes or no.
elif actual == "no":
if predicted == "no":
tn += 1
elif predicted == "yes":
fp += 1
else:
fp +=1 # handle cases where the prediction was not yes or no.
else:
print(f"Error: actual answer is neither yes nor no: {actual}")
print("Confusion matrix for eval_dataset", {"TP": tp, "TN": tn, "FP": fp, "FN": fn})
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Confusion matrix for eval_dataset {'TP': 8, 'TN': 2, 'FP': 5, 'FN': 5}
In [7]:
model0 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", ########
trust_remote_code=True
)
tp = 0
tn = 0
fp = 0
fn = 0
for example in eval_dataset:
#prompt_text1 = "".join([d['content'] for d in example["prompt"]])
prompt_text1 = "".join([d for d in example["prompt"]]) + "\n\nAssistant:"
#tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True).to('cuda')
tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True).to('cuda')
#generated_ids = model.generate(**tokenized_input)
generated_ids = model0.generate(**tokenized_input, max_new_tokens=50, do_sample=True, temperature=0.7, top_p=0.9)
generated_text1 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
predicted = extract_yes_no(generated_text1)
actual = example["answer"].lower()
predicted_with_assistant = extract_assistant_yes_no(generated_text1)
if predicted_with_assistant:
predicted = predicted_with_assistant.split(": ")[1].lower() # extract just yes or no.
else:
predicted = "none" #handle the case where the prediction failed.
actual = example["answer"].lower()
if actual == "yes":
if predicted == "yes":
tp += 1
elif predicted == "no":
fn += 1
else:
fn +=1 # handle cases where the prediction was not yes or no.
elif actual == "no":
if predicted == "no":
tn += 1
elif predicted == "yes":
fp += 1
else:
fp +=1 # handle cases where the prediction was not yes or no.
else:
print(f"Error: actual answer is neither yes nor no: {actual}")
print("Confusion matrix for 'raw' model without fine-tuning for eval_dataset", {"TP": tp, "TN": tn, "FP": fp, "FN": fn})
WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.
Confusion matrix for 'raw' model without fine-tuning for eval_dataset {'TP': 4, 'TN': 0, 'FP': 7, 'FN': 9}
The confusion matrix improved from {'TP': 4, 'TN': 0, 'FP': 7, 'FN': 9} for the raw model to {'TP': 8, 'TN': 2, 'FP': 5, 'FN': 5} after training, resulting in an accuracy increase from 20% to 50%. However, due to the small dataset size (187 examples in total, used for both training and evaluation), it's possible that this accuracy increase is not absolutely certain. Furthermore, the model exhibits a bias towards "yes" answers. Is it an artifact or rather a genuine improvement in reasoning?
In [15]:
test_prompt = "What is 7 * 13?"
inputs = tokenizer(test_prompt, return_tensors="pt", padding=True, truncation=True).to('cuda')
outputs = model.generate(**inputs, max_new_tokens=60)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
What is 7 * 13? 7 * 13 = 91 So the answer is 91.
In [16]:
test_prompt = "What is 7 * 13?"
inputs = tokenizer(test_prompt, return_tensors="pt", padding=True, truncation=True).to('cuda')
outputs = model0.generate(**inputs, max_new_tokens=60)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
What is 7 * 13? $$ 7 \times 13 = 7 \times (10 + 3) = (7 \times 10) + (7 \times 3) = 70 + 21 = 91 $$ Alternatively, $$ 7 \times