-
Notifications
You must be signed in to change notification settings - Fork 804
Expand file tree
/
Copy pathloss.py
More file actions
153 lines (143 loc) · 5.77 KB
/
loss.py
File metadata and controls
153 lines (143 loc) · 5.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict
import torch
from art.utils.group_aggregate import group_aggregate
from . import dev
if TYPE_CHECKING:
from art.unsloth.service import TrainInputs
class Loss(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mean_policy_loss: torch.Tensor
mean_entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor
kl_policy_ref: torch.Tensor | None = None
def loss_fn(
inputs: "TrainInputs",
new_logprobs: torch.Tensor,
ref_logprobs: torch.Tensor | None,
entropies: torch.Tensor | None,
experimental_config: dev.TrainConfig,
) -> Loss:
old_logprobs = shift_tensor(inputs["logprobs"], float("nan"))
advantages = shift_tensor(inputs["advantages"], 0.0)
assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
new_logprobs.dtype
)
weights = shift_tensor(inputs["weights"], 0.0)
old_logprobs_mask = ~torch.isnan(old_logprobs)
probs_corr = torch.corrcoef(
torch.stack(
[
torch.exp(old_logprobs[old_logprobs_mask]),
torch.exp(new_logprobs[old_logprobs_mask]),
]
)
)[0, 1]
# Assume missing old logprobs were sampled under the current policy
old_logprobs = torch.where(
torch.isnan(old_logprobs),
new_logprobs.detach(),
old_logprobs,
)
logprob_diff = new_logprobs - old_logprobs
importance_sampling_level = experimental_config.get(
"importance_sampling_level", "token"
)
prob_ratio = torch.exp(logprob_diff)
if importance_sampling_level != "token":
sequence_prob_ratio = torch.exp(
group_aggregate(
logprob_diff,
by=shift_tensor(inputs["group_ids"], 0) * assistant_mask,
reduce="mean",
)
)
if importance_sampling_level == "sequence":
prob_ratio = sequence_prob_ratio
elif importance_sampling_level == "average":
prob_ratio = (prob_ratio + sequence_prob_ratio) / 2
elif importance_sampling_level == "geometric_average":
prob_ratio = (prob_ratio**0.5) * (sequence_prob_ratio**0.5)
ppo = experimental_config.get("ppo", False)
if ppo:
epsilon_default = 0.2
epsilon_high_default = None
else:
epsilon_default = 1.0
epsilon_high_default = 4.0
epsilon = experimental_config.get("epsilon", epsilon_default)
epsilon_high = experimental_config.get("epsilon_high", epsilon_high_default)
if epsilon_high is None:
epsilon_high = epsilon
if max_negative_advantage_importance_sampling_weight := experimental_config.get(
"max_negative_advantage_importance_sampling_weight", None
):
prob_ratio = torch.clamp(
prob_ratio, max=max_negative_advantage_importance_sampling_weight
)
if experimental_config.get("mask_prob_ratio", False):
prob_ratio = torch.where(
(prob_ratio > 1 - epsilon) & (prob_ratio < 1 + epsilon_high),
prob_ratio,
0.0,
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
kl_policy_ref: torch.Tensor | None = None
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
if kl_penalty_coef > 0 and ref_logprobs is not None:
match experimental_config.get("kl_penalty_source", "current_learner"):
case "sample":
kl_source_logprobs = old_logprobs.detach()
case "current_learner":
kl_source_logprobs = new_logprobs.detach()
case other:
raise AssertionError(other)
kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
advantages = advantages + kl_penalty
kl_policy_ref = avg_kl
if ppo:
policy_loss = -torch.min(
prob_ratio * advantages,
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
)
else:
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
policy_loss = -(
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
* advantages
* new_logprobs
)
if upper_bound := experimental_config.get("truncated_importance_sampling", None):
if "original_logprobs" in inputs:
original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0) # ty:ignore[invalid-key]
original_logprobs = torch.where(
torch.isnan(original_logprobs),
new_logprobs.detach(),
original_logprobs,
)
logprob_diff = old_logprobs - original_logprobs
prob_ratio = torch.exp(logprob_diff)
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
policy_loss = policy_loss * weights * assistant_mask
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
# Compute mean entropy for the current step
if entropies is not None:
shifted_entropies = shift_tensor(entropies, 0.0)
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
else:
mean_entropy = None
return Loss(
mean_policy_loss=mean_policy_loss,
mean_entropy=mean_entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
kl_policy_ref=kl_policy_ref,
)
def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor:
return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad)