RLHF讲解
Swift Lv6

RLHF包含了两个至关重要的步骤:

  1. 训练Reward Model
  2. 用Reward Model和SFT Model构造Reward Function,基于PPO算法来训练LLM
    1. frozen RM
    2. frozen SFT Model
    3. Actor $\pi_{\Phi}^{R L}$ initialized from SFT Model
    4. Critic $V_\eta$ initialized from RM

最大化目标函数:

rlhf

训练流程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
policy_model = load_model()
ref_policy_model = policy_model.copy()

for k in range(20000):
# 采样(上一个epoch的actor模型和critic模型)
prompts = sample_prompt()
# old_log_probs是上一个epoch的actor模型的对数概率
# old_values是上一个epoch的critic模型的预估期望收益
responses, old_log_probs, old_values = respond(policy_model, prompts)

# 反馈
# 固定的reward模型
scores = reward_model(prompts, responses)
# 固定的sft模型
ref_log_probs, _ = analyze_responses(ref_policy_model, prompts, responses)
rewards = reward_func(reward_model, scores, old_log_probs, ref_log_probs)

# 学习,为了更新actor和critic模型
for epoch in range(4):
# 这里的values用于更新critic模型
log_probs, values = analyze_responses(policy_model, prompts, responses)
advantages = advantage_func(rewards, old_values)
actor_loss = actor_loss_func(advantages, old_log_probs, log_probs)
critic_loss = critic_loss_func(rewards, values)
loss = actor_loss + 0.1 * critic_loss
train(loss, policy_model.parameters())

  • frozen RM 和 frozen SFT是用来计算rewards的
  • actor和critict会在epoch训练中同步更新

参考

Powered by Hexo & Theme Keep
Unique Visitor Page View