Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions examples/gemma/configs/gemma3_270m_sow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import dataclasses

from train import TrainConfig


@dataclasses.dataclass(unsafe_hash=True)
class Config:
# Path to load or store sentencepiece vocab file.
vocab_path: str | None = None
# Vocabulary size if `vocab_path` is not given.
vocab_size: int = 35_008 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144)
# Maximum number of characters to use for training.
max_corpus_chars: int = 10**7
# Name of TFDS translation dataset to use.
dataset_name: str = 'lm1b'
# Optional name of TFDS translation dataset to use for evaluation.
eval_dataset_name: str = 'lm1b'
# Optional name of TFDS split to use for evaluation.
eval_split: str = 'test'
# Per device batch size for training.
per_device_batch_size: int = 16
# Per device batch size for training.
eval_per_device_batch_size: int = 16
# Grain prefetch number of workers.
prefetch_num_workers: int | None = None

# Prompt for language model sampling
prompts: tuple[str, ...] = (
'Paris is a the capital',
'Flax is a',
# From train set:
'The shutdown was aimed at creating efficiencies as',
# -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day
'A big theme of this hire is that there are parts of',
# -> our operations that to use a pretty trite phrase , need to be taken to the next level ...

# From test set:
'Because of Bear Stearns , many analysts are',
# -> raising the odds that a 2008 recession could be worse than expected
'Next month , the Brazilian bourse',
# -> opens a London office',
)
# Temperature for top_p sampling.
sampling_temperature: float = 0.0
# Top-p sampling threshold.
sampling_top_p: float = 0.95

# Number of steps to take during training.
num_train_steps: int = 100_000
# Number of steps to take during evaluation.
# Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198
# num_eval_steps: int = 2_000
num_eval_steps: int = 500
# Number of steps to generate predictions.
# -1 will use the whole eval dataset.
num_predict_steps: int = 50
# Base learning rate.
learning_rate: float = 0.0016
# Linear learning rate warmup.
warmup_steps: int = 1000
# Cross entropy loss label smoothing.
label_smoothing: float = 0.0
# Decay factor for AdamW style weight decay.
weight_decay: float = 0.1
# Maximum length cutoff for training examples.
max_target_length: int = 1024
# Maximum length cutoff for eval examples.
max_eval_target_length: int = 1024

# Gemma transformer name.
# Possible values defined in transformer.TransformerConfig:
# (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_270m, gemma3_1b, gemma3_4b, ...)
transformer_name: str | None = "gemma3_270m"
# or alternatively define the model using the dict of parameters
transformer_params: dict | None = None

# Whether to save model checkpoints.
save_checkpoints: bool = True
# Whether to restore from existing model checkpoints.
restore_checkpoints: bool = True
# Save a checkpoint every these number of steps.
checkpoint_every_steps: int = 10_000
# Frequency of eval during training, e.g. every 2_000 steps.
eval_every_steps: int = 2_000
# Use bfloat16 mixed precision training instead of float32.
use_bfloat16: bool = True
# Integer for PRNG random seed.
seed: int = 0

# Parallelism
mesh_axes: tuple[str, ...] = ('fsdp', 'tensor')
data_sharding: tuple[str, ...] = ('fsdp', )

fsdp_parallelism: int = -1
tensor_parallelism: int = 1

sow_config: dict = dataclasses.field(
default_factory=lambda: {
"rs_after_attention": True,
"rs_after_ffw": True,
"attn_logits_topk": 5,
"mlp_hidden_topk": 5,
}
)

def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)


def get_config() -> TrainConfig:
"""Get the default hyperparameter configuration."""
config = Config()
return TrainConfig(**dataclasses.asdict(config))
21 changes: 16 additions & 5 deletions examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def _sample_step(
last_token = last_token.reshape((batch_size, 1))

transformer = nnx.merge(graphdef, params)
logits, cache = transformer(
forward = nnx.capture(transformer, nnx.Intermediate)
(logits, cache), intermediates = forward(
last_token,
step_positions,
sampler_state.cache,
Expand Down Expand Up @@ -267,7 +268,7 @@ def sample_best(logits):
logits_buffer = sampler_state.logits_buffer

if sampler_state.intermediates is not None:
sampler_state.intermediates.merge(decoding_step, transformer)
sampler_state.intermediates.merge(decoding_step, intermediates)

done = sampler_state.done | jnp.equal(
token_buffer[:, decoding_step + 1], self.vocab.eos_id()
Expand Down Expand Up @@ -350,6 +351,18 @@ def init_sample_state(
else:
logits_buffer = None

intermediates = None
if self.transformer.sow_config.is_enabled():
intermediates = sow_lib.init_intermediates(
batch_size,
buffer_size,
self.transformer.embed_dim,
self.transformer.num_layers,
self.transformer.config.num_heads,
self.transformer.sow_config,
dtype=dtype
)

return _SamplingState(
decoding_step=0,
num_input_tokens=num_input_tokens,
Expand All @@ -364,9 +377,7 @@ def init_sample_state(
done=done,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
intermediates=self.transformer.init_intermediates(
batch_size, buffer_size, self.transformer.sow_config, dtype=dtype
),
intermediates=intermediates,
temperature=temperature,
top_p=top_p,
seed=seed,
Expand Down
96 changes: 84 additions & 12 deletions examples/gemma/sow_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import auto_axes


@jax.tree_util.register_dataclass
Expand All @@ -35,7 +36,7 @@ class LayerIntermediates:
attn_logits_topk_values: jax.Array | None = None
attn_logits_topk_indices: jax.Array | None = None

def merge(self, decoding_step, layer: nnx.Module):
def merge(self, decoding_step, layer: nnx.State):
"""Merges the intermediate activations from one step."""

for field in dataclasses.fields(self.__class__):
Expand All @@ -47,9 +48,7 @@ def merge(self, decoding_step, layer: nnx.Module):
# sub-module.
try:
if field.name.startswith('attn_'):
step_value = getattr(
layer.attn, field.name.replace('attn_', '')
)[0]
step_value = getattr(layer.attn, field.name.replace('attn_', ''))[0]
elif field.name.startswith('mlp_'):
step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0]
else:
Expand Down Expand Up @@ -86,24 +85,25 @@ class TransformerIntermediates:
# Intermediate activations of each layer.
layers: list[LayerIntermediates] = dataclasses.field(default_factory=list)

def merge(self, decoding_step, transformer: nnx.Module):
def merge(self, decoding_step, intermediates: nnx.State):
"""Merges the intermediate activations from one step."""
if self.embeddings is not None:
try:
self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
transformer.embeddings[0][:, 0, ...]
intermediates.embeddings[0][:, 0, ...]
)
except AttributeError as exc:
raise ValueError(
'Embeddings are not in the step intermediates.'
) from exc
if len(self.layers) != len(transformer.layers):
if len(self.layers) != len(intermediates.layers):
raise ValueError(
'Number of layers in the transformer and intermediates do not match.'
)
for layer_intermediates, layer_module in zip(
self.layers, transformer.layers
for layer_intermediates, layer_idx in zip(
self.layers, intermediates.layers
):
layer_module = intermediates.layers[layer_idx]
layer_intermediates.merge(decoding_step, layer_module)

def trim(self, max_length: int):
Expand Down Expand Up @@ -136,6 +136,15 @@ class SowConfig:
# We use a sparse representation here to save memory.
attn_logits_topk: int = 0

def is_enabled(self):
return any([
self.embeddings,
self.rs_after_attention,
self.rs_after_ffw,
self.mlp_hidden_topk,
self.attn_logits_topk,
])

def maybe_sow_embeddings(
self,
embeddings: jax.Array,
Expand Down Expand Up @@ -170,8 +179,13 @@ def maybe_sow_mlp_hidden_topk(
):
"""Sows top-absolute-k activations in a mlp hidden layer if configured."""
if self.mlp_hidden_topk:
_, indices = jax.lax.top_k(jnp.abs(activations), self.mlp_hidden_topk)
values = jnp.take_along_axis(activations, indices, axis=-1)
shd = jax.typeof(activations).sharding
top_k = lambda x: jax.lax.top_k(x, self.mlp_hidden_topk)
_, indices = auto_axes(out_sharding=shd)(top_k)(jnp.abs(activations))
take_aa = lambda x, i, axis: jnp.take_along_axis(x, i, axis)
values = auto_axes(out_sharding=shd)(take_aa)(
activations, indices, axis=-1
)
module.sow(nnx.Intermediate, 'hidden_topk_values', values)
module.sow(nnx.Intermediate, 'hidden_topk_indices', indices)

Expand All @@ -182,6 +196,64 @@ def maybe_sow_attn_logits_topk(
):
"""Sows top-k attention logits if configured."""
if self.attn_logits_topk:
values, indices = jax.lax.top_k(logits, self.attn_logits_topk)
shd = jax.typeof(logits).sharding
top_k = lambda x: jax.lax.top_k(x, self.attn_logits_topk)
values, indices = auto_axes(out_sharding=shd)(top_k)(logits)
module.sow(nnx.Intermediate, 'logits_topk_values', values)
module.sow(nnx.Intermediate, 'logits_topk_indices', indices)


def init_intermediates(
batch_size: int,
buffer_size: int,
embed_dim: int,
num_layers: int,
num_heads: int,
sow_config: SowConfig,
dtype: jnp.dtype = jnp.float32,
) -> TransformerIntermediates:
"""Initializes the intermediate activations that will be filled."""
intermediates = TransformerIntermediates()
residual_stream_dummy = jnp.zeros(
(batch_size, buffer_size, embed_dim),
dtype=dtype,
)
if sow_config.embeddings:
intermediates.embeddings = residual_stream_dummy
for _ in range(num_layers):
layer_intermediates = LayerIntermediates()
if sow_config.rs_after_attention:
layer_intermediates.rs_after_attention = residual_stream_dummy
if sow_config.rs_after_ffw:
layer_intermediates.rs_after_ffw = residual_stream_dummy
if sow_config.attn_logits_topk:
shape = (
batch_size,
buffer_size,
num_heads,
sow_config.attn_logits_topk,
)
layer_intermediates.attn_logits_topk_values = jnp.zeros(
shape,
dtype=dtype,
)
layer_intermediates.attn_logits_topk_indices = jnp.zeros(
shape,
dtype=jnp.int32,
)
if sow_config.mlp_hidden_topk:
shape = (
batch_size,
buffer_size,
sow_config.mlp_hidden_topk,
)
layer_intermediates.mlp_hidden_topk_values = jnp.zeros(
shape,
dtype=dtype,
)
layer_intermediates.mlp_hidden_topk_indices = jnp.zeros(
shape,
dtype=jnp.int32,
)
intermediates.layers.append(layer_intermediates)
return intermediates
7 changes: 4 additions & 3 deletions examples/gemma/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ class TokenizeOpNumpy:
def __call__(self, features: dict[str, Any]) -> dict[str, Any]:
for k in self.data_keys:
features[k] = self.sp_processor.EncodeAsIds(features[k])
features[k].insert(0, self.sp_processor.bos_id)
features[k].append(self.sp_processor.eos_id)
features[k] = np.array(features[k], dtype=np.int32)
features[k] = np.array(
[self.sp_processor.bos_id()] + features[k] + [self.sp_processor.eos_id()],
dtype=np.int32
)
return features


Expand Down
Loading
Loading