diff --git a/examples/gemma/configs/gemma3_270m_sow.py b/examples/gemma/configs/gemma3_270m_sow.py new file mode 100644 index 000000000..f9cd72354 --- /dev/null +++ b/examples/gemma/configs/gemma3_270m_sow.py @@ -0,0 +1,129 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import dataclasses + +from train import TrainConfig + + +@dataclasses.dataclass(unsafe_hash=True) +class Config: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None = None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int = 35_008 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) + # Maximum number of characters to use for training. + max_corpus_chars: int = 10**7 + # Name of TFDS translation dataset to use. + dataset_name: str = 'lm1b' + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str = 'lm1b' + # Optional name of TFDS split to use for evaluation. + eval_split: str = 'test' + # Per device batch size for training. + per_device_batch_size: int = 16 + # Per device batch size for training. + eval_per_device_batch_size: int = 16 + # Grain prefetch number of workers. + prefetch_num_workers: int | None = None + + # Prompt for language model sampling + prompts: tuple[str, ...] = ( + 'Paris is a the capital', + 'Flax is a', + # From train set: + 'The shutdown was aimed at creating efficiencies as', + # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day + 'A big theme of this hire is that there are parts of', + # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... + + # From test set: + 'Because of Bear Stearns , many analysts are', + # -> raising the odds that a 2008 recession could be worse than expected + 'Next month , the Brazilian bourse', + # -> opens a London office', + ) + # Temperature for top_p sampling. + sampling_temperature: float = 0.0 + # Top-p sampling threshold. + sampling_top_p: float = 0.95 + + # Number of steps to take during training. + num_train_steps: int = 100_000 + # Number of steps to take during evaluation. + # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 + # num_eval_steps: int = 2_000 + num_eval_steps: int = 500 + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int = 50 + # Base learning rate. + learning_rate: float = 0.0016 + # Linear learning rate warmup. + warmup_steps: int = 1000 + # Cross entropy loss label smoothing. + label_smoothing: float = 0.0 + # Decay factor for AdamW style weight decay. + weight_decay: float = 0.1 + # Maximum length cutoff for training examples. + max_target_length: int = 1024 + # Maximum length cutoff for eval examples. + max_eval_target_length: int = 1024 + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_270m, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None = "gemma3_270m" + # or alternatively define the model using the dict of parameters + transformer_params: dict | None = None + + # Whether to save model checkpoints. + save_checkpoints: bool = True + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool = True + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int = 10_000 + # Frequency of eval during training, e.g. every 2_000 steps. + eval_every_steps: int = 2_000 + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool = True + # Integer for PRNG random seed. + seed: int = 0 + + # Parallelism + mesh_axes: tuple[str, ...] = ('fsdp', 'tensor') + data_sharding: tuple[str, ...] = ('fsdp', ) + + fsdp_parallelism: int = -1 + tensor_parallelism: int = 1 + + sow_config: dict = dataclasses.field( + default_factory=lambda: { + "rs_after_attention": True, + "rs_after_ffw": True, + "attn_logits_topk": 5, + "mlp_hidden_topk": 5, + } + ) + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + +def get_config() -> TrainConfig: + """Get the default hyperparameter configuration.""" + config = Config() + return TrainConfig(**dataclasses.asdict(config)) diff --git a/examples/gemma/sampler.py b/examples/gemma/sampler.py index 1ab336d15..2bad6260c 100644 --- a/examples/gemma/sampler.py +++ b/examples/gemma/sampler.py @@ -223,7 +223,8 @@ def _sample_step( last_token = last_token.reshape((batch_size, 1)) transformer = nnx.merge(graphdef, params) - logits, cache = transformer( + forward = nnx.capture(transformer, nnx.Intermediate) + (logits, cache), intermediates = forward( last_token, step_positions, sampler_state.cache, @@ -267,7 +268,7 @@ def sample_best(logits): logits_buffer = sampler_state.logits_buffer if sampler_state.intermediates is not None: - sampler_state.intermediates.merge(decoding_step, transformer) + sampler_state.intermediates.merge(decoding_step, intermediates) done = sampler_state.done | jnp.equal( token_buffer[:, decoding_step + 1], self.vocab.eos_id() @@ -350,6 +351,18 @@ def init_sample_state( else: logits_buffer = None + intermediates = None + if self.transformer.sow_config.is_enabled(): + intermediates = sow_lib.init_intermediates( + batch_size, + buffer_size, + self.transformer.embed_dim, + self.transformer.num_layers, + self.transformer.config.num_heads, + self.transformer.sow_config, + dtype=dtype + ) + return _SamplingState( decoding_step=0, num_input_tokens=num_input_tokens, @@ -364,9 +377,7 @@ def init_sample_state( done=done, total_sampling_steps=total_sampling_steps, forbidden_token_ids=forbidden_token_ids, - intermediates=self.transformer.init_intermediates( - batch_size, buffer_size, self.transformer.sow_config, dtype=dtype - ), + intermediates=intermediates, temperature=temperature, top_p=top_p, seed=seed, diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 7580cdfe2..3f869b48d 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -18,6 +18,7 @@ from flax import nnx import jax import jax.numpy as jnp +from jax.sharding import auto_axes @jax.tree_util.register_dataclass @@ -35,7 +36,7 @@ class LayerIntermediates: attn_logits_topk_values: jax.Array | None = None attn_logits_topk_indices: jax.Array | None = None - def merge(self, decoding_step, layer: nnx.Module): + def merge(self, decoding_step, layer: nnx.State): """Merges the intermediate activations from one step.""" for field in dataclasses.fields(self.__class__): @@ -47,9 +48,7 @@ def merge(self, decoding_step, layer: nnx.Module): # sub-module. try: if field.name.startswith('attn_'): - step_value = getattr( - layer.attn, field.name.replace('attn_', '') - )[0] + step_value = getattr(layer.attn, field.name.replace('attn_', ''))[0] elif field.name.startswith('mlp_'): step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0] else: @@ -86,24 +85,25 @@ class TransformerIntermediates: # Intermediate activations of each layer. layers: list[LayerIntermediates] = dataclasses.field(default_factory=list) - def merge(self, decoding_step, transformer: nnx.Module): + def merge(self, decoding_step, intermediates: nnx.State): """Merges the intermediate activations from one step.""" if self.embeddings is not None: try: self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set( - transformer.embeddings[0][:, 0, ...] + intermediates.embeddings[0][:, 0, ...] ) except AttributeError as exc: raise ValueError( 'Embeddings are not in the step intermediates.' ) from exc - if len(self.layers) != len(transformer.layers): + if len(self.layers) != len(intermediates.layers): raise ValueError( 'Number of layers in the transformer and intermediates do not match.' ) - for layer_intermediates, layer_module in zip( - self.layers, transformer.layers + for layer_intermediates, layer_idx in zip( + self.layers, intermediates.layers ): + layer_module = intermediates.layers[layer_idx] layer_intermediates.merge(decoding_step, layer_module) def trim(self, max_length: int): @@ -136,6 +136,15 @@ class SowConfig: # We use a sparse representation here to save memory. attn_logits_topk: int = 0 + def is_enabled(self): + return any([ + self.embeddings, + self.rs_after_attention, + self.rs_after_ffw, + self.mlp_hidden_topk, + self.attn_logits_topk, + ]) + def maybe_sow_embeddings( self, embeddings: jax.Array, @@ -170,8 +179,13 @@ def maybe_sow_mlp_hidden_topk( ): """Sows top-absolute-k activations in a mlp hidden layer if configured.""" if self.mlp_hidden_topk: - _, indices = jax.lax.top_k(jnp.abs(activations), self.mlp_hidden_topk) - values = jnp.take_along_axis(activations, indices, axis=-1) + shd = jax.typeof(activations).sharding + top_k = lambda x: jax.lax.top_k(x, self.mlp_hidden_topk) + _, indices = auto_axes(out_sharding=shd)(top_k)(jnp.abs(activations)) + take_aa = lambda x, i, axis: jnp.take_along_axis(x, i, axis) + values = auto_axes(out_sharding=shd)(take_aa)( + activations, indices, axis=-1 + ) module.sow(nnx.Intermediate, 'hidden_topk_values', values) module.sow(nnx.Intermediate, 'hidden_topk_indices', indices) @@ -182,6 +196,64 @@ def maybe_sow_attn_logits_topk( ): """Sows top-k attention logits if configured.""" if self.attn_logits_topk: - values, indices = jax.lax.top_k(logits, self.attn_logits_topk) + shd = jax.typeof(logits).sharding + top_k = lambda x: jax.lax.top_k(x, self.attn_logits_topk) + values, indices = auto_axes(out_sharding=shd)(top_k)(logits) module.sow(nnx.Intermediate, 'logits_topk_values', values) module.sow(nnx.Intermediate, 'logits_topk_indices', indices) + + +def init_intermediates( + batch_size: int, + buffer_size: int, + embed_dim: int, + num_layers: int, + num_heads: int, + sow_config: SowConfig, + dtype: jnp.dtype = jnp.float32, +) -> TransformerIntermediates: + """Initializes the intermediate activations that will be filled.""" + intermediates = TransformerIntermediates() + residual_stream_dummy = jnp.zeros( + (batch_size, buffer_size, embed_dim), + dtype=dtype, + ) + if sow_config.embeddings: + intermediates.embeddings = residual_stream_dummy + for _ in range(num_layers): + layer_intermediates = LayerIntermediates() + if sow_config.rs_after_attention: + layer_intermediates.rs_after_attention = residual_stream_dummy + if sow_config.rs_after_ffw: + layer_intermediates.rs_after_ffw = residual_stream_dummy + if sow_config.attn_logits_topk: + shape = ( + batch_size, + buffer_size, + num_heads, + sow_config.attn_logits_topk, + ) + layer_intermediates.attn_logits_topk_values = jnp.zeros( + shape, + dtype=dtype, + ) + layer_intermediates.attn_logits_topk_indices = jnp.zeros( + shape, + dtype=jnp.int32, + ) + if sow_config.mlp_hidden_topk: + shape = ( + batch_size, + buffer_size, + sow_config.mlp_hidden_topk, + ) + layer_intermediates.mlp_hidden_topk_values = jnp.zeros( + shape, + dtype=dtype, + ) + layer_intermediates.mlp_hidden_topk_indices = jnp.zeros( + shape, + dtype=jnp.int32, + ) + intermediates.layers.append(layer_intermediates) + return intermediates diff --git a/examples/gemma/tokenizer.py b/examples/gemma/tokenizer.py index de5b36385..0a495086b 100644 --- a/examples/gemma/tokenizer.py +++ b/examples/gemma/tokenizer.py @@ -180,9 +180,10 @@ class TokenizeOpNumpy: def __call__(self, features: dict[str, Any]) -> dict[str, Any]: for k in self.data_keys: features[k] = self.sp_processor.EncodeAsIds(features[k]) - features[k].insert(0, self.sp_processor.bos_id) - features[k].append(self.sp_processor.eos_id) - features[k] = np.array(features[k], dtype=np.int32) + features[k] = np.array( + [self.sp_processor.bos_id()] + features[k] + [self.sp_processor.eos_id()], + dtype=np.int32 + ) return features diff --git a/examples/gemma/train.py b/examples/gemma/train.py index ca7850db9..83d30bfc6 100644 --- a/examples/gemma/train.py +++ b/examples/gemma/train.py @@ -21,6 +21,8 @@ # pytype: disable=attribute-error import contextlib import dataclasses +import time +from functools import partial from pathlib import Path import time @@ -145,7 +147,8 @@ def jax_train_step( train_metrics: nnx.Metric, label_smoothing: float = 0.0, pad_id: int = 0, -) -> tuple[nnx.Module, nnx.Optimizer, nnx.Rngs, nnx.Metric]: + with_capture: bool = False, +) -> tuple[nnx.State, nnx.State | None]: """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this @@ -174,23 +177,33 @@ def jax_train_step( def loss_fn(params, rngs): """loss function used for training.""" module = nnx.merge(graphdef, params, nondiff) - - logits, _ = module( + if with_capture: + forward = nnx.capture(module, nnx.Intermediate) + else: + forward = module + output = forward( inputs, positions=inputs_positions, attention_mask=attention_mask, cache=None, rngs=rngs, ) + # output is (preds, cache) if with_capture=False + # and is ((preds, cache), intermediates) if with_capture=True + if with_capture: + (logits, _), intermediates = output + else: + logits, _ = output + intermediates = None loss_per_sample = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss_per_sample.mean() - return mean_loss, (loss_per_sample, logits) + return mean_loss, (loss_per_sample, logits, intermediates) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, (loss_per_sample, logits)), grads = grad_fn(params, rngs.fork()) + (_, (loss_per_sample, logits, intermediates)), grads = grad_fn(params, rngs.fork()) optimizer.update(model, grads) # Apply pad mask on logits and targets for metrics computation @@ -200,7 +213,7 @@ def loss_fn(params, rngs): labels=targets, mask={"accuracy": input_mask}, ) - return nnx.state((model, optimizer, rngs, train_metrics)) + return nnx.state((model, optimizer, rngs, train_metrics)), intermediates def nnx_train_step( @@ -211,7 +224,8 @@ def nnx_train_step( train_metrics: nnx.Metric, label_smoothing: float = 0.0, pad_id: int = 0, -) -> None: + with_capture: bool = False, +) -> nnx.State | None: """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this @@ -237,23 +251,33 @@ def nnx_train_step( def loss_fn(model, rngs): """loss function used for training.""" - - logits, _ = model( + if with_capture: + forward = nnx.capture(model, nnx.Intermediate) + else: + forward = model + output = forward( inputs, positions=inputs_positions, attention_mask=attention_mask, cache=None, rngs=rngs, ) + # output is (preds, cache) if with_capture=False + # and is ((preds, cache), intermediates) if with_capture=True + if with_capture: + (logits, _), intermediates = output + else: + logits, _ = output + intermediates = None loss_per_sample = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss_per_sample.mean() - return mean_loss, (loss_per_sample, logits) + return mean_loss, (loss_per_sample, logits, intermediates) grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) - (_, (loss_per_sample, logits)), grads = grad_fn(model, rngs.fork()) + (_, (loss_per_sample, logits, intermediates)), grads = grad_fn(model, rngs.fork()) optimizer.update(model, grads) # Apply pad mask on logits and targets for metrics computation @@ -263,6 +287,7 @@ def loss_fn(model, rngs): labels=targets, mask={"accuracy": input_mask}, ) + return intermediates def eval_step( @@ -351,6 +376,34 @@ def inspect_sharding(x): logging.debug(f" - Device {key.id}: {value}") +def _filter_intermediates(x): + # assume single host so we skip allgather from all hosts + # by converting to np we remove all shardings + if jnp.isdtype(x, "real floating"): + x = x.astype(jnp.float32) + x = np.asarray(x) + x = x[0, 0, ...].reshape(-1) + if len(x) < 10: + return x + if np.isdtype(x.dtype, "integral"): + return x.min().item(), x.max().item() + # Mask -inf for attn masks + x = x[x > -1e10] + if len(x) > 0: + return x.min().item(), x.mean().item(), x.max().item() + return x + + +def display_intermediates(intermediates, placeholder=None): + if placeholder is not None: + placeholder = jax.tree.map(lambda p: _filter_intermediates(p), placeholder) + placeholder = nnx.to_flat_state(placeholder) + for k, v in placeholder: + logging.info(f"{k}, stats: {v.get_value()[0]}") + + return intermediates + + def train_and_evaluate( config: TrainConfig, workdir: str, chpt_bucket: str | None = None ): @@ -431,7 +484,11 @@ def train_and_evaluate( rngs = nnx.Rngs(params=config.seed, dropout=config.seed) with jax.set_mesh(mesh): - model = transformer_lib.Transformer(model_config, rngs=rngs) + model = transformer_lib.Transformer( + model_config, + rngs=rngs, + sow_config=sow_lib.SowConfig(**config.sow_config), + ) optimizer = nnx.Optimizer( model, tx=optax.adamw( @@ -525,13 +582,12 @@ def train_and_evaluate( accuracy=nnx.metrics.Accuracy(), ) - jit_fn = ( - nnx.jit if config.use_nnx_transforms in ("all", "jit-only") else jax.jit - ) + nnx_train_step_ = partial(nnx_train_step, with_capture=config.sow_config is not None) + jax_train_step_ = partial(jax_train_step, with_capture=config.sow_config is not None) + + jit_fn = nnx.jit if config.use_nnx_transforms in ("all", "jit-only") else jax.jit jit_train_step = jit_fn( - nnx_train_step - if config.use_nnx_transforms in ("all", "grad-only") - else jax_train_step, + nnx_train_step_ if config.use_nnx_transforms in ("all", "grad-only") else jax_train_step_, static_argnames=("label_smoothing", "pad_id"), donate_argnames=("model", "optimizer"), ) @@ -548,6 +604,7 @@ def train_and_evaluate( # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training loop.') + prev_intermediates = None # for async display of intermediates hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer @@ -573,7 +630,7 @@ def train_and_evaluate( batch = next(train_iter) with report_progress.timed('train_step'): - updates = jit_train_step( + output = jit_train_step( model, optimizer, rngs, @@ -582,8 +639,12 @@ def train_and_evaluate( 0.0, # label_smoothing encoder.pad_id(), # pad_id ) - if updates is not None: + + if isinstance(output, tuple): + updates, intermediates = output nnx.update((model, optimizer, rngs, train_metrics), updates) + else: + intermediates = output # Quick indication that training is happening. if step < 20: @@ -597,6 +658,10 @@ def train_and_evaluate( for h in hooks: h(step) + # Display intermediates every 100 iters + if step % 100 == 0: + prev_intermediates = display_intermediates(intermediates, prev_intermediates) + # Periodic metric handling. if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: with report_progress.timed('training_metrics'): diff --git a/examples/gemma/train_cfg.py b/examples/gemma/train_cfg.py index eca94ca61..ec5a7d1ad 100644 --- a/examples/gemma/train_cfg.py +++ b/examples/gemma/train_cfg.py @@ -105,7 +105,7 @@ class TrainConfig: use_nnx_tree_mode: bool = False use_nnx_transforms: str = "no" # ["all", "no", "grad-only", "jit-only"] - sow_config: tuple[str, ...] | None = None + sow_config: dict = dataclasses.field(default_factory=dict) def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index 222e49886..7da11e74d 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -130,6 +130,7 @@ def __init__( ) self.final_dropout = nnx.Dropout(config.dropout_rate, deterministic=False) self.final_logits_softcap = config.final_logit_softcap + self.config = config self.sow_config = sow_config self.shd_config = config.shd_config @@ -214,59 +215,6 @@ def init_cache( for i in range(self.num_layers) } - def init_intermediates( - self, - batch_size: int, - buffer_size: int, - sow_config: sow_lib.SowConfig, - dtype: jnp.dtype = jnp.float32, - ) -> sow_lib.TransformerIntermediates: - """Initializes the intermediate activations that will be filled.""" - intermediates = sow_lib.TransformerIntermediates() - residual_stream_dummy = jnp.zeros( - (batch_size, buffer_size, self.embed_dim), - dtype=dtype, - ) - if sow_config.embeddings: - intermediates.embeddings = residual_stream_dummy - for layer in self.layers: - layer_intermediates = sow_lib.LayerIntermediates() - if sow_config.rs_after_attention: - layer_intermediates.rs_after_attention = residual_stream_dummy - if sow_config.rs_after_ffw: - layer_intermediates.rs_after_ffw = residual_stream_dummy - if sow_config.attn_logits_topk: - shape = ( - batch_size, - buffer_size, - layer.attn.num_heads, - sow_config.attn_logits_topk, - ) - layer_intermediates.attn_logits_topk_values = jnp.zeros( - shape, - dtype=dtype, - ) - layer_intermediates.attn_logits_topk_indices = jnp.zeros( - shape, - dtype=jnp.int32, - ) - if sow_config.mlp_hidden_topk: - shape = ( - batch_size, - buffer_size, - sow_config.mlp_hidden_topk, - ) - layer_intermediates.mlp_hidden_topk_values = jnp.zeros( - shape, - dtype=dtype, - ) - layer_intermediates.mlp_hidden_topk_indices = jnp.zeros( - shape, - dtype=jnp.int32, - ) - intermediates.layers.append(layer_intermediates) - return intermediates - def make_causal_attn_mask( input_mask: Array,