In [1]:
import torch
import sys
import gc
print(sys.version)
print(f"PyTorch Version: {torch.__version__}")
print(torch.cuda.is_available())
print(torch.cuda.device_count())

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(torch.cuda.get_device_name(0))

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

import bitsandbytes
import peft
import transformers

print(transformers.__version__)

print(f"bitsandbytes version: {bitsandbytes.__version__}")
print(f"peft version: {peft.__version__}")
print(torch.cuda.is_bf16_supported())
3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
PyTorch Version: 2.5.1+cu121
True
1
CUDA Version: 12.1
NVIDIA GeForce RTX 4080 Laptop GPU
4.50.0.dev0
bitsandbytes version: 0.45.3
peft version: 0.14.0
True

GRPO Fine-Tuning of Gemma 3-1B-it for Improving Its Reasoning¶

In [3]:
from transformers import AutoTokenizer
import re
import torch, numpy as np
from datasets import load_dataset, Dataset

model_name = "google/gemma-3-1b-it"
run_name="Gemma-3-1B-BBH"

SYSTEM_PROMPT = "Answer CORRECTLY and CONCISELY the questions."

dataset_name = "maveriq/bigbenchhard"

# Load a tokenizer to use its chat template
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it", trust_remote_code=True)
# add tokens:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token 
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<PAD>"})

tokenizer.padding_side = "left" # For autoregressive models

# Load and format the BBH dataset
def format_prompt(example):
    """Format the prompt using the <|user|> template Gemma-3-1B expects"""
    chat = [
        {"role": "system", "content": "You are a helpful assistant. Answer factual questions accurately and concisely."},
        {"role": "user", "content": example["input"]},
        {"role": "assistant", "content": ""}  # The model is to continue from here
    ]
    prompt = tokenizer.apply_chat_template(chat, tokenize=False)
    return {"prompt": prompt, "answer": example["target"]}


bbh_causal_judgement = load_dataset("maveriq/bigbenchhard", 'causal_judgement')
dataset = bbh_causal_judgement.shuffle(seed=137)
train_dataset = dataset["train"].select(range(167)).map(format_prompt, remove_columns=["input", 'target'])
eval_dataset = dataset["train"].select(range(167, 187)).map(format_prompt, remove_columns=["input", 'target'])

print(train_dataset[0])

print(train_dataset.column_names)
print(eval_dataset.column_names)
print (len(train_dataset), len(eval_dataset))
{'prompt': '<bos><start_of_turn>user\nYou are a helpful assistant. Answer factual questions accurately and concisely.\n\nHow would a typical person answer each of the following questions about causation?\nJim, Carol, Bob, and Nancy are researchers in a remote area, and they have a limited supply of electricity. Because of their limited supply, the electricity only comes on in the evenings from 8-9 PM, and they have to restrict who can use power on certain days. If four people turn on their lamps at the same time, the breaker will fail. The breaker will not fail if fewer people turn on their lamps at the same time. Jim, Carol, Bob, and Nancy are all allowed to use their lamps on Thursdays. This Thursday Jim turns on his lamp at 8 PM. Just then Carol turns on her lamp, Bob also turns on his lamp, and Nancy turns on her lamp. Since four people turned on their lamps at the same time, the circuit breaker failed. Did Jim turning on his lamp at 8 PM cause the circuit breaker to fail?\nOptions:\n- Yes\n- No<end_of_turn>\n<start_of_turn>model\n<end_of_turn>\n', 'answer': 'No'}
['prompt', 'answer']
['prompt', 'answer']
167 20

Reward Functions¶

In [4]:
# Reward Functions
import re
import numpy as np
from rouge_score import rouge_scorer

def extract_yes_no(text: str) -> str:
    """
    Extracts 'yes' or 'no' from the start of the response.
    Returns '' if ambiguous (contains both 'yes' and 'no').
    """
    text = text.lower()
    matches = re.findall(r'\b(yes|no)\b', text)

    if not matches:
        return ''  # No clear answer detected
    if "yes" in matches and "no" in matches:
        return ''  # Both found → ambiguous, reject
    return matches[0]  # Return the first clear "yes" or "no"

def basic_matching_reward(prompts, completions, answers) -> list[float]:
    responses = [extract_yes_no(completion) for completion in completions]
    correct_answers = [ans.strip().lower() for ans in answers]
    #print("from basic_matching_reward: responses =", responses)
    #print("from basic_matching_reward: correct_answers =", correct_answers)
    rewards = []
    
    for completion, pred, target in zip(completions, responses, correct_answers):
        full_response = completion.strip().lower()
        word_count = len(full_response.split())
        #print("from zip in basic_matching_reward: word_count =", word_count, "full_response=", full_response)

        if pred == target:
            if word_count == 1:  # Exactly "yes" or "no"
                rewards.append(1.0)
            elif word_count <= 5:
                rewards.append(0.85)
            else:
                rewards.append(0.7)  # Further penalty for excessive words
        elif pred in ["yes", "no"]:
            rewards.append(-0.75)  # Incorrect, but at least binary
        elif full_response:  
            rewards.append(-1.5)  # Not "yes" or "no" → heavily penalized
        else:
            rewards.append(-1.75)  # Empty response → worst-case scenario

    return rewards

def rouge_reward_func(prompts, completions, answer) -> list[float]:
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    responses = [completion for completion in completions]
    
    scores = [scorer.score(ref_answer, response)['rougeL'].fmeasure 
              for response, ref_answer in zip(responses, answer)]
    
    # Penalize extremely common answers
    avg_score = np.mean(scores)
    return [s * (0.8 + 0.4 * (s < avg_score)) for s in scores]  # Slightly reduce too-high scores


def length_similarity_reward_func(prompts, completions, answer) -> list[float]:
    responses = [completion for completion in completions]
    answer_lengths = [len(ans.split()) for ans in answer]
    response_lengths = [len(resp.split()) for resp in responses]
    
    return [np.exp(-abs(resp_len - ans_len) / max(1, ans_len))  # Exponential decay
            for resp_len, ans_len in zip(response_lengths, answer_lengths)]

def logic_explanation_reward(prompts, completions, answers):
    """Rewards responses that provide correct reasoning beyond just 'yes' or 'no'."""
    response_texts = [completion for completion in completions]
    
    reward_scores = []
    for response in response_texts:
        if "because" in response or "due to" in response or "contradicts" in response:
            reward_scores.append(1.0)  # Strong explanation signal
        elif len(response.split()) > 10:  # Longer responses *without* reasoning terms
            reward_scores.append(-0.5)  # Penalize unnecessary long responses
        else:
            reward_scores.append(0.0)  # Neutral
    
    return reward_scores


#def normalize(scores):
    #scores = np.array(scores, dtype=np.float64)
    #mean = np.mean(scores)
    #std = np.std(scores) + 1e-8  # Add small value to prevent division by zero
    #return (scores - mean) / std
def normalize(scores):
    scores = np.array(scores, dtype=np.float64)
    min_val, max_val = np.min(scores), np.max(scores)
    return (scores - min_val) / (max_val - min_val + 1e-8)  # Scale to [0,1]

def combined_reward(prompts, completions, answer):
    #print("from combined_reward: completions =", completions)
    matching_scores = basic_matching_reward(prompts, completions, answer)
    rouge_scores = rouge_reward_func(prompts, completions, answer)
    length_scores = length_similarity_reward_func(prompts, completions, answer)
    explanation_scores = logic_explanation_reward(prompts, completions, answer)

    #print("Before Normalization - matching_scores, rouge_scores, length_scores:", 
    #      matching_scores, rouge_scores, length_scores)

    # Apply normalization
    matching_scores = normalize(matching_scores)
    rouge_scores = normalize(rouge_scores)
    length_scores = normalize(length_scores)
    explanation_scores = normalize(explanation_scores)

    #print("After Normalization - matching_scores, rouge_scores, length_scores:", 
    #      matching_scores, rouge_scores, length_scores)
    return [
        0.45 * match + 0.25 * rouge + 0.15 * length + 0.15 * explanation  
        for match, rouge, length, explanation in zip(matching_scores, rouge_scores, length_scores, explanation_scores)
    ]

Training settings¶

In [5]:
from transformers import AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# trainer settings
output_dir = "./Gemma-3-1B_grpo_on_bbh24"
training_args = GRPOConfig(
    output_dir=output_dir,
    learning_rate=3e-5,
    adam_beta1=0.9,
    adam_beta2=0.99,
    optim="paged_adamw_32bit",
    weight_decay=0.01,
    warmup_ratio=0.05,
    max_steps=1601,
    lr_scheduler_type='cosine',    
    bf16=True,
    per_device_train_batch_size=8, #######
    gradient_accumulation_steps=2, 
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=64,
    temperature=0.7,
    #num_train_epochs=6, 
    log_on_each_node=False,    
    report_to="none",
    logging_steps=10,
    save_steps=50,
    max_grad_norm=1.2, # gradient clipping
    #eval_strategy="steps",
    #eval_steps=10,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    #load_best_model_at_end=True,  # Crucial for saving best model
    #metric_for_best_model="eval_loss"    
)


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto", ########
    trust_remote_code=True  
)

# Prepare LoRA Configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules = ['q_proj', 'o_proj', 'k_proj', 'v_proj', 'gate_proj', 'up_proj', 'down_proj'], # Layers to target
    task_type="CAUSAL_LM",
    lora_dropout=0.1,
    bias="none",
)

from peft import get_peft_model
model = get_peft_model(model, peft_config)

# Double-check if model is fully on GPU
print(model.hf_device_map)

model.print_trainable_parameters()
{'': 0}
trainable params: 13,045,760 || all params: 1,012,931,712 || trainable%: 1.2879

Training¶

In [6]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[lambda prompts, completions, answer: combined_reward(prompts, completions, answer)],
    args=training_args,
    train_dataset=train_dataset,
    #eval_dataset=tokenized_eval_dataset,
    peft_config=peft_config,
)

# Train!
trainer.train()

print (torch.cuda.memory_summary())

# Save QLoRA weights
trainer.model.save_pretrained("Gemma-3-1B-qlora", safe_serialization=True)
#trainer.save_model("best_Gemma-3-1B-qlora")

# Accessing the logs (after training):
log_history = trainer.state.log_history
# Plot the loss vs steps
import matplotlib.pyplot as plt

# Extract training loss
train_steps = [entry["step"] for entry in trainer.state.log_history if "loss" in entry]
train_losses = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]

# Extract validation loss
#val_steps = [entry["step"] for entry in trainer.state.log_history if "eval_loss" in entry]
#val_losses = [entry["eval_loss"] for entry in trainer.state.log_history if "eval_loss" in entry]

# Plot both training and validation loss
plt.plot(train_steps, train_losses, label="Training Loss", linestyle="-")
#plt.plot(val_steps, val_losses, label="Validation Loss", linestyle="--")

plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss vs. Steps")
plt.legend()
plt.show()

#trainer.eval_dataset = eval_dataset
#print("Evaluation on test set:", trainer.evaluate())
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\mapping.py:185: UserWarning: The PEFT config's `base_model_name_or_path` was renamed from 'google/gemma-3-1b-it' to 'None'. Please ensure that the correct base model is loaded when loading this checkpoint.
  warnings.warn(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
[1601/1601 9:15:34, Epoch 145/161]
Step Training Loss
10 0.005200
20 0.006700
30 0.005600
40 0.007600
50 0.036100
60 0.032300
70 0.049100
80 0.975500
90 0.117700
100 0.220000
110 0.123500
120 0.067200
130 0.585100
140 0.084800
150 0.216800
160 0.540600
170 0.728100
180 0.938700
190 0.375700
200 0.703100
210 0.727100
220 0.768400
230 0.903900
240 0.544300
250 0.284200
260 0.208800
270 0.278200
280 0.357200
290 119.257700
300 0.320500
310 0.671800
320 0.353000
330 0.963300
340 0.300900
350 0.089300
360 0.184600
370 0.181000
380 0.148400
390 0.145400
400 0.246600
410 0.901100
420 0.293900
430 0.651500
440 0.816300
450 0.703300
460 1.139500
470 0.893000
480 0.358200
490 0.668500
500 0.760300
510 17.114000
520 0.560500
530 0.323600
540 0.304600
550 0.308000
560 0.223900
570 0.412700
580 0.504400
590 3.308100
600 0.986400
610 0.747700
620 0.724300
630 0.877700
640 37.628100
650 0.748200
660 105.165200
670 5.688700
680 0.671900
690 0.572100
700 0.637400
710 0.878800
720 1.018200
730 2.271400
740 0.686100
750 0.782000
760 0.632000
770 0.545200
780 0.683600
790 0.535300
800 0.550400
810 0.386800
820 0.376800
830 0.460100
840 0.279600
850 1.541100
860 0.362800
870 0.452300
880 0.424500
890 0.429600
900 0.427900
910 8.069800
920 0.716800
930 0.661700
940 0.673800
950 0.670500
960 0.448400
970 0.486700
980 0.451300
990 0.375400
1000 0.487800
1010 0.457500
1020 0.381300
1030 0.571600
1040 0.541700
1050 0.515800
1060 1.693800
1070 0.418100
1080 0.389900
1090 0.616000
1100 0.453500
1110 0.495500
1120 0.553900
1130 0.507400
1140 0.511200
1150 0.404900
1160 0.530800
1170 0.287900
1180 0.290300
1190 0.402500
1200 0.335000
1210 0.427400
1220 0.411300
1230 0.414100
1240 0.518500
1250 0.456400
1260 0.407900
1270 0.913000
1280 0.366600
1290 0.262700
1300 0.423800
1310 0.353400
1320 0.317800
1330 0.384500
1340 0.399000
1350 0.318000
1360 0.381100
1370 0.315700
1380 0.325400
1390 1.366600
1400 0.377700
1410 0.364400
1420 0.315800
1430 0.281800
1440 0.290200
1450 0.316800
1460 0.600000
1470 0.243400
1480 0.357200
1490 0.365800
1500 0.323600
1510 0.313000
1520 0.470400
1530 0.324600
1540 0.297300
1550 0.436000
1560 0.330100
1570 0.285600
1580 0.352300
1590 0.356200
1600 0.315200

C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
C:\Users\alexa\miniconda3\envs\grpo_env\lib\site-packages\peft\tuners\tuners_utils.py:197: FutureWarning: `num_logits_to_keep` is deprecated and will be removed in version 4.50 for `Gemma3ForCausalLM.forward`. Use `logits_to_keep` instead.
  return self.model.forward(*args, **kwargs)
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   2006 MiB |   7855 MiB |   1703 TiB |   1703 TiB |
|       from large pool |   1893 MiB |   7691 MiB |   1646 TiB |   1646 TiB |
|       from small pool |    113 MiB |    188 MiB |     56 TiB |     56 TiB |
|---------------------------------------------------------------------------|
| Active memory         |   2006 MiB |   7855 MiB |   1703 TiB |   1703 TiB |
|       from large pool |   1893 MiB |   7691 MiB |   1646 TiB |   1646 TiB |
|       from small pool |    113 MiB |    188 MiB |     56 TiB |     56 TiB |
|---------------------------------------------------------------------------|
| Requested memory      |   2006 MiB |   7855 MiB |   1702 TiB |   1702 TiB |
|       from large pool |   1893 MiB |   7691 MiB |   1645 TiB |   1645 TiB |
|       from small pool |    112 MiB |    188 MiB |     56 TiB |     56 TiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   9316 MiB |   9316 MiB |   9316 MiB |      0 B   |
|       from large pool |   9124 MiB |   9124 MiB |   9124 MiB |      0 B   |
|       from small pool |    192 MiB |    192 MiB |    192 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory | 101486 KiB |   4924 MiB |   2262 TiB |   2262 TiB |
|       from large pool |  47232 KiB |   4872 MiB |   2204 TiB |   2204 TiB |
|       from small pool |  54254 KiB |     54 MiB |     58 TiB |     58 TiB |
|---------------------------------------------------------------------------|
| Allocations           |    1280    |    2055    |  570315 K  |  570313 K  |
|       from large pool |     133    |     206    |   55098 K  |   55098 K  |
|       from small pool |    1147    |    1890    |  515216 K  |  515215 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1280    |    2055    |  570315 K  |  570313 K  |
|       from large pool |     133    |     206    |   55098 K  |   55098 K  |
|       from small pool |    1147    |    1890    |  515216 K  |  515215 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     116    |     116    |     116    |       0    |
|       from large pool |      20    |      20    |      20    |       0    |
|       from small pool |      96    |      96    |      96    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     114    |     266    |  202115 K  |  202115 K  |
|       from large pool |       4    |      36    |   20769 K  |   20769 K  |
|       from small pool |     110    |     253    |  181346 K  |  181346 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

No description has been provided for this image

Evaluation¶

In [6]:
import re

def extract_assistant_yes_no(generated_text):
    pattern = r"Assistant:\s*(yes|no)"
    match = re.search(pattern, generated_text, re.IGNORECASE) #added re.IGNORECASE to capture Yes, YES, No etc.
    if match:
        return match.group(0)
    else:
        return None

tp = 0
tn = 0
fp = 0
fn = 0
for example in eval_dataset:
    #prompt_text1 = "".join([d['content'] for d in example["prompt"]])
    prompt_text1 = "".join([d for d in example["prompt"]]) + "\n\nAssistant:"
    #tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True).to('cuda')
    tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True).to('cuda')
    #generated_ids = model.generate(**tokenized_input)
    generated_ids = model.generate(**tokenized_input, max_new_tokens=50, do_sample=True, temperature=0.7, top_p=0.9)
    generated_text1 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    predicted = extract_yes_no(generated_text1)
    actual = example["answer"].lower()
    predicted_with_assistant = extract_assistant_yes_no(generated_text1)

    if predicted_with_assistant:
        predicted = predicted_with_assistant.split(": ")[1].lower() # extract just yes or no.
    else:
        predicted = "none" #handle the case where the prediction failed.

    actual = example["answer"].lower()

    if actual == "yes":
        if predicted == "yes":
            tp += 1
        elif predicted == "no":
            fn += 1
        else:
            fn +=1 # handle cases where the prediction was not yes or no.
    elif actual == "no":
        if predicted == "no":
            tn += 1
        elif predicted == "yes":
            fp += 1
        else:
            fp +=1 # handle cases where the prediction was not yes or no.
    else:
        print(f"Error: actual answer is neither yes nor no: {actual}")
print("Confusion matrix for eval_dataset", {"TP": tp, "TN": tn, "FP": fp, "FN": fn})
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Confusion matrix for eval_dataset {'TP': 8, 'TN': 2, 'FP': 5, 'FN': 5}
In [7]:
model0 = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto", ########
    trust_remote_code=True  
)

tp = 0
tn = 0
fp = 0
fn = 0
for example in eval_dataset:
    #prompt_text1 = "".join([d['content'] for d in example["prompt"]])
    prompt_text1 = "".join([d for d in example["prompt"]]) + "\n\nAssistant:"
    #tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True).to('cuda')
    tokenized_input = tokenizer(prompt_text1, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True).to('cuda')
    #generated_ids = model.generate(**tokenized_input)
    generated_ids = model0.generate(**tokenized_input, max_new_tokens=50, do_sample=True, temperature=0.7, top_p=0.9)
    generated_text1 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    predicted = extract_yes_no(generated_text1)
    actual = example["answer"].lower()
    predicted_with_assistant = extract_assistant_yes_no(generated_text1)

    if predicted_with_assistant:
        predicted = predicted_with_assistant.split(": ")[1].lower() # extract just yes or no.
    else:
        predicted = "none" #handle the case where the prediction failed.

    actual = example["answer"].lower()

    if actual == "yes":
        if predicted == "yes":
            tp += 1
        elif predicted == "no":
            fn += 1
        else:
            fn +=1 # handle cases where the prediction was not yes or no.
    elif actual == "no":
        if predicted == "no":
            tn += 1
        elif predicted == "yes":
            fp += 1
        else:
            fp +=1 # handle cases where the prediction was not yes or no.
    else:
        print(f"Error: actual answer is neither yes nor no: {actual}")
print("Confusion matrix for 'raw' model without fine-tuning for eval_dataset", {"TP": tp, "TN": tn, "FP": fp, "FN": fn})
WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.
Confusion matrix for 'raw' model without fine-tuning for eval_dataset {'TP': 4, 'TN': 0, 'FP': 7, 'FN': 9}

The confusion matrix improved from {'TP': 4, 'TN': 0, 'FP': 7, 'FN': 9} for the raw model to {'TP': 8, 'TN': 2, 'FP': 5, 'FN': 5} after training, resulting in an accuracy increase from 20% to 50%. However, due to the small dataset size (187 examples in total, used for both training and evaluation), it's possible that this accuracy increase is not absolutely certain. Furthermore, the model exhibits a bias towards "yes" answers. Is it an artifact or rather a genuine improvement in reasoning?

In [15]:
test_prompt = "What is 7 * 13?"
inputs = tokenizer(test_prompt, return_tensors="pt", padding=True, truncation=True).to('cuda')
outputs = model.generate(**inputs, max_new_tokens=60)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
What is 7 * 13?

7 * 13 = 91

So the answer is 91.

In [16]:
test_prompt = "What is 7 * 13?"
inputs = tokenizer(test_prompt, return_tensors="pt", padding=True, truncation=True).to('cuda')
outputs = model0.generate(**inputs, max_new_tokens=60)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
What is 7 * 13?
$$ 7 \times 13 = 7 \times (10 + 3) = (7 \times 10) + (7 \times 3) = 70 + 21 = 91 $$
Alternatively,
$$ 7 \times