安爸-超级家庭

极客说|强化学习(RL)与有监督微调(SFT)的选择以及奖励函数的优化

安爸 发布于

作者: 魏新宇 – 微软 AI 全球黑带高级技术专家

排版:Alan Wang

「极客说」 是一档专注 AI 时代开发者分享的专栏,我们邀请来自微软以及技术社区专家,带来最前沿的技术干货与实践经验。在这里,您将看到深度教程、最佳实践和创新解决方案。关注「极客说」,与行业顶尖专家一起探索科技的无限可能!投稿请联系:17278094563(微信号)

本文首先将阐述强化学习(RL)和监督微调(SFT)在实现方式上的区别,然后通过一个具体案例,详细说明如何对奖励函数进行优化。

从简单例子入手理解 SFT 和 RL

监督微调(SFT)- 像老师教学生

监督微调(Supervised Fine-Tuning,简称 SFT)相当于作为老师,自己先列出很多问题,再告诉模型标准的回答,比如用数据(训练集)教它:

问题 标准答案
1加1多少? 2
苹果什么颜色? 红色
太阳从哪边升起? 东方

我们让模型一遍又一遍模仿训练语料中的标准答案,直到我们符合要求。

SFT 具体步骤(算法的介绍)

  1. 我们拿出一个问题:苹果什么颜色?
  2. 模型自己尝试回答:比如它乱回答成 蓝色
  3. 我们就立马纠正,告诉它正确的答案应该是红色,给它一个明确的误差信号: [ 误差 = – log P(“红色”) ]
  4. 然后模型用这个误差信号帮助它更新自己说法,让下次“红色”概率增加。

所以,监督学习过程如下:

for 问题, 标准答案 in 数据集:     模型答案 = 模型生成(问题)     误差 = 计算交叉熵Loss(模型答案, 标准答案)     模型更新(误差)

优点:安全、稳定

缺点:模型永远只能模仿,不太能创造性地发现新答案。

强化学习(RL)– 让模型自己摸索

强化不直接教标准答案,而是用“鼓励”和“惩罚”引导模型。

我们问模型:“1加1等于?”

  • 它如果乱说了:“香蕉!”,我们立刻给个负面奖励(-1);
  • 如果它说对了:“2”,我们给它正面奖励(+2)。

模型得到这些奖励和惩罚之后,会慢慢去摸索和记忆,知道怎么才能得到更多奖励(而不是直接告诉它标准答案)。

强化学习大致算法:

# RL过程: for 问题 in 数据集:     # 让鹦鹉自由生成多个答案(探索)     多个答案 = 模型生成多个可行答案(问题)      # 每个答案给奖励     for 每个答案 in 多个答案:         奖励 = 奖励函数(每个答案)         更新策略(奖励 * log(生成该答案概率))

优势:模型能够自己发现最优策略,能主动“探索”,学得更主动;

危险:但探索过猛容易产生 KL 爆冲、梯度爆炸、最终模型崩盘。

方面 监督微调(SFT) 强化学习(RL)
本质 模仿老师标准答案 只靠鼓励&惩罚自己摸索
学习速度 快,稳定 慢,可能反复波动
创造性 低,比较死板 高,能探索创造
数据需求 标准答案必须充足 只需奖励的反馈信号
常见问题 很少出现大的稳定性问题 易出现 KL 爆冲→ 梯度爆炸→ 崩盘

SFT 和 RL 选择

大多数情况下训练模型先 SFT 再 RL 更安全、更高效,尤其是对能力尚弱的小模型或需要严格格式输出的任务。不过这并不是绝对法则,下面补充几点可作为快速校验的要点。

为什么“先 SFT 后 RL”通常更好

  • 训练稳定性
    • 直接 RL(尤其是小模型)容易出现 KL 爆冲、梯度爆炸,模型甚至崩盘。
    • SFT 先把策略锚定在“基本正确、格式合规”的空间,再让 RL 微调,KL 跳变小很多,收敛更稳。
  • 数据利用效率
    • SFT 等价于“先喂答案教基础功”;RL 更像“在掌握基础后练举一反三”。
    • 如果一开始就 RL,模型会在大量无意义探索上浪费步数。
  • 人工标注成本
    • SFT 阶段可用少量高质量标注(或合成高质量标注)直接模仿;
    • RL 阶段只用奖励信号即可继续放大效果,二者配合能节省标注量。

直接 RL 的合理场景

  1. 几乎没有标注数据、但可以自动计算奖励,例如:解数独、玩 Atari 游戏,环境本身给出分数。
  2. 大模型已具备强基础能力 GPT-4、Claude 3-Sonnet 这一级别,格式和基本推理已比较稳,直接 RL(或 RLAIF)效果也可接受。
  3. 任务鼓励高多样性、无法提供单一“标准答案” 如创意写作、对话风格优化,仅用偏好打分即可训练。

实践经验速查表

情况 建议策略 备注
有一批高质量标注 先 SFT,后 RL 主流 RLHF/GRPO Pipeline
只有合成弱标注 可尝试短 SFT + RL 先对齐格式再放大能力
纯交互式/环境奖励 直接 RL/在线 RL 如游戏、机器人控制
预算极低、模型极小 先小规模 SFT,再视情况决定是否加 RL RL 计算开销更大
  1. 我们的奖励函数是不是完全依赖“答案==标准答案”? 如果是,说明我们已经有明确标注;SFT 通常先做更划算。
  2. 我们有多大 GPU/TPU 预算? RL(尤其 GRPO/PPO)往往需要比 SFT 高 2-4 倍的算力。
  3. 任务对“推理链”可解释性要求高吗? 先 SFT(教会标签格式)再 RL(提升正确率)更容易满足可解释输出。

结论

“先 SFT 再 RL”并非硬性规定,但在绝大多数需要结构化输出、且有可用标注的场景下是最省心、最稳妥的路径。只有当标注极少或任务天然提供可计算奖励时,才会优先考虑“直接 RL”。

RL 常见问题

前文提到的 RL 常见的 KL 爆冲、梯度爆炸、模型崩盘问题,本小节详细介绍。

术语 问题本质(到底是哪里出了问题) 属于哪种概念 错误的具体表现(学术描述)
KL 爆冲 模型输出分布变化太剧烈、速度太快 输出分布问题 (Distribution-level issue) KL 散度指标短期内急剧上升(如超过10或更高);策略模型与参考模型的概率分布差异迅速扩大;导致模型输出质量急剧下降,如明显的文本内容混乱、重复或断句异常;
梯度爆炸 训练过程参数更新数值过大、模型变得不稳定 训练过程中参数更新的问题 (Training stability issue) 反向传播过程中梯度范数(Gradient norm)异常剧增(数值巨大甚至趋于无穷或NaN);训练损失(loss)数值异常增大甚至跳跃至无穷大或NaN;模型权重参数技术层面上更新幅度异常大,导致网络计算存溢出或数值退化;
模型崩盘 模型生成的内容变得单一、呆板、无法泛化 模型最终表现问题 (Final-generation quality issue) 模型生成内容多样性急速降低,信息熵(Entropy)显著减小;输出分布退化到极少的模式(如mode collapse),文本表现为反复生成单一或少数固定答案;在训练集之外的数据上表现能力急剧下降,泛化能力严重受损。

一般情况下,这三个问题会组成一条「连锁反应」:

奖励函数设计不佳或超参错误       ↓↓导致↓↓    KL爆冲 --> 梯度爆炸 --> 模型参数剧烈变化或NaN       ↓↓进一步导致↓↓    模型崩盘 (输出单一、低质)

KL 爆冲

KL 散度(Kullback–Leibler Divergence)本质上衡量的确实是两个概率分布之间的差距。在 DPO(Direct Preference Optimization)方法中,参考模型(reference model)训练中模型(policy model)之间计算的就是 KL 散度

用简单例子解释一下:

假设默认模型只会讲三句话:“我们好”、“谢谢”、“再见”。

它现在的“说话概率”(也可以叫“原始概率分布”)是:

鹦鹉自己的当前概率分布(P 分布) 概率
我们好 0.6
谢谢 0.3
再见 0.1

我们心目中理想的“模型应该说话的概率分布”(目标概率分布)是:

我们想要的目标概率分布(Q 分布) 概率
我们好 0.2
谢谢 0.7
再见 0.1

我们希望模型朝着目标概率(Q 分布)学习,但它原本的习惯是当前概率(P 分布)

这时候,为了知道我们的鹦鹉目前的 概率分布 P目标概率分布 Q 差距有多远。

  • KL 散度越小 \= 两个概率越接近
  • KL 散度越大 \= 两个概率分布的差距越明显

在例子中,如果原来模型会说:“我们好(Hello)”,但我们想教它说:”谢谢(Thank you)”,那么就有了:

  • 一个原始模型的分布(Original distribution):擅长说“我们好”;
  • 一个目标模型的分布(Target distribution):我们希望它能学会说“谢谢”。

假设我们给了模型过分高的奖励,比如只要提到“谢谢”,我们奖励20分。模型会在几步内学得太猛,突然所有问题只回复:“谢谢谢谢!”这就是 KL 距离瞬间爆发。

KL 爆冲发生以后,需要用算法调整 KL 惩罚系数(β)

Loss 总 = 奖励损失 + β × KL 散度

提高 β,比如 0.01 → 0.1,约束模型变化的幅度。

梯度爆炸

深度学习中很常见的梯度爆炸问题主要是指:

  • 网络在训练过程中因为某次更新的梯度过大,导致模型参数 突然变化过大,从而网络可能变得不稳定甚至崩溃。

最常见导致梯度爆炸的情况,很少是简单的代码 Bug;事实上更多是算法超参设置不当或数值计算不稳定导致的:

  • 学习率(LR)过大:如原本建议的学习率是 1e-5,但使用了过高学习率(如 1e-2 或者更高),一次参数更新迈步过大,造成梯度过大。
  • 奖励信号设计不合理(尺度过大):有时设计奖励信号时,没有进行归一化处理,例如我们奖励的值过大(比如正常奖励是 ±1,却给了数百甚至上万),导致更新步幅过猛,产生极大的梯度数值。
  • 网络结构本身设计或优化器配置不好:比如神经网络某些层的初始化不合理,或梯度累计出现了数值问题,使得运动过程中梯度持续放大。
  • 未使用梯度裁剪或裁剪设置值过大:如果训练过程中未用梯度裁剪方法,或梯度裁剪的上限值设置过大(如10以上),一旦梯度猛增就不能约束,即可引发梯度爆炸。

算法表现为梯度值剧烈变大甚至 NaN。

模型崩盘

模型崩盘的本质含义是:

  • 模型的参数被 “过度优化” 到单一或极少数的策略上(也称为 Mode Collapse);
  • 策略分布发生严重的退化,模型无法再生成丰富、多样化的内容。

模型崩盘有典型的指标,例如:

  • 输出的熵大幅降低(Entropy↓),表示语言多样性消失;
  • 生成内容变得单一固定,重复度极高;
  • 在训练数据以外的泛化能力和稳健性大幅下降。

算法上,熵的定义是:

熵值 = -sum( p(X_i)*log(p(X_i)) ) # 熵越低,表示模型生成的语言越单调单一,越接近崩盘

一种典型的模型崩盘的表现是:

  • 训练前语言多样性熵值 ≈ 8 到 10;
  • 训练后模型崩盘,语言熵值 下降至 1~2 左右。

模型崩盘最常见的直接原因是源于强化学习训练过程本身的一系列内在问题(尤其是强化学习),例如:

  • 奖励函数过于单一和简单:导致模型倾向走极端,重复一种行为;
  • 长时间训练、KL 问题持续未解决:模型能力持续退化,最终彻底丢失多样性;
  • 连续出现梯度爆炸但未干预:参数持续异常更新,模型能力根本不能正常保留;
  • 数据质量较低或过拟合于一种模式:模型长时间反复学习有限模式,无法泛化。

如果出现上述问题我们还继续训练,鹦鹉最后脑袋就真的弄坏了。比如它彻底只会一招,一问就吐出“苹果苹果”或彻底傻掉不回话,再训练也没用(模型崩溃)。

SFT 与 GRPO 的两阶段训练

接下来,参考 repo 中 code 目录下的训练代码,我们详细介绍 SFT 和 GRPO 的区别。

阶段 代码调用 HF 数据集仓库名 配置(子集) 网址示例
SFT get_limo()load_dataset(“GAIR/LIMO”) GAIR/LIMO 无子配置(默认) https://huggingface.co/datasets/GAIR/LIMO
GRPO get_gsm8k_questions()load_dataset(“openai/gsm8k”, “main”) openai/gsm8k “main” https://huggingface.co/datasets/openai/gsm8k

说明

  1. SFT 阶段脚本里会对 GAIR/LIMO.select(range(1600)) 之类抽样;原始仓库约 817 条(train),938 条(dev+test)。
  2. GRPO 阶段在 openai/gsm8k“main” 配置上取 train split,再 select(range(3500)) 抽子集做 RL;test split 用于离线评测。

要在 Hugging Face Hub 搜索 “GAIR/LIMO” 和 “openai/gsm8k” 即可查看与下载完整数据。

阶段 数据集 (行数) 内容示例 目的
SFT GAIR/LIMO 约 1600 条 K-12 数学题 + 官方解答已包好 & 教模型先学「写作模板 + 推理语气」
GRPO GSM8K-train 3500 条 小学应用题,只有真值数字 用奖励(格式+数值)做 RL 提升正确率

两阶段各自“训练了什么”?

阶段①:SFT(Supervised Fine-Tuning)

  • 训练信号
    • 交叉熵(Cross-Entropy),对教师答案逐 token 强制对齐。
  • 学到内容
    • XML 模板必须完整闭合。
    • 里如何写链式思考(First … Therefore …)。
    • 标签里只出现一个纯数字。
    • LoRA 参数被拉近“正确格式 + 基本推理”的低损失区。
  • 不学/很少学到
    • GSM8K 真值数字(因为数据集不同)。
    • 高阶数学技巧(量太少、只有 1 epoch)。

阶段②:GRPO(Reinforcement Learning, KL-regularized)

  • 训练信号
    • 数值奖励 cor_reward:完全命中 +2;其余 0。
    • 格式奖励 fmt_reward:模板满足 +1;否则 0。
    • 惩罚项 KL:防止行为过度偏离基座。
  • 学到内容
    • 如何把 数字精确等于真值(Exact-Match)。
    • 在保持模板的同时优化上一步数字。
    • 探索 ‑> 投票 ‑> 精修的策略(num_generations=8 + 众数投票)。
  • 不再关注
    • 语言流畅度/用词:奖励里没有相应项。
    • 训练集 LIMO 里的叙述风格(如果在奖励里没加 BLEU/Rouge)。

在两阶段训练中:

  • SFT 主要把模型往“格式正确 + 推理语气自然”方向拉; 在真值层面,由于 LIMO 的答案和 GSM8K 不重叠,加的数值知识有限。
  • GRPO 不仅训练数学,还继续用 fmt_reward 维持格式; 如果把格式奖励权重调成 0,格式率会显著下降。
  • SFT 阶段也会略提升数学(因为 LIMO 题目是算数题),只是提升幅度小; GRPO 阶段才用 3 500 条 GSM8K + 180 步强化专门优化数字。
  • 最终格式 90 %+ 依然是两阶段共同作用的结果——SFT 给起点,GRPO 用奖励守住。

两阶段原始数据集字段如下:

数据集 question solution answer 其它
GAIR/LIMO (SFT 用)
openai/gsm8k (main) (GRPO & 评测) √(含 “#### 72”)

训练脚本脚本 map 之后变成如下格式:

阶段 产生列 字段内容 来自原始列 说明 / 用途
SFT(get_limo prompt question + + solution + + + answer + question / solution / answer 教模型写模板与推理文字
_ completion “” (空串手动补) completion_only_loss=False,CE 覆盖整串
GRPO (get_gsm8k_questions) prompt SYSTEM_PROMPT ⏎ question (无标签、无答案) question 作为 RL 生成起点
_ answer 纯数字 72 (由 split(“####”)[1] 提取) answer 属于真值,供 cor_reward 比对

这样便于核对:

  • LIMO 的 solution 被嵌入 prompt → 模型在 SFT 时学习;
  • GSM8K 的 answer 纯数字保留,供 GRPO 奖励使用;
  • LIMO 的 answer 在 SFT 时只是模板演示,不参与 RL。

SFT 中的训练格式解释

如上一段内容解释,SFT map 后的训练语料并没有 completion 字段。

在 SFTConfig 训练代码里设置了

completion_only_loss=False

这表示“不要只对 completion 计算损失,而是对整条 prompt 进行 teacher-forcing”。在这种模式下,SFTTrainer 并不需要单独的 completion 字段——只要有一列 prompt 含完整参考答案即可。

  1. 但 SFTTrainer 源码要求数据集中必须存在 completion 这一列(无论用不用)。为了省事就补了空 字符串占位,使得字段齐全、代码不报错。
  2. 为什么不把 answer 放进 completion? 如果我们设 completion_only_loss=True,那就需要把 25 部分挪到 completion,让 prompt 只包含系统提示 + question + 。 当前脚本选用整串 CE 方式,所以 completion 留空即可。

简而言之:

  • completion=”” 是占位;
  • 真正的教师文本(含 solution 和 answer)已经在 prompt 里,交叉熵对整串计算,所以不会损失任何监督信息。

SFT 训练损失函数的构建

三种构建 SFT 损失函数方案

方案 prompt 内容 completion 内容 交叉熵作用范围 优点 潜在副作用
A 整串 CE (当前脚本) 系统提示+题干+reasoning+answer 空串 “” 全部 token 1. 格式 & 语言一次性学完2. 梯度稠密,收敛快 1. 数字 token 只占几步,权重被稀释2. 容易过拟合冗长 COT
B 仅 CE 系统提示+题干+reasoning 25 answer 部分 1. 数值权重高,目标集中 1. reasoning 无监督,质量全靠自发2. 格式只约束 answer 标签
C reasoning + answer CE 系统提示+题干 25 reasoning + answer 1. COT 与数字都有监督2. prompt 更短,显存省 1. 格式仍要自己生成 头标签2. 题干与 COT 之间缺乏直接 token 连接,梯度较稀疏

把 COT + answer 全放到 completion(方案 C)会发生什么?

  1. prompt 只剩 “系统提示 + 题干”,长度变短 → 同批显存更低;
  2. model 在训练时只要“读题干 → 预测 reasoning+answer”, 形成经典的 Instruction → Target 教师强制结构;
  3. 优势
    • 数字与 reasoning token 都在 loss 中,权重不被系统提示稀释;
    • prompt 更短,长题目不易溢出 max_seq_length
  4. 可能副作用
    • 如果 很长,占用了 90 % 的 loss,数字又被稀释;
    • 需要保证 首 token 可由题干直接预测, 否则梯度稀疏(题干→ 标签 gap);
    • 格式标签 仍在 completion 内,CE 会学到,但如果生成时 Temperature>0,模型还是可能漏标签,需要 RL 或格式奖励二次约束。

如何选择?

目标 推荐方案 原因
想快速让模型生成“整段 推理+答案”,对语言质量要求高 整串 CE(方案 A) 梯度稠密、格式稳
只在乎最后数字准确率,推理可忽略 answer-only CE(方案 B)+ 后续 RL 数值权重集中
既要推理文本、又想让数字权重更高(数据集大) reasoning+answer CE(方案 C) Prompt 短、两类 token 都被监督,但需多 epoch

在我们当前“小数据、 1 epoch”的设置下,整串 CE 提供最稠密梯度;如果未来扩充 LIMO 到数万条并跑多 epoch,可以考虑方案 C,并在 RL 阶段继续用格式奖励守护模板,以获得更高数值准确率且不过拟合冗长 COT。

如果改成方案 C——把整个

<reasoning>……</reasoning><answer>……</answer>

都放进 completion,只让交叉熵监督这段文本,但增加一个格式类奖励仍然是最稳妥的做法。理由与操作要点如下。

  • 方案 C 把 COT+答案放在 completion 后,模型有潜力更关注数值,但仍可能在生成时漏标签;
  • 保留一个(或低权重)的 fmt_reward 作为安全带是最保险的配置;
  • 可根据任务需要把格式奖励权重动态调低或改成惩罚式,以兼顾准确率与模板稳定性。

设计欠佳奖励函数(优化前的奖励函数)

在强化学习训练中,答案正确性的判断通常通过自动化脚本实现,而非依赖人工标注的表格。以下是具体实现逻辑。

格式奖励函数(format_reward_func)

目标

确保模型输出符合预设的 XML 标签结构

代码实现

import re def format_reward_func(completions, **kwargs):     """检查输出是否符合XML标签格式"""     pattern = r"^<reasoning>[\s\S]*?<\/reasoning>\s*<answer>[\s\S]*?<\/answer>$"     responses = [completion[0]["content"] for completion in completions]     rewards = [1.0if re.match(pattern, response) else0.0for response in responses]     return rewards

逻辑解析

  • 正则表达式匹配: 使用正则表达式 r”^[\s\S]*?<\/reasoning>\s*[\s\S]*?<\/answer>$” 严格检查输出是否包含完整的 标签,且顺序正确。
  • 奖励分配: 符合格式则奖励 1.0 分,否则 0.0 分。

正确性奖励函数(correctness_reward_func)

目标

验证模型输出的数值答案是否与标准答案一致。

代码实现

def correctness_reward_func(completions, answer, **kwargs):     """检查答案是否正确"""     responses = [completion[0]["content"] for completion in completions]     extracted_responses = [extract_last_xml_answer(response) for response in responses]     rewards = [         2.0if extracted == correct else0.0         for extracted, correct inzip(extracted_responses, answer)     ]     return rewards

依赖函数 extract_last_xml_answer

def extract_last_xml_answer(response):     """从XML标签中提取答案(若格式错误,则取最后一个数字)"""     try:         # 尝试解析XML标签         answer = re.search(r"<answer>(.*?)</answer>", response).group(1).strip()         return answer     except:         # 格式错误时,提取最后一个数字         numbers = re.findall(r"\d+\.?\d*", response)         return numbers[-1] if numbers else""

逻辑解析

  • 答案提取:优先从 标签中提取答案;若标签缺失或格式错误,则提取输出中的最后一个数字。
  • 奖励分配:答案与标准答案一致则奖励 2.0 分,否则 0.0 分。

总奖励计算

  • 总分范围0.0 \~ 3.0 总奖励(1.0) + 正确性奖励(2.0)。
  • 归一化处理: GRPO 算法会对组内奖励进行相对归一化(组内个体奖励减去组平均奖励),以平衡探索与利用。

关键设计考量

  • 格式与正确性的权重 正确性奖励(2.0)权重高于格式奖励(1.0),体现“答案正确性优先于格式”的设计原则。
  • 容错机制 即使格式错误,仍尝试提取最后一个数字作为答案,避免因格式问题完全丢弃有效答案。
  • 正则表达式严格性 格式检查使用严格匹配(^…$),确保标签闭合且无多余内容,强制模型学习结构化输出。

奖励函数优化思路

奖励函数优化主要包含:

  1. 细化奖励分值
  2. 增加群组投票

数字奖励 cor_reward (0 / 1 / 2 分)

XML_RE  = re.compile(r"<answer>(.*?)</answer>", re.S) _num    = lambda x: re.sub(r"[%$,]", "", x).strip() def _extract_nums(text: str):     return [_num(m) for m in XML_RE.findall(text)] def cor_reward(completions, **kw):     answers = kw.get("answer") or kw.get("answers") or []     rewards = []     for cand_list, gt inzip(completions, answers):         # 1) 收集 8 条回答里的所有 <answer>…</answer> 数字         nums = [             n             for c in cand_list             for n in _extract_nums(c["content"])         ]         # 2) 若一个数字都没抓到 → 直接 0 分         ifnot nums:             rewards.append(0.0)             continue         # 3) 群组投票:出现次数最多的数字         vote = Counter(nums).most_common(1)[0][0]         # 4) 评分:完全对 +2,差 1 +1,其余 0         diff = abs(int(vote) - int(gt)) if vote.isdigit() and gt.isdigit() else999         if   diff == 0: rewards.append(2.0)         elif diff == 1: rewards.append(1.0)         else:           rewards.append(0.0)     return rewards

详细步骤

1. _extract_nums()

  • 用正则在单条回答文本里找所有
  • _num() 去掉 $ % , 等符号,得纯数字字符串。

2. 组内投票(majority vote)

vote = Counter(nums).most_common(1)[0][0]
  • 把 8 条回答汇总得到的数字列表 nums 做统计;
  • 选出现频率最高的那一个(若并列,Counter 取第一出现)。
  • 投票的好处:
    • 抑制偶然的随机数;
    • 让模型有动力让多个回答趋向一致=正确数值。

3. 分级奖励 diff = |vote – ground_truth| – diff == 0 → +2 (完全正确) – diff == 1 → +1 (只差 1,也给部分梯度) – else → 0 (远离真值) 这样 early 训练阶段更容易拿到非零 reward,梯度稠密,KL 更平滑。

输出示例

batch_size = 8 cor_reward → [2,1,0,2,0,1,0,2] fmt_reward → [1,1,0,1,1,1,0,1] total_reward → [3,2,0,3,1,2,0,3]

新旧奖励函数对比:

_ 旧版 cor_reward 新版 cor_reward (群组投票+部分分)
采样数目 只看第 1 条回答 利用 8 条回答,众数投票
评分标准 完全对 +2,否则 0 完全对 +2;差 1 +1;其余 0
梯度稀疏 早期大量 0 分 早期平均 reward 即可达 1.0+
格式耦合 格式奖励独立;数字奖励看所有回答

结果:

  • fmt_reward_mean 更快爬到 0.9;
  • cor_reward_mean 抬到 \~1.2(≈30% 完全对 + 35% 差 1);
  • KL 控制在 <0.2,训练稳定;
  • 总 reward 1.8→2.1 左右,比旧版提升约 10 %。

群组 vote 的合理性研究

先投票再对真值”合不合理,要看我们希望奖励函数起什么作用。

合理方面

  1. 自洽性(Self-Consistency)的经验规律 OpenAI、Google 论文都表明: “同一 prompt 让模型多生成几条推理,用众数/平均值作为最终答案, 准确率往往高于单条输出。” 投票奖励把这个经验直接注入 RL:
    • 如果 8 条里 ≥4 条写 42,那 42 很可能就是正确答案;
    • 早期即使 8 条回答各不相同,也能把出现次数最多的那个作为 “模型当前最确信” 的猜测。
  2. 梯度密度更高
    • 纯 0/2 模式:完全错 = 0,很容易 reward 全 0;
    • 投票 + 差 1 给 1 分:早期也能拿到非零 reward,梯度方向更连续。
  3. 利用并行生成的计算成本 既然我们已经花显存一次性生成了 8 条回答,把它们全都用来评奖要 比只看第一条更物超所值。
  4. 格式门控 + 数值投票分离 先用格式奖励约束输出形状,再用投票奖励评数值;两部分可独立调 权重,互不干扰。

局限性

  1. “集体跑偏” 如果模型内部存在系统性错误(8 条都写 41,但真值 42),投票仍会 选错。此时 reward 仍给 0 / 1,梯度作用有限。
  2. 并列众数的歧义 Counter.most_common(1) 默认返回先出现的数字; 若票数打平,选择具有随机性,可能带来噪声。 → 可以设阈值:只有票数 ≥4 才用众数,否则 reward=0。
  3. 差值阈值的 trade-off ‑ 差 1 给 1 分能 densify 梯度; ‑ 但如果阈值太宽(差 5 也给分)会削弱“完全正确”的驱动力。
  4. 生成条数与开销 num_generations=8 对 A100 2B 模型还算轻;如果用更大的模型或者 更长 completion,生成 8 条会拖慢训练。

如何让投票更棒(进一步优化的可能性)

表格 建议 作用
票数阈值 vote, cnt = Counter(nums).most_common(1)[0]; if cnt < 4: reward=0 避免 2∶2∶2∶2 平票时的随机性
格式过滤 if not XML_RE.match(c[“content”]): continue 不让无效回答影响投票
置信度加权 用 logits 概率给数字加权平均,而非纯计数 兼顾概率信息
多众数比较 如果出现两个众数且都差 0/1,各给 1.5 分 减少随机性噪声
差值梯度 score = max(0, 2 – diff)(差 2 得 0、差 1 得 1) reward 曲线更平滑

结论

  • 小批量、有限步数的 RL 微调而言,“投票→对真值” 的奖励能显著缓解 0/2 稀疏问题,让 early reward 更快爬升,并在格式合规率上带来明显增益,是一个合理且常用的技巧。
  • 如果我们更在意“对/错的严格区分”,可以保持 diff==0 才给分; 若更在意收敛速度和平滑梯度,保留差 1 + 部分分会更友好。

因此,是否保留投票机制取决于:

  • 我们能否接受多生成几条回答的时间 / 显存成本;
  • 我们更关注最终极限正确率(可考虑后期关闭差 1 奖励), 还是关注训练效率和稳定性(保留投票 + 部分分)。

训练结果指标解读

SFTTrainer 日志里出现字段:

字段 含义 典型范围 计算方式
loss teacher-forcing 交叉熵平均值(越低越好) 0.7 → 0.3 CrossEntropy(outputs, labels)
mean_token_accuracy token 级 top-1 准确率 0.65 → 0.80 1 – ppl 近似值
num_tokens 当前 step 处理的 token 数 batch×seq_len 统计 tokenizer 输入长度
train_runtime 整个 epoch 耗时 (最终行) 280-300 s end_time – start_time
train_samples_per_second 每秒处理样本数 ≈(batch/step)/sec HF Trainer 统计
train_steps_per_second 每秒更新步数 ≈1 / step_latency HF Trainer 统计
train_loss 全 epoch 的 loss 平均值(最终行) 0.85 所有 step loss 加权平均

SFT、GRPO 通用字段:

字段 含义 备注
epoch 当前步对应的 epoch 比例 0.08 = 8% 已训练进度
loss SFT:交叉熵;GRPO:KL − reward GRPO 中越“低”未必越好
grad_norm 当前梯度 L2 范数,过大可能爆炸 通常 0.1-3 之间
learning_rate 每 step 动态学习率 线性/余弦调度
num_tokens step 内处理 token 数 生成任务包含 prompt+completion
logging_steps n 步打印一次,决定日志行粒度 配置里的 logging_steps

GRPOTrainer 特有字段:

字段名 (日志 key) 含义 判读规则 好/坏典型阈值
rewards/cor_reward/mean 数字奖励均值(+2 完全正确;+1 仅差 1;0 其余) ↑ 越高越好 ≥1.2 ⇒ \~30 % 完全对
rewards/fmt_reward/mean XML 格式奖励均值(满足模板得 +1) ↑ 越高越好 ≥0.90 ⇒ 90 % 合规
reward cor + fmt 的批均值 ∈ [0, 3] ↑ 越高越好 ≥2.0 ⇒ 整体表现佳
reward_std 批内 reward 的标准差 中等即可 0.3–0.8 正常;>1 波动大
frac_reward_zero_std reward=0 的样本比例 ↓ 越低越好 <0.3 ⇒ 梯度稠密
kl 策略与底座模型的 KL 散度 中等最好 0.05–0.25 安全;>0.4 可能发散
loss β·KL – reward(GRPO 目标) 趋势即可 上下波动属正常
grad_norm 当前梯度 L2 范数 ↓ 避免爆 ≤1 稳定;>5 需调梯度裁剪
completions/mean_length 8 条回答平均 token 长度 监控长度 80–110 正常;<50 推理不足
completions/clipped_ratio 回答被 max_completion_length 截断的比例 ↓ 越低越好 <0.2 最佳;>0.4 考虑加长
epoch 已训练进度 (0-1 = 0-100 %) 用来对齐时间点

备注:

  1. fmt_reward/mean ≥ 0.9 → 模板输出稳定。
  2. cor_reward/mean ≥ 1.2 → 30 % 以上完全正确(好)。
  3. kl < 0.3 → 更新稳定;若突涨,需减小学习率 / β。
  4. frac_reward_zero_std < 0.3 → 奖励信号足够密集。
  5. completions/clipped_ratio > 0.4 → 说明 128 token 不够,可调大。

奖励函数优化对训练效果实测

奖励函数优化前

训练

source .venv/bin/activate root@a100vm:~/Gemma-2-2B-IT-GRPO# pwd /root/Gemma-2-2B-IT-GRPO root@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo2.py  root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo2.py

训练中的资源利用率

(Gemma-2-2B-IT-GRPO) root@a100vm:~/Gemma-2-2B-IT-GRPO# nvidia-smi Sun Jun 1511:44:162025        +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     | |-----------------------------------------+------------------------+----------------------+ | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC | | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. | |                                         |                        |               MIG M. | |=========================================+========================+======================| |   0  NVIDIA A100 80GB PCIe          Off |   00000001:00:00.0 Off |                    0 | | N/A   49C    P0            109W /  300W |   80793MiB /  81920MiB |     48%      Default | |                                         |                        |             Disabled | +-----------------------------------------+------------------------+----------------------+                                                                                           +-----------------------------------------------------------------------------------------+ | Processes:                                                                              | |  GPU   GI   CI        PID   Type   Process name                              GPU Memory | |        ID   ID                                                               Usage      | |=========================================================================================| |    0   N/A  N/A    449318      C   python                                      80780MiB | +-----------------------------------------------------------------------------------------+

评估

python3 -m venv ~/eval-env source ~/eval-env/bin/activate pip install "torch>=2.1""transformers>=4.49" datasets tqdm pip install accelerate python gsm8k-eval-tf2.py --model_dir gemma-grpo-only python gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

纯 GRPO

---------------------------------------- Input tokens  avg=140.5  max=269 Output tokens avg=90.9  max=257 Correct format     : 1142/1319 (86.6%) Plausibly correct  : 566/1319 (42.9%) Exact correct      : 559/1319 (42.4%) ========================================

SFT+GRPO

---------------------------------------- Input tokens  avg=140.5  max=269 Output tokens avg=74.7  max=257 Correct format     : 1192/1319 (90.4%) Plausibly correct  : 504/1319 (38.2%) Exact correct      : 500/1319 (37.9%) ======================================== (eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化后

训练

source .venv/bin/activate root@a100vm:~/Gemma-2-2B-IT-GRPO# pwd /root/Gemma-2-2B-IT-GRPO root@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo3.py  root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo3.py

评估

python3 -m venv ~/eval-env source ~/eval-env/bin/activate pip install "torch>=2.1""transformers>=4.49" datasets tqdm pip install accelerate python gsm8k-eval-tf2.py --model_dir gemma-grpo-only python gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

仅 GRPO

---------------------------------------- Input tokens  avg=140.5  max=269 Output tokens avg=92.2  max=257 Correct format     : 1120/1319 (84.9%) Plausibly correct  : 665/1319 (50.4%) Exact correct      : 657/1319 (49.8%) ======================================== (eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

SFT+GRPO

---------------------------------------- Input tokens  avg=140.5  max=269 Output tokens avg=75.5  max=257 Correct format     : 1161/1319 (88.0%) Plausibly correct  : 506/1319 (38.4%) Exact correct      : 505/1319 (38.3%) ======================================== (eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化前后对比(仅仅对比 GRPO)

指标 旧奖励 新奖励 差值 (B-A) 谁更优
Correct format(格式合规率) 86.6 % (1142/1319) 84.9 % (1120/1319) – 1.7 pt 持平
Exact correct(完全命中真值) 42.4 % (559/1319) 49.8 % (657/1319) + 7.4 pt B
Plausibly correct(数字或 XML 命中) 42.9 % (566/1319) 50.4 % (665/1319) + 7.5 pt B
输出长度 avg / max 90.9 / 257 92.2 / 257 +1.3 tok 持平

结论

  1. 数值准确率 新奖励把完全正确率提升了约 7 个百分点,这是只奖励 exact-match 的直接收益。
  2. 格式合规率 基本持平(因为并没有优化格式奖励)
  3. 后续细化奖励规则,增加训练 step 数,准确率有望继续提升。

魏新宇

微软 AI 全球黑带高级技术专家

著有《大语言模型原理、训练及应用》《金融级 IT 架构与运维》《OpenShift 在企业中的实践》v1&v2、《云原生应用构建》。

(文:AIGC开放社区)

极客说|强化学习(RL)与有监督微调(SFT)的选择以及奖励函数的优化最先出现在每时AI


扫描二维码,在手机上阅读