大语言模型 (LLM) 的强大能力毋庸置疑,但其庞大的参数量也带来了巨大的计算资源需求。如何在有限的硬件条件下,例如单张消费级显卡 T4 上,高效地进行大模型微调,成为了许多开发者关注的焦点。

本文将基于一个实用的 Notebook 示例,深入探讨如何使用 GRPO (Gradient Ratio Policy Optimization) 算法,在单张 T4 GPU 上对 Qwen2.5-0.5B 这一开源大模型进行全参数微调。我们将详细解析代码,并解释背后的优化策略,帮助读者理解如何在资源受限的环境下也能玩转大模型微调。

性能提升

通过本文介绍的方法,我们可以在单 T4 GPU 上,仅用约 150 步(约 30 分钟)的训练,就将 Qwen2.5-0.5B-Instruct 模型在 GSM8K 数据集上的数学解题能力从 22.4% 提升至 48.6%

训练过程中的奖励 (reward) 曲线图如下所示,可以看到模型性能在训练过程中稳步提升:

核心优化策略

为了在单 T4 GPU 上顺利完成大模型的全参数微调,Notebook 中采用了以下关键优化策略:

  • 优化的 TRL 库分支: 使用了 andyl98 大佬针对 TRL 库的一个分支,该分支引入了批量计算 logprobs 的功能,显著降低了显存占用。同时,作者还在其基础上进一步优化了 logprobs 计算函数,以更有效地减少 VRAM 使用。
  • 8-bit AdamW 优化器: 采用 8 位 AdamW 优化器,相比于传统的 32 位或 16 位优化器,可以大幅降低优化器状态所占用的显存。
  • 显存分配限制: 通过设置环境变量 PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:128',显式地限制 PyTorch 的显存分配策略,避免显存碎片化,更有效地利用有限的显存空间。

进阶优化 (可选)

如果您使用的是 Ampere 架构或更新的 NVIDIA GPU (例如 A100, RTX 30/40 系列),还可以通过以下方式进一步降低显存占用:

  • 启用 flash_attention_2: 在加载模型时,设置 attn_implementation="flash_attention_2",利用 Flash Attention 2 加速注意力计算,并减少显存占用 (T4 GPU 不支持 Flash Attention 2)。
  • 使用 Liger-Kernel 封装: 使用 Liger-Kernel 提供的封装器加载模型,可以进一步优化模型推理性能和显存效率。

代码详解

接下来,我们将逐个代码块进行详细解读,并添加中文注释,帮助您理解 Notebook 的实现细节。

1. 安装依赖库

%%capture
!pip install uv  # 使用 uv 包管理器,速度更快

!uv pip install --system git+https://github.com/qunash/trl-1.git@grpo-vram-optimization # 安装优化的 TRL 库分支,来自 qunash 的 fork
!uv pip install --system triton==2.2.0 # 安装特定版本的 triton,用于支持 flash attention 等
!uv pip install --system vllm # 安装 vllm,用于快速模型推理
!uv pip install --system bitsandbytes # 安装 bitsandbytes,用于 8-bit 优化器等

代码解释:

  • %%capture: 这是一个 Jupyter Notebook 的 magic command,用于捕获单元格的输出,避免安装过程中的大量信息刷屏。
  • !pip install uv: 使用 uv 包管理器代替 pipuv 据称速度更快。
  • !uv pip install ...: 使用 uv pip 安装必要的 Python 库。
    • git+https://github.com/qunash/trl-1.git@grpo-vram-optimization: 从 GitHub 安装指定的 TRL 库分支,该分支包含了显存优化。
    • triton==2.2.0: 安装 triton 库的 2.2.0 版本,triton 用于构建高性能的深度学习算子,例如 Flash Attention。
    • vllm: 安装 vllm 库,这是一个快速且高效的大模型推理库,可以显著提升模型推理速度,并降低显存占用。
    • bitsandbytes: 安装 bitsandbytes 库,该库提供了 8-bit 优化器 (例如 8-bit AdamW) 和量化等功能,用于降低显存占用。
  • --system: 将库安装到系统 Python 环境中。

2. 代码设置与数据预处理

import os # 导入 os 库,用于与操作系统交互,例如设置环境变量
import re # 导入 re 库,用于正则表达式操作,例如提取答案
import torch # 导入 PyTorch 库
from datasets import load_dataset, Dataset # 导入 datasets 库,用于加载和处理数据集
from transformers import AutoTokenizer, AutoModelForCausalLM # 导入 transformers 库,用于加载预训练模型和 tokenizer
from trl.trainer import GRPOConfig, GRPOTrainer # 导入 trl 库中的 GRPOConfig 和 GRPOTrainer,用于 GRPO 训练

# 定义 R1 风格的系统提示词,用于引导模型进行 reasoning 和 answer 的格式输出
R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

# 定义任务特定的指令,这里要求答案必须是单个整数
TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

# 数据预处理函数
def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
    dataset = load_dataset(dataset_name, 'main')[split] # 加载指定数据集和 split (例如 'train')

    # 从文本中提取 "####" 后面的答案
    def extract_hash_answer(text: str) -> str | None:
        try:
            return text.split("####")[1].strip() # 使用 "####" 分割文本,取第二部分并去除首尾空格
        except IndexError:
            return None # 如果分割失败,返回 None

    # 处理批次数据
    def process_batch(batch):
        # 构建 prompt 列表,每个 prompt 包含 system prompt, example conversation, 和用户问题
        prompts = [[
            {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS}, # 系统提示词
            {'role': 'user', 'content': "What is 2+2?"}, # 示例用户问题
            {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"}, # 示例助手回答,包含 reasoning 和 answer 标签
            {'role': 'user', 'content': q.strip()} # 当前批次的用户问题
        ] for q in batch['question']] # 遍历批次中的问题

        return {
            'prompt': prompts, # 返回构建好的 prompt 列表
            'answer': [extract_hash_answer(a) for a in batch['answer']] # 返回提取出的答案列表
        }

    return dataset.map(process_batch, batched=True, batch_size=chunk_size) # 使用 map 函数对数据集进行批处理

dataset_name = 'openai/gsm8k' # 设置数据集名称为 gsm8k (数学应用题数据集)
dataset = preprocess_dataset(dataset_name, chunk_size=500) # 预处理数据集,批次大小为 500

# 从 XML 格式文本中提取答案,例如从 <answer>42</answer> 中提取 "42"
def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip() # 使用 <answer> 和 </answer> 分割文本,提取答案并去除首尾空格
        return answer
    except IndexError:
        return "" # 如果提取失败,返回空字符串

# 定义奖励函数 (reward functions)

# 格式奖励函数:检查模型输出是否符合 XML 格式要求 (包含 <reasoning> 和 <answer> 标签)
def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has the correct format."""
    pattern = r"^<reasoning>.*?</reasoning>\s*<answer>.*?</answer>$" # 定义正则表达式,匹配以 <reasoning> 开头,</reasoning> 结尾,中间任意字符,然后是 <answer> 开头,</answer> 结尾,中间任意字符的格式
    responses = [completion[0]["content"] for completion in completions] # 从 completions 中提取模型生成的文本内容
    matches = [bool(re.match(pattern, r)) for r in responses] # 使用正则表达式匹配生成的文本是否符合格式
    return [1.0 if match else 0.0 for match in matches] # 如果匹配,奖励 1.0,否则奖励 0.0

# 正确性奖励函数:检查模型生成的答案是否与标准答案一致
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that checks if the answer is correct."""
    responses = [completion[0]['content'] for completion in completions] # 提取模型生成的文本内容
    extracted_responses = [extract_xml_answer(r) for r in responses] # 从生成的文本中提取 XML 格式的答案
    print(f"Question: {prompts[0][-1]['content']}\nAnswer: {answer[0]}\nResponse: {responses[0]}\nExtracted: {extracted_responses[0]}") # 打印问题、标准答案、模型完整输出和提取出的答案,用于调试
    print(''.join('' if r == a else '' for r, a in zip(extracted_responses, answer))) # 打印 ✅ 或 ❌,表示答案是否正确
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] # 如果答案正确,奖励 2.0,否则奖励 0.0

# model_name = "Qwen/Qwen2.5-0.5B" # 可以选择不带 Instruct 的版本,这里使用 Instruct 版本
model_name = "Qwen/Qwen2.5-0.5B-Instruct" # 设置模型名称为 Qwen2.5-0.5B-Instruct

output_dir = f"outputs/{model_name.split('/')[-1]}-GRPO" # 设置输出目录,例如 outputs/Qwen2.5-0.5B-Instruct-GRPO
run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}" # 设置 run 名称,例如 Qwen2.5-0.5B-Instruct-gsm8k

# 设置显存相关的环境变量,限制 PyTorch 显存分配策略,防止 OOM
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

max_prompt_length=256 # 设置最大 prompt 长度
max_completion_length=512 # 设置最大 completion 长度

# 配置 GRPO 训练参数
training_args = GRPOConfig(
    output_dir=output_dir, # 输出目录
    run_name=run_name, # run 名称
    learning_rate=1e-5, # 学习率
    beta=0.005, # divergence coefficient,控制策略偏离参考模型的程度,值越大更新越保守,默认 0.04
    optim="adamw_8bit", # 使用 8-bit AdamW 优化器
    adam_beta1=0.9, # AdamW beta1 参数
    adam_beta2=0.99, # AdamW beta2 参数
    weight_decay=0.1, # 权重衰减
    warmup_ratio=0.1, # 学习率 warmup 比例
    lr_scheduler_type='cosine', # 学习率调度器类型为 cosine
    logging_steps=1, # 每隔多少步记录日志
    bf16=True, # 使用 bfloat16 精度训练
    per_device_train_batch_size=4, # 每个设备上的训练 batch size
    num_generations=4,  # group size,GRPO 算法中的 group size
    gradient_accumulation_steps=4, # 梯度累积步数,用于增大有效 batch size
    max_prompt_length=max_prompt_length, # 最大 prompt 长度
    max_completion_length=max_completion_length, # 最大 completion 长度
    num_train_epochs=1, # 训练 epoch 数
    save_steps=100, # 每隔多少步保存模型 checkpoint
    max_grad_norm=0.1, # 最大梯度裁剪范数
    report_to="wandb", # 使用 wandb 记录训练日志 (需要安装 wandb 并登录)
    log_on_each_node=False, # 只在主节点记录日志
    use_vllm=True, # 使用 vllm 进行推理加速
    vllm_init_kwargs={ # vllm 初始化参数
        "device": "cuda:0", # 指定设备为 cuda:0
        "gpu_memory_utilization": 0.3, # 设置 vllm 显存利用率为 30%
        "max_model_len": max_prompt_length + max_completion_length, # 设置 vllm 最大模型长度
        "dtype": "half", # 设置 vllm 数据类型为 half (float16/bfloat16,根据 GPU 支持自动选择)
        # "enable_chunked_prefill": True, # 启用 chunked prefill (vllm 的优化技术)
        # "max_num_batched_tokens": 2048, # 设置 vllm 最大 batch tokens 数
    },
    gradient_checkpointing=True, # 启用梯度检查点,减少显存占用
    gradient_checkpointing_kwargs={"use_reentrant": False}, # 梯度检查点参数,use_reentrant=False 可以节省显存
    logit_computation_mini_batch_size=1, # logit 计算的 mini batch size,用于进一步减少显存占用
    enable_profiling=False # 是否启用 profiling
)

# 加载预训练模型
model = AutoModelForCausalLM.from_pretrained(
    model_name, # 模型名称
    torch_dtype=torch.bfloat16, # 指定数据类型为 bfloat16
    # attn_implementation="flash_attention_2", # T4 不支持 flash_attention_2,如果使用 A100 等 Ampere 架构 GPU 可以启用
    device_map="auto", # 自动选择设备 (GPU 或 CPU)
)

# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name, # tokenizer 名称
    model_max_length=training_args.max_completion_length, # 设置 tokenizer 最大长度
)
tokenizer.pad_token = tokenizer.eos_token # 将 pad token 设置为 eos token

# 初始化 GRPO Trainer
trainer = GRPOTrainer(
    model=model, # 传入模型
    processing_class=tokenizer, # 传入 tokenizer
    reward_funcs=[ # 传入奖励函数列表
        format_reward_func, # 格式奖励函数
        correctness_reward_func # 正确性奖励函数
    ],
    args=training_args, # 传入训练参数
    train_dataset=dataset, # 传入训练数据集
)

# 开始训练
trainer.train()

代码解释:

  • 导入库: 导入了 PyTorch, Hugging Face datasets, transformers, 和 trl 等必要的库。
  • 定义提示词和指令: R1_STYLE_SYSTEM_PROMPTTASK_SPECIFIC_INSTRUCTIONS 定义了用于引导模型生成特定格式答案的提示词。
  • preprocess_dataset 函数: 该函数负责加载 GSM8K 数据集,并将其处理成 GRPO 训练所需的格式,包括构建包含系统提示、示例对话和用户问题的 prompt,以及提取标准答案。
  • extract_xml_answer 函数: 用于从模型生成的 XML 格式文本中提取答案 (例如 <answer>42</answer>)。
  • 奖励函数:
    • format_reward_func: 判断模型输出是否符合 XML 格式要求,符合则奖励 1.0,否则奖励 0.0。
    • correctness_reward_func: 判断模型生成的答案是否与标准答案一致,一致则奖励 2.0,否则奖励 0.0。 该函数还会打印一些调试信息,方便观察训练过程。
  • 模型和 Tokenizer 加载: 使用 AutoModelForCausalLM.from_pretrainedAutoTokenizer.from_pretrained 加载预训练的 Qwen2.5-0.5B-Instruct 模型和 tokenizer。
  • GRPOConfig: 配置 GRPO 训练的各项参数,例如学习率、batch size、优化器、显存优化策略等。 其中,optim="adamw_8bit" 指定使用 8-bit AdamW 优化器,gradient_checkpointing=True 启用梯度检查点,vllm_init_kwargs 配置了 vllm 的相关参数。
  • GRPOTrainer 初始化: 使用配置好的模型、tokenizer、奖励函数、训练参数和数据集,初始化 GRPOTrainer
  • trainer.train(): 开始进行 GRPO 训练。

3. 模型评估

import torch # 导入 PyTorch
from datasets import load_dataset # 导入 datasets 库
from transformers import AutoTokenizer # 导入 transformers 库
from vllm import LLM, SamplingParams # 导入 vllm 库,用于快速推理
from tqdm.notebook import tqdm # 导入 tqdm,用于显示进度条
import numpy as np # 导入 numpy
from typing import List, Dict # 导入 typing,用于类型提示
import json # 导入 json,用于保存 json 文件
from datetime import datetime # 导入 datetime,用于获取当前时间
import logging # 导入 logging,用于日志记录

# 禁用 VLLM 的进度条,避免在 Notebook 中输出过多信息
logging.getLogger("vllm").setLevel(logging.WARNING)

# 从训练脚本中复制常量,保持评估时 prompt 格式一致
R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

# 复制 XML 答案提取函数
def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

# 复制 Hash 答案提取函数 (用于提取 GSM8K 数据集中的标准答案)
def extract_hash_answer(text: str) -> str | None:
    try:
        return text.split("####")[1].strip()
    except IndexError:
        return None

# 评估模型函数
def evaluate_model(
    model_path: str, # 模型路径,checkpoint 路径
    batch_size: int = 4, # 评估 batch size
    num_samples: int = None, # 评估样本数量,None 表示评估整个数据集
    save_results: bool = True, # 是否保存评估结果到 json 文件
    gpu_memory_utilization: float = 0.3, # vllm 显存利用率,与训练时保持一致
) -> Dict:
    print("Initializing evaluation...") # 打印开始评估信息

    # 初始化 VLLM,并显示加载进度条
    with tqdm(total=2, desc="Loading model components") as pbar:
        llm = LLM( # 初始化 vllm LLM 对象
            model=model_path, # 模型路径
            dtype="half", # 数据类型为 half (float16/bfloat16)
            gpu_memory_utilization=gpu_memory_utilization, # 显存利用率
            max_model_len=768, # 最大模型长度
            device="cuda:0", # 设备为 cuda:0
            enable_chunked_prefill=True, # 启用 chunked prefill
        )
        pbar.update(1) # 更新进度条

        tokenizer = AutoTokenizer.from_pretrained( # 加载 tokenizer
            model_path, # tokenizer 路径,与模型路径一致
            model_max_length=768, # 最大长度
            padding_side='right', # padding 方向为 right
            truncation_side='right' # truncation 方向为 right
        )
        pbar.update(1) # 更新进度条

    # 设置 vllm 推理参数
    sampling_params = SamplingParams(
        temperature=0.0, # temperature 为 0,采用 greedy decoding
        max_tokens=512,  # 最大生成 tokens 数,与训练时的 max_completion_length 保持一致
        stop_token_ids=[tokenizer.eos_token_id], # 停止生成的 token id,这里设置为 eos token
    )

    # 加载测试数据集
    print("Loading dataset...") # 打印加载数据集信息
    dataset = load_dataset('openai/gsm8k', 'main', split='test') # 加载 gsm8k 测试集
    if num_samples: # 如果指定了评估样本数量
        dataset = dataset.select(range(num_samples)) # 则只选择指定数量的样本
    total_samples = len(dataset) # 获取总样本数
    print(f"Loaded {total_samples} samples") # 打印加载的样本数

    results = [] # 存储评估结果
    correct = 0 # 记录正确答案数量
    total = 0 # 记录总样本数量

    # 创建 tqdm 进度条
    progress_bar = tqdm(
        total=total_samples, # 总样本数
        desc="Processing samples", # 进度条描述
        unit="examples", # 单位为 examples
        dynamic_ncols=True, # 动态调整进度条宽度
    )

    progress_bar.set_postfix({ # 设置进度条后缀,显示 accuracy 和 correct 数量
        'acc': '0.00%',
        'correct': '0',
    })

    # 批量处理样本
    for i in range(0, total_samples, batch_size): # 遍历所有样本,步长为 batch_size
        batch_data = dataset[i:i + batch_size] # 获取当前批次的数据
        current_batch_size = len(batch_data['question']) # 获取当前批次的实际大小

        # 准备 prompts,格式与训练时保持一致
        prompts = [
            [
                {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
                {'role': 'user', 'content': "What is 2+2?"},
                {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"},
                {'role': 'user', 'content': q.strip()}
            ] for q in batch_data['question']
        ]

        # 将 prompt 转换为 chat format,使用 tokenizer 的 apply_chat_template 方法
        formatted_prompts = [
            tokenizer.apply_chat_template(
                p,
                tokenize=False, # 不进行 tokenize,因为 vllm 内部会处理
                add_generation_prompt=True # 添加 generation prompt
            )
            for p in prompts
        ]

        # 使用 vllm 进行推理,生成 responses
        outputs = llm.generate(
            formatted_prompts,
            sampling_params,
        )

        # 处理每个 response
        for j, output in enumerate(outputs):
            response = output.outputs[0].text # 获取生成的文本 response

            # 提取生成的答案和标准答案
            generated_answer = extract_xml_answer(response) # 提取生成的 XML 格式答案
            true_answer = extract_hash_answer(batch_data['answer'][j]) # 提取标准答案

            # 存储结果
            result = {
                'question': batch_data['question'][j], # 问题
                'true_answer': true_answer, # 标准答案
                'generated_answer': generated_answer, # 生成的答案
                'full_response': response, # 完整 response
                'correct': generated_answer == true_answer # 是否正确
            }
            results.append(result) # 将结果添加到 results 列表中

            # 更新 metrics
            if generated_answer == true_answer: # 如果答案正确
                correct += 1 # 正确答案数加 1
            total += 1 # 总样本数加 1

        # 更新进度条
        progress_bar.update(current_batch_size) # 更新进度条进度
        progress_bar.set_postfix({ # 更新进度条后缀
            'acc': f'{(correct/total)*100:.2f}%', # 更新 accuracy
            'correct': f'{correct}/{total}', # 更新 correct 数量
        })

    progress_bar.close() # 关闭进度条

    # 计算最终 metrics
    accuracy = correct / total if total > 0 else 0 # 计算 accuracy
    metrics = { # 构建 metrics 字典
        'accuracy': accuracy, # accuracy
        'correct': correct, # 正确答案数
        'total': total, # 总样本数
        'model_path': model_path, # 模型路径
        'timestamp': datetime.now().isoformat() # 当前时间戳
    }

    # 保存评估结果到 json 文件
    if save_results: # 如果 save_results 为 True
        save_path = f"gsm8k_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" # 构建保存路径
        with open(save_path, 'w') as f: # 打开文件
            json.dump({ # 将 metrics 和 results 写入 json 文件
                'metrics': metrics,
                'results': results
            }, f, indent=2) # 使用 indent=2 格式化 json 文件
        print(f"\nResults saved to {save_path}") # 打印保存路径

    return metrics # 返回 metrics 字典

print("Starting GSM8K evaluation...") # 打印开始评估信息
checkpoint_path = "outputs/Qwen2.5-0.5B-Instruct-GRPO/checkpoint-latest"  # 设置 checkpoint 路径,需要根据实际情况修改

# 运行评估函数
metrics = evaluate_model(
    model_path=checkpoint_path, # 传入 checkpoint 路径
    batch_size=4, # batch size
    num_samples=None, # 评估所有样本
    save_results=True, # 保存结果
    gpu_memory_utilization=0.3, # 显存利用率
)

print("\nFinal Evaluation Results:") # 打印最终评估结果
print(f"Accuracy: {metrics['accuracy']:.2%}") # 打印 accuracy,保留两位小数百分比
print(f"Correct: {metrics['correct']}/{metrics['total']}") # 打印 correct 数量和总样本数

代码解释:

  • 导入库: 导入了 PyTorch, datasets, transformers, vllm, tqdm 等库,用于模型评估。
  • 常量和函数复用: 复制了训练脚本中的 R1_STYLE_SYSTEM_PROMPT, TASK_SPECIFIC_INSTRUCTIONS, extract_xml_answer, extract_hash_answer 等常量和函数,确保评估时 prompt 格式和答案提取方式与训练时一致。
  • evaluate_model 函数: 该函数负责评估模型的性能。
    • 初始化 VLLM: 使用 vllm.LLM 初始化 vllm 模型,加载指定路径的 checkpoint,并配置显存利用率等参数。
    • 加载 Tokenizer: 加载与模型对应的 tokenizer。
    • 设置推理参数: 使用 vllm.SamplingParams 设置推理参数,例如 temperature (设置为 0 以进行 greedy decoding), max_tokens 等。
    • 加载测试数据集: 加载 GSM8K 测试集。
    • 批量推理: 循环遍历测试数据集,分批进行推理,使用 llm.generate 生成模型的 responses。
    • 答案提取和评估: 从模型 response 中提取生成的答案,与标准答案进行比较,计算 accuracy 等指标。
    • 结果保存: 将评估结果 (包括 metrics 和详细的 results) 保存到 json 文件。
  • 运行评估: 调用 evaluate_model 函数,传入训练好的模型 checkpoint 路径,开始进行模型评估。
  • 打印评估结果: 打印最终的 accuracy 和 correct 数量。

实验结果

运行评估代码后,您将会看到模型在 GSM8K 测试集上的评估结果,包括 Accuracy 和 Correct 数量。经过 GRPO 微调后,Qwen2.5-0.5B-Instruct 模型在 GSM8K 数据集上的准确率可以从 22.4% 提升到 48.6%,证明了该方法的有效性。

总结

本文详细解析了如何使用 GRPO 算法在单张 T4 GPU 上对 Qwen2.5-0.5B 模型进行全参数微调,并取得了显著的性能提升。 核心的优化策略包括使用优化的 TRL 库分支、8-bit AdamW 优化器以及显存分配限制等。 这些技术使得在资源受限的条件下训练大模型成为可能。

希望本文能够帮助读者理解大模型微调的原理和实践方法,并能够在自己的项目中应用这些技术,充分发挥大模型的潜力。


结语

感谢您的阅读!如果您在实践过程中遇到任何问题,欢迎留言交流。