• AGI
  • 《探索智能演化:GRPO训练在GRPO Llama-1B上的绚丽实验之旅》

在深邃的机器学习星空中,科研人员怀揣着对人工智能真谛的无限好奇,不断探索如何使模型“思考”得更像人类。今天,我们就来一场穿越代码与实验细节的奇幻旅行,揭开 GRPO(Generalized Reward Policy Optimization)训练在 GRPO Llama-1B 上的神秘面纱。这不仅是一段技术细节的堆砌,更是一场充满智慧火花与代码诗意的探索之旅。本文将以通俗易懂且风趣幽默的自然杂志风格,为你呈现这一过程中的每个精彩片段,从数据集准备、模型配置,到奖励函数的设计,再到硬件设备带来的神秘变量,我们将一一拆解。


🌍 起航:从问题到解决方案的起点

在许多 AI 实验中,如何引导模型做出正确推理、生成格式正确的答案是个永恒的命题。正如我们在剧本中看到的那样,GRPO Llama-1B 脚本中设定了如下系统提示(system prompt):

Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

这段提示语不仅要求模型回答问题,更要求模型给出推理过程。它借助 XML 风格的标签 <reasoning><answer> 来明确区分模型思考的过程和最终答案。想象一下,这是一个魔法咒语,让模型在回答问题时不得不展示出思维轨迹,就像是在公开展示“心灵的演算过程”。

而为了在实验中验证这种提示策略的有效性,作者特别选用了 GSM8K 数据集——一个包含小学数学问题的大规模数据集,成为训练和测试模型推理能力的试金石。通过对这些问题进行细致剖析,科学家们不仅追求准确答案,更期待看到模型如何通过严谨的“思维”步骤,给出合理解释。


🧩 数据集的奥秘:从 GSM8K 开始

在 GRPO Llama-1B 的整个实验过程中,数据集的作用不可忽视。我们看到以下函数被用于加载 GSM8K 数据集:

def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

这段代码的精妙之处在于它不仅加载了数据,更重新构造了每个问题的输入格式。每个数据样本包含两个部分:一个系统提示(始终提醒模型按照特定格式回答)以及用户提出的问题。与此同时,正确答案通过提取函数 extract_hash_answer 被精确分离。这样一来,无论是训练还是测试模型时,输入输出都在严格的框架内运行,帮助模型建立统一的输出格式。

把它想象成烹饪一道精致大餐:原材料(GSM8K 的问题与答案)经过严谨的预处理(格式化成 prompt 和 answer),在“厨房”(训练脚本)中经过复杂工艺的烹饪,最终呈现出一道令人赞叹的“智能料理”。


🤖 模型与配置:细节决定成败

在这个实验中,模型的选择和配置同样令人瞩目。代码中展示了两条不同的逻辑分支,一支来自 Llama 模型,另一支则来自 Qwen 模型。这里的关键是:模型名称的选择直接影响了输出目录和运行名称。例如,在代码里这样书写:

model_name = "Qwen/Qwen2.5-1.5B-Instruct"
if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir = "outputs/Qwen-1.5B-GRPO"
    run_name = "Qwen-1.5B-GRPO-gsm8k"

这一判断逻辑确保了模型版本与所使用的硬件或参数调优方案对应。我们还看到,通过加载 AutoModelForCausalLM.from_pretrained 方法,模型被加载到 GPU 上,同时采用 torch_dtype=torch.bfloat16 参数来平衡精度和内存占用。而 attn_implementation="flash_attention_2" 则代表了作者在实验中对注意力机制的优化尝试,旨在加速生成过程并降低显存使用。这正如在高速列车上安装了新型空气动力学系统,让列车在运行过程中既快速又高效。


🔧 工具箱:奖励函数的魔法

在基于强化学习的策略训练中,奖励函数的重要性不言而喻。GRPO 训练通过多个奖励函数为模型输出打分,进而引导模型在对问题给出答案时兼顾准确性与格式规范。脚本中定义了多个奖励函数,例如:

  • correctness_reward_func :专注于答案的正确性。如果模型生成的答案与标准答案一致,则奖励为 2.0,否则为 0.0。具体实现中,从模型输出中提取 <answer> 标签内的内容,再对比答案是否一致。如果答案形式正确,就会获得奖励 2.0,否则毫发无损地给予 0 分。

  • int_reward_func :这个函数主要检测模型生成的答案是否为数字,采用方式检查字符串是否纯数字。如果是,则奖励 0.5,否则没有奖励。这样的设计正好适用于 GSM8K 中大部分答案都是整数的情况。

  • strict_format_reward_funcsoft_format_reward_func 则着眼于输出文本格式:前者检查输出文本是否完全符合预期的 XML 格式,后者则允许一定的自由度、较宽松匹配正确格式。代码中利用正则表达式来匹配文本,这个过程如同考古学家细致检查文物的每一处细节,确保其符号完整无缺。

  • xmlcount_reward_func 则用来计算输出文本中 XML 标签的数量,奖励依据标签出现的次数与位置是否正确。不过,这个打分机制还引入了惩罚因子,用来抵消重复标签的多余奖励。通过下面的公式描述:

    count=δstart×0.125+δend×0.125+δanswer_start×0.125+δanswer_end×0.125penaltycount = \delta_{\text{start}} \times 0.125 + \delta_{\text{end}} \times 0.125 + \delta_{\text{answer\_start}} \times 0.125 + \delta_{\text{answer\_end}} \times 0.125 - \text{penalty}

    其中每个 δ\delta 用以确认特定标签是否唯一出现,而 penalty 则根据额外分隔符的数量进行微调。这样的设计灵感来自计分板上分值的精确计算,任何多余的分割符都将被扣分,确保模型输出的“文章”严谨规范。

此外,代码中的 extract_xml_answer 函数用来从模型生成的长文本中提取出 <answer> 标签之间的数据,这个过程基于字符串分割技巧实现,非常巧妙地从冗余文本中筛选出关键信息。

试想,整个流程犹如一位严格的评委,用多个标准对选手的表现进行打分:除了答案是否正确外,还考察思路是否条理清楚、格式是否严谨。多维度打分使得训练过程更具层次感,很大程度上帮助模型在训练过程中快速收敛,提升推理准确度和输出格式的一致性。


🔬 强化学习的灵魂:GRPO 策略及其超参调优

强化学习(RL)在自然语言处理中的应用正日益火热,而 GRPO 就是其中一颗冉冉升起的新星。其核心思想在于利用奖励信号来优化策略,逐步调整模型生成的答案,推动模型在保持正确答案的基础上,还能“学会”如何解释自己的推理过程。GRPO 配置被封装在 GRPOConfig 中,其参数设计反映了对细节的极致追求,例如:

  • 学习率(learning_rate)设为 5×1065 \times 10^{-6},相对较低,确保更新过程足够平稳。
  • adam_beta1adam_beta2 分别设为 0.9 和 0.99,确保优化器在权重更新时平滑而有效。
  • weight_decaywarmup_ratio 分别引入正则化和学习率调度,帮助模型在训练初期缓慢适应,避免过拟合。
  • num_generations 指定了生成多少个候选答案,模型在内部进行自我对比,从而挑选最佳答案,这一设计体现了“多视角评估”的训练理念。

通过调整这些超参数,研究人员可以在训练中观察模型数值输出的变化,例如在一段讨论中有研究者提到,将 beta 参数调整到 0.01 后,模型在 GSM8K 测试集上的得分能够从 41.6% 提升到 51% 以上。这里的 beta 指的是 KL 散度系数,用于平衡模型生成分布与真实分布之间的距离。公式可以简单表示为:

KL Loss=β×DKL(PQ)\text{KL Loss} = \beta \times D_{\text{KL}}(P \parallel Q)

其中 PP 为模型当前生成的分布,QQ 为参考分布,β\beta 则调控了奖励信号的强度。正是这种精细调节,使得 GRPO 能够迅速收敛,并在小样本数据集上表现出惊人的改善。

另外,为了进一步优化训练过程,不少研究者还尝试通过第三方工具 Optuna 进行超参调优,自动寻找最佳奖励函数门槛和权重设置。通过动态调整奖励函数中的权重(例如 xmlcount_weight, soft_format_weight, strict_format_weight 等),研究人员不仅能在训练过程中观察到模型性能的细微变化,还能通过大规模试验自动化地寻找全局最佳参数组合,确保训练过程的高效稳定。


💻 硬件与软件的变奏曲:从 H100 到 7900XTX 的奇妙差异

近年来,在硬件性能不断突飞猛进的今天,模型训练不仅仅受限于算法本身,还深受底层硬件、软件库版本影响。GRPO Llama-1B 的实验记录中,讨论了在 H100 与 7900XTX 两种 GPU 之间训练结果的差异。对同一代码,不同硬件可能会因为库的版本差异而产生意想不到的结果。例如,研究者发现 H100 与 7900XTX 分别采用了不同版本的 TRL 库:

  • 7900XTX 使用的是 trl @ git+https://github.com/huggingface/trl.git@2ce36ae889f286dad91dc3ac6b55904864bf9254
  • 而 H100 则采用 trl @ git+https://github.com/huggingface/trl.git@1c35a48b50f54b92c6b820437aaf75c4e3d777ce

这一版本差异可能会引起意想不到的训练行为,如奖励函数的数值反馈不同,生成文本格式略有偏差等。正如文章中一位研究者调侃:“我还没弄明白哪个包出了 bug,但显然在两台机器上,模型的表现天差地别。”这种现象提醒我们,在科研的道路上,每一个小细节都可能成为影响实验结果的重要因素。

此外,研究人员对于同一代码在多 GPU 与单 GPU 环境下的表现也进行了对比,有人成功在单 GPU 上运行了 peft 调优方案,也有人面临多 GPU 分布式训练中不可调和的问题。对此,社区内讨论激烈,不乏一些调试技巧,如在部分版本环境中禁用 peft,以确保所有层都得到更新;或通过调用 vLLM 直接进行生成,从而规避分布式训练时可能存在的种种坑。正是这种开放的科研讨论,使得整个社区不断进步,激发出无数解决方案和调试策略。


🎯 实验数据与性能评估:从 41.6% 到 51% 的飞跃

实验数据永远是检验理论最有力的证据。通过在 GSM8K 数据集上对比基线模型与经过 GRPO 强化训练后的模型,有学者报道在一些参数优化后,模型在测试集上的准确率从 41.6% 一举提升至 51% 左右。这一 10% 多的绝对提升,虽然看似微小,但在 AI 领域中,无疑是一个令人振奋的进步。尤其是在面对千上万道题目时,这种提分效果可以大幅提高模型实际应用中的可靠性和表现。

从代码角度来看,评价函数 evaluate_model 被设计成对每个测试样本生成候选答案,并提取其中 <answer> 的部分,再与标准答案精确对比。其核心逻辑如下:

def evaluate_model(model, tokenizer, dataset):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    for i in range(len(dataset)):
        sample = dataset[i]
        prompt = sample['prompt']
        true_answer = sample['answer']
        inputs = tokenizer.apply_chat_template(prompt, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = model.generate(inputs, max_new_tokens=786, pad_token_id=tokenizer.eos_token_id)
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predicted_answer = extract_xml_answer(generated_text)
        if predicted_answer == true_answer:
            correct_predictions += 1
        total_predictions += 1
    accuracy = correct_predictions / total_predictions
    return accuracy

这一段代码不仅体现了文本生成过程,还采用了贪婪解码(greedy decoding)的策略,为了确保每个生成答案都在固定格式下被正确解析。这样的测试过程真实反映了在实际生产环境中,模型如何处理复杂问题,并将每个数字、每个符号都精准还原,犹如一部精密运转的机械装置。

同时,为了不断改进模型性能,研究者们还使用了 Optuna 这类自动化超参调优工具,对整个训练过程中的各项奖励权重与优化超参进行全局搜索,从而为 GRPO 模型找到最佳运行配置。Optuna 的引入为实验注入了自动探索的基因,使得奖励函数的每个参数都能在多次试验中找到最优‘平衡点’。这种自动探索方法无疑为未来大规模自适应 AI 训练提供了无限可能。


🌟 科学故事与幽默插曲:实验室里的奇闻轶事

科研的道路往往充满坎坷,但也少不了有趣的插曲。在 GitHub 评论区中,研究者们不仅讨论着代码的细节,还调侃着硬件环境的奇异表现:

“我在 H100 上遇到的问题太多了,简直无法调试;而 7900XTX 则仿佛一台‘古董’,越老越有韵味……”

有人甚至调侃说:“在多 GPU 时代,让我怀疑到底是算法让模型变聪明了,还是显卡黑科技在捣鬼。”这种调侃不仅活跃了严肃的讨论氛围,也让每一位忙碌于代码调试的工程师在实验室的长夜里找到了一丝慰藉和笑意。

每一次调试,每一次参数的微调,都仿佛是一场探险,仿佛技术人员在为未知领域开辟新路。正如一位评论者所言:“构建工具本身就是一门艺术,而用工具构建工具,更像是一场没有终点的马拉松。”

有趣的是,在实验过程中,奖励函数也会受到无意间的“副作用”影响。有时,生成文本中会出现越来越多的“(\boxed{ANSWE})”,这是代码中没有明确奖励的部分,却意外成为模型训练过程中的“隐形幽灵”。有位研究者便提出疑问:“为什么会出现越来越多的 (\boxed{ANSWE}) ?代码对这个没有任何奖励,却在训练过程中越来越多?”这样的疑问既引发了讨论,也彰显了科学探索过程中总有新问题等待我们去挖掘和解答。


📊 图表与示例:用数据讲故事

为了更直观地理解实验过程,我们可以构建一些 Markdown 格式的表格,来展示不同设置下的模型表现。以下是一张简单的对比表:

模型版本基线准确率GRPO 后准确率调整参数说明
Qwen2.5-0.5B 原始模型41.6%未经过 GRPO 强化训练
Qwen2.5-0.5B + GRPO51% 左右将 beta 调整为 0.01,学习率 2e-6,生成长度限制 512
Llama-1B + GRPO待实验不同基础模型下的 GRPO 尝试

另一端,我们也可以用伪代码流程图来展示整个训练流程:

【数据处理】
      │
      ▼
[加载 GSM8K 数据集]
      │
      ▼
[构造系统 prompt 和用户提问]
      │
      ▼
[模型生成答案]
      │
      ▼
[提取 <answer> 标签内容]
      │
      ▼
[计算各项奖励函数得分]
      │
      ▼
[梯度反向传播与模型更新]

这样的流程图让我们更直观地见识到,从数据导入到模型生成答案,再到奖励信号作用下的梯度更新,每一步都是精心设计与高效运行的重要环节。


🔍 解析代码:关键细节逐一剖析

翻开这段 GRPO 训练代码,每一行都承载着研发者无数个挑灯夜战、思索优化的痕迹。

  1. 第一部分导入了必备的库:正则表达式(re)、PyTorch、datasets、transformers、peft、trl 等。这些库各司其职,共同构筑了一个完整的自然语言处理训练流水线。

  2. 数据预处理部分,利用 load_dataset 加载 GSM8K 数据集,并使用 map 函数重构数据格式。这里的 extract_hash_answerextract_xml_answer 函数,分别用来提取正确答案和从生成文本中解析出的答案。这种设计正如给模型设置了一个双重保险机制,既能精准评估模型输出,也可以核对模型“思维”的轨迹。

  3. 奖励函数部分尤其值得一提。每个奖励函数都采用了不同的策略:

    • 如果模型输出完全正确,correctness_reward_func 返回 2.0 的奖励;
    • int_reward_func 则检查输出是否为数字,利用 .isdigit() 方法给予奖励 0.5;
    • strict_format_reward_funcsoft_format_reward_func 则通过正则表达式判断输出格式,前者要求严格符合预定格式,后者则允许一定自由度;
    • xmlcount_reward_func 通过计算 XML 标签的出现频率和正确性,来细化打分。整个奖励函数体系的设计,恰似一场精密的“评分表”,每项细节都事关模型最终学习的方向。
  4. 模型加载部分,通过 AutoModelForCausalLM.from_pretrained 加载预训练模型,并在 GPU(cuda)上运行。这里也展示了如何设置诸如 torch_dtype=torch.bfloat16 以及 attn_implementation="flash_attention_2" 等参数,确保训练过程中能兼顾高效与准确率。

  5. 最后,GRPO 训练器 GRPOTrainer 的初始化,以及调用 .train() 开始训练的那一刻,就像拉开了通向未来智能的一道序幕。整个过程不仅仅是代码在运行,更是无数科学家智慧与激情碰撞的闪光点。


🚀 前沿展望:从 GRPO 到未来 AGI

随着实验的不断深入,我们可以看到 GRPO 训练在自然语言处理和强化学习中的巨大潜能。对比传统的直接监督学习(Supervised Fine-Tuning, SFT),GRPO 通过奖励函数对模型推理过程进行干预,有效提高了生成答案的合理性与准确性。正如一位资深工程师所言:“调整奖励函数,就像调整万有引力常数,每一个微小的变化都可能决定整个宇宙的走向。”

事实上,科研社区对 GRPO 的讨论不仅局限于实验参数,还有关于如何利用多模型反馈来进行自我进化的思考。有人提出,通过多样化模型(如 Qwen、Llama 等)之间的协同作用,我们有可能构建起更为强大的人工智能系统。未来的 AGI(人工通用智能)或许正是由这种多样化协作与细致奖励调节构成的——一个由无数微小改进累积而成的宏大奇迹。

同时,Optuna 这样的超参调优工具的引入,使得训练过程更加自动化、科学化。将来的研究中,我们不仅可以利用固定的超参数进行实验,更可以依托自动调优算法,在庞大参数空间中寻找最优组合。这样一来,研究工作将从手动调参的繁琐中解放出来,转而更加专注于算法原理与理论探讨。

GRPO 的出现也使得我们对模型“思考”的过程有了更深层次的认识。通过训练模型同时给出推理步骤与最终答案,我们看到的不仅仅是机器喷薄而出的单一结果,而是一幅完整的思维地图。这样的进步,正是向“可解释人工智能”迈进的重要一步,为未来 AI 系统的透明性与信任机制打下了坚实基础。


🎈 结语:代码与科学背后的无限探索

一路走来,我们从简单的系统提示起步,到复杂奖励函数的设计,从 GSM8K 数据集的精细处理到超参调优的自动化探索,无不展示了科研人员在构建模型训练体系过程中所付出的心血与智慧。GRPO Llama-1B 实验不仅为 AI 训练领域注入了新的思路,更让我们看到在未来 AGI 的道路上,每一段代码都将成为通向未知世界的一把钥匙。

正如那句古老的谚语所说:“万丈高楼平地起。”每一个细微的改进、每一次参数调优,就像是垒起一砖一瓦,最终构筑出智能时代的宏伟大厦。我们相信,在众多热衷于此领域的科研者不断努力下,不仅能使模型变得更加准确、思维更加透明,还能让“人机共生”的未来真正走进每一个角落。

无论你是刚踏入 AI 世界的初学者,还是在前沿领域苦心钻研的专家,这段 GRPO 实验的历程都给我们带来了无限启示和鼓舞。或许明天,就在你我手中运行的一行代码中,便蕴藏着通向未来智能奇迹的奥秘。让我们共同期待,探索下一个令人心跳加速的创新时刻!


📚 参考文献

  1. willccbb. GRPO Llama-1B · GitHub. https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
  2. Hugging Face TRL 文档. https://huggingface.co/docs/trl
  3. GSM8K 数据集介绍. https://openai.com/gsm8k
  4. Optuna 超参数调优工具. https://optuna.org/
  5. 相关讨论及实验分享 (X/Twitter 讨论帖子)

GRPO 与 GSM8K 训练:代码解析与社区讨论综述

近年来,针对大语言模型如何通过强化学习(RL)进行微调,从而提高回答准确率和格式一致性,成为了社区关注的热点。本文将结合 GRPO Llama-1B 的示例代码 以及讨论社区中的交流,对代码整体结构、奖励函数设计、训练流程以及实际应用中的问题与优化策略进行详细梳理和解析。


1. 背景介绍

在该示例代码中,作者使用了 GRPO(可理解为一种基于奖励优化(Reward Policy Optimization)的算法)对 GSM8K 数据集进行强化学习调优。目标是让模型输出符合预定格式的回答,同时确保答案的正确性。本次调优主要依赖于 Hugging Face 的 trl 库,并使用了 PEFT(参数高效微调)和 Flash-Attention 等优化组件来加速模型训练。


2. 数据预处理与提示模板

2.1 系统提示与答案格式

代码中定义了一个固定的系统提示(SYSTEM_PROMPT):

“Respond in the following format:
<reasoning> ... </reasoning>
<answer> ... </answer>

这一提示起到了规范输出格式的作用,要求模型在回答时输出推理过程和最后答案,并分别用 XML 风格的标签 <reasoning><answer> 包裹。除了系统提示外,还定义了一个模板 XML_COT_FORMAT,用于格式化链式思考(Chain-of-Thought,CoT),该模板将推理过程与答案嵌入到预定格式中。

2.2 数据集构造

数据集使用的是 OpenAI 的 GSM8K——一个数学问题数据集。函数 get_gsm8k_questions 加载了数据集的指定 split(如“train”或“test”),并将每个数据样本处理成一个字典格式,其中包括:

  • prompt:一个对话列表,最开始是系统角色提供的格式提示,然后由用户角色给出具体问题。
  • answer:通过辅助函数 extract_hash_answer 将原答案文本(多以 “####” 标记)提取出来。

这种数据预处理方式不仅保证了数据的一致性,同时也为后续强化学习过程中使用奖励函数评估输出提供了标准化输入。


3. 奖励函数设计

在 RL 调优中,奖励函数(Reward Function)起到了至关重要的作用。该代码中设计了多种奖励函数来衡量模型输出与预期结果之间的差距,具体包括:

3.1 正确性奖励

  • correctness_reward_func
    该函数首先提取用户问题最后的 prompt,然后解析模型输出中的 <answer> 部分,并与正确答案(经过预处理后的结果)进行严格对比。如果答案完全一致,则给予较高奖励(例如 2.0 分),否则返回 0 分。
    此外,函数会打印调试信息,展示问题、目标答案、模型完整输出和从中提取的答案,便于检查可能产生的问题。

3.2 格式奖励

为了鼓励模型输出符合预期的 XML 格式(即同时包含 <reasoning><answer> 标签),设计了两种奖励函数:


  • strict_format_reward_func
    使用严格的正则表达式(regex)检测输出内容是否精准匹配预定格式(包括换行符等细节),匹配成功给予固定奖励(例如 0.5 分)。

  • soft_format_reward_func
    使用相对宽松的正则表达式,允许一定的变动性,但依然要求输出包含 <reasoning><answer> 标签。匹配成功同样给予奖励。

3.3 数字格式奖励

  • int_reward_func
    针对 GSM8K 数据集中答案本质上为正整数,该函数检查提取出来的答案是否为数字,若是则给予一定奖励(例如 0.5 分),否则为 0。

3.4 XML 内部结构奖励

  • xmlcount_reward_func
    函数内部通过 count_xml 对输出文本进行简单计数:
    • 检查是否恰好出现 <reasoning>\n\n</reasoning>\n\n<answer>\n\n</answer>\n 标记,并相应地加分;
    • 对于额外的或者不合理的分段,则进行轻微扣分。

这种设计鼓励模型在回答时严格遵守预定格式,但同时也允许一定的灵活性。


4. 模型加载与训练配置

4.1 模型加载

代码中根据模型名称(例如是否包含 “Llama” 来区别使用模型文件夹)确定输出路径,并通过 Hugging Face 的 AutoModelForCausalLM 与相关配置加载模型。本示例中默认使用的是 Qwen 系列模型,如 "Qwen/Qwen2.5-1.5B-Instruct",同时设置模型的量化类型(torch.bfloat16)并采用 Flash-Attention 来加速计算。

4.2 Tokenizer 配置

加载 Tokenizer 后,将其 pad_token 设为模型的 eos_token,保证输入长度对齐,适应生成任务要求。

4.3 训练器与 GRPO 配置

利用 GRPOConfig 对训练超参数进行设定,例如:

  • 学习率(例如 5e-6)
  • 权重衰减、warmup 比例、梯度积累步数
  • 每个 batch 的样本数、每次生成的样本数(num_generations)等
    此外,还设置了多种奖励函数,将它们传递给 GRPOTrainer。这样在训练过程中,模型不仅通过语言建模损失进行更新,还会依据奖励函数反馈进一步优化输出,使其同时满足正确性与格式要求。

4.4 可选的 PEFT 配置

代码中使用的 LoRA(低秩适配)配置(当前被注释掉)则为减少微调过程中模型参数更新量提供了可能。这部分配置在大规模多 GPU 训练场景下曾经引起一些兼容性问题,相关讨论在社区中也不断被提及。


5. 社区讨论与优化探索

在代码的 GitHub 讨论区,社区成员就以下几个问题展开了热烈讨论:

5.1 同一代码在不同硬件平台上的差异

  • H100 vs. 7900XTX
    部分评论提到,同一份代码在 H100 与 7900XTX 上运行时,生成结果存在差异。从进一步对比 pip freeze 中的 TRL 代码版本(不同 commit,如 1c35… 与 2ce36…),可以看出库内更新可能引入了细微差异。社区成员也讨论了如何锁定特定版本并调试潜在的兼容性问题。

5.2 奖励函数参数与超参数调优

  • 提到类似 “beta” 的参数(实际上指的是 KL 系数)对训练效果的影响。通过调整诸如学习率、warmup_ratio 以及生成样本数等超参数,部分团队在 GSM8K 上的准确率得到了显著提升(例如从 41.6% 到 51% 左右)。
  • 有成员尝试通过 Optuna 自动化调整整个奖励函数的权重(例如修改 correct_reward、strict_match_reward 等),使得训练过程免受手动调节困扰,从而进一步挖掘出最佳组合。这种方法为奖励设计提供了灵活性,并帮助找到更优的超参数组合。

5.3 格式化输出问题

  • 社区中有讨论为什么模型有时会输出诸如 “(\boxed{ANSWE})” 的字符,这种内容并非代码中设计奖励时关注的格式部分。这提示开发者在奖励设计时,需要额外考虑如何屏蔽与奖励无关的符号,确保奖励反馈更准确。

5.4 vLLM 的应用场景讨论

  • 部分用户分享了如何在单 GPU 或多 GPU 环境下使用 vLLM 加速生成,有助于快速评估模型在推理端的输出质量。利用 vLLM 可以节省时间,使得周期性评估(例如使用 TrainerCallback)成为可能。

6. 总结与展望

本文通过对 GRPO 调优代码的详细解析,我们可以看出:

  • 使用预定义格式化提示和严格的奖励函数设计,有助于引导模型在回答问题时做到既准确又格式统一。
  • 多种奖励函数的组合(正确性、格式、数字检查以及 XML 内部结构计数)可以针对具体任务(如 GSM8K 数学问答)提供多维度优化。
  • 社区在相同代码基础上,通过不同硬件环境的测试、微调超参数以及引入 Optuna 自动调参,均在探索如何进一步提高模型性能。这些讨论为同类项目提供了宝贵的实践经验和优化思路。

未来,随着社区在 RL 调优和奖励设计上的不断探索,或许能发现更自动化、鲁棒性更强的方法来统一和提升大语言模型的输出质量。开源社区的力量也将在这样的试验和探索中发挥越来越重要的作用。


下面给出对grpo_demo.py文件的详细解读,帮助你理解每个部分的作用。


1. 文件总体作用

该文件主要用于使用 GRPO(Graded Reward Policy Optimization)算法对因果语言模型(Causal LM)进行强化学习式微调。目标是让模型在回答问题时满足指定格式(例如,通过 XML 标签将推理过程和答案分隔开来),同时通过一系列奖励函数对生成结果进行打分,从而促使模型输出既符合格式要求又能正确回答问题的内容。


2. 模块导入与常量定义

首先导入了所需的库和模块:

  • 正则表达式模块 (re):用于验证输出格式。
  • PyTorch (torch):支持 GPU 加速以及张量计算。
  • Datasets 库 (load_dataset, Dataset):用于加载 GSM8k 数据集。
  • Transformers 模块 (AutoTokenizer, AutoModelForCausalLM):用于加载预训练模型和分词器。
  • PEFT 模块 (LoraConfig):支持 LoRA(低秩适应)微调方法。
  • TRL(Transformer Reinforcement Learning)模块 (GRPOConfig, GRPOTrainer):用于强化学习策略微调。

此外,还定义了两个常量:

  • SYSTEM_PROMPT
    一个系统提示字符串,明确要求模型按照特定 XML 格式(包含 <reasoning><answer> 标签)来回复答案。

  • XML_COT_FORMAT
    一个格式化模板,用于将推理过程和答案包装成 XML 样式输出。


3. 数据处理函数

3.1. 提取 XML 格式答案

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()
  • 作用:从模型的输出文本中提取 <answer> 标签内部的答案部分,该函数假定输出文本严格包含 <answer></answer> 标记。

3.2. 提取 Hash 标记的答案

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()
  • 作用:用于从数据集中提取答案。部分数据可能使用 “####” 将答案包裹起来,函数提取 “####” 后面的内容,如果不存在 “####” 则返回 None

3.3. GSM8k 数据集加载与预处理

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: {  # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore
  • 作用

    1. 加载 gsm8k 数据集(主要为数学推理题目)。
    2. 使用 map 方法遍历数据,将每个样本预处理成一个包含 promptanswer 的字典:
      • prompt 为一个消息列表,其中包含系统角色(给出回答格式要求)和用户角色(问题内容)。
      • answer 则利用 extract_hash_answer 将原始答案进行抽取。
  • 备注:部分代码被注释掉,注释内展示了如何使用 1-shot 提示(通过提供一个示范问答)来引导模型,但目前默认仅使用单一用户问题。


4. 奖励函数设计

为了引导模型产生符合要求的回答,文件中定义了多个奖励函数,每个函数根据输出的不同特性给予奖励分数。

4.1. 正确性奖励函数

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
  • 作用
    对生成的答案进行正确性检验:
    • 首先从生成的输出中提取 XML 格式内的答案。
    • 将提取的答案与标准答案进行比较,相同则奖励 2.0 分,不同则奖励 0 分。
    • 同时打印调试信息,便于查看问题、标准答案和模型生成的完整输出细节。

4.2. 数字格式奖励函数

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  • 作用
    检查模型输出的答案是否完全由数字组成。如果是,则奖励 0.5 分;否则不奖励。

4.3. 严格格式奖励函数

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]
  • 作用
    利用正则表达式严格校验生成文本的格式,要求:

    • 文本以 <reasoning> 开始,紧跟换行和非贪婪模式文本,再到 </reasoning>

    • 接着 <answer> 标签及相应内容,最后以 </answer> 结束。


      若完全匹配指定格式则奖励 0.5 分,否则奖励 0 分。

4.4. 宽松格式奖励函数

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]
  • 作用
    与严格格式不同,这里使用一个宽松一点的正则表达式,只要包含 <reasoning><answer> 标签及内容即可,奖励同样为 0.5 分。

4.5. 基于 XML 标签计数的奖励函数

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count
  • 作用
    以更细粒度的方式检查模型输出中的 XML 标签是否正确:
    • 对每个关键标签(<reasoning>\n\n</reasoning>\n\n<answer>\n\n</answer>)各自计分,保证恰好出现一次。
    • 同时如果在答案部分存在额外字符,则按字符长度扣分,用以惩罚格式不严谨的回答。

随后函数 xmlcount_reward_func 会对完整生成结果应用 count_xml 得分:

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

5. 模型、训练配置及训练过程

5.1. 模型名称和输出目录配置

#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"
  • 说明
    文件中可以选择不同的预训练模型(例如 Llama 或 Qwen)。当前选用的是 Qwen 模型,因此训练输出目录和运行名称均以 “Qwen” 命名。

5.2. 训练参数配置

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)
  • 作用
    利用 GRPOConfig 定义训练过程中的超参数,包括:
    • 学习率Adam 优化器的 beta 参数权重衰减
    • warmup比例cosine 学习率调度器
    • 混合精度(bf16)
    • 批量大小梯度累积步数
    • 每个 prompt 生成回答时的生成样本数(num_generations:16);
    • 最大 prompt 与回答长度;
    • 训练轮数、保存步数以及梯度裁剪等配置;
    • 日志输出(例如报告到 wandb)。

5.3. LoRA 参数配置(目前未启用)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)
  • 说明
    如果使用 PEFT 进行低秩适应微调,可以通过此配置定义:
    • 低秩维度 r、比例系数 lora_alpha
    • 针对模型中的哪些模块(如查询、键、值投影等)进行 LoRA 应用;
    • 指定任务类型和 dropout 概率。
      需要注意的是,在后续初始化 GRPOTrainer 时,此配置被注释掉,说明当前没有启用 PEFT。

5.4. 模型和 Tokenizer 加载

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to("cuda")
        
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
  • 作用
    • 使用 AutoModelForCausalLM 从 Hugging Face Hub 加载预训练模型,并指定使用 bfloat16 数据类型以及 flash attention 优化(若支持)。
    • 加载对应的分词器并将 pad_token 设置为 eos_token,防止模型输入时因缺少 pad_token 而报错。

5.5. GRPOTrainer 实例化与训练

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()
  • 作用
    • 实例化 GRPOTrainer 时传入:
    • 模型分词器(此处分词器作为 processing_class,负责数据预处理)。
    • 奖励函数列表:包括对格式(XML 标签)和答案正确性等多维度的奖励函数,组成混合奖励信号。
    • 训练参数training_args)与预处理后的 训练数据集dataset)。
    • 可选的 peft_config(当前被注释掉)。
    • 通过调用 trainer.train() 启动训练过程。

总结

该文件通过以下几个步骤实现了强化学习式的微调流程:

  1. 数据准备
    从 GSM8k 数据集中提取问题及答案,同时添加系统提示以强制模型输出带有 <reasoning><answer> 标签的格式。

  2. 奖励函数设计
    定义多个奖励函数,从格式正确性、答案是否为整数字、XML 标签的完整性及答案正确性等多个角度评分,帮助模型逐步调整输出格式与内容。

  3. 训练配置
    基于 GRPOConfig 指定训练中的所有超参数,并加载预训练模型与分词器,保证训练过程在 GPU 上高效运行。

  4. 训练过程
    使用 GRPOTrainer 将模型、数据、奖励函数、训练参数等组合在一起,开始强化学习式微调,从而使模型在回答问题时既满足特定格式,又尽可能给出正确答案。

通过这些步骤,文件展示了一种基于奖励信号引导预训练语言模型生成结构化、规范化回答的方案,并且可以通过调整奖励函数和超参数来个性化训练效果。

沪ICP备2024052574号-2