diff --git a/flash_lm/.gitignore b/flash_lm/.gitignore new file mode 100644 index 000000000..6756efebf --- /dev/null +++ b/flash_lm/.gitignore @@ -0,0 +1 @@ +*checkpoints/ diff --git a/flash_lm/README.md b/flash_lm/README.md new file mode 100644 index 000000000..cf96aff43 --- /dev/null +++ b/flash_lm/README.md @@ -0,0 +1,75 @@ +# flash_lm + +Train an on-device LM with MLX + +### Install + +Install dependencies: + +``` +pip install -r requirements.tx +``` + +Install MLX on macOS: + +``` +pip install mlx +``` + +Or for CUDA: + +``` +pip install mlx[cuda13] +``` + +### Training + +For pretraining: + +``` +python pretrian.py +``` + +For supervised fine-tuning (SFT): + +``` +python sft.py +``` + +### Generation + +The model can be easily converted to a format compatible with `mlx_lm` for +generation. + +Install `mlx-lm`: + +``` +pip install mlx-lm +``` + +Or for CUDA: + +``` +pip install mlx-lm[cuda13] +``` + +Then convert a given checkpoint: + +``` +python convert.py --checkpoint-dir path/to/checkpoint --save-dir path/to/mlx_lm_model +``` + +Then use any `mlx-lm` command or API: + +``` +mlx_lm.generate --model path/to/mlx_lm_model --prompt "Hi" +``` + +### Next Steps + +To customize the model change the default config (`configs/tiny.py`) or +make a new config and use it. + +``` +python pretrian.py --config my_custom_config.py +``` diff --git a/flash_lm/configs/base_600m.yaml b/flash_lm/configs/base_600m.yaml new file mode 100644 index 000000000..3e0c8e790 --- /dev/null +++ b/flash_lm/configs/base_600m.yaml @@ -0,0 +1,25 @@ +model: + model_type: "transformer" + hidden_size: 1024 + head_dim: 128 + vocab_size: 128256 + intermediate_size: 3072 + num_attention_heads: 16 + num_key_value_heads: 8 + num_hidden_layers: 28 + +seed: 0 +batch_size: 2 +context_size: 2048 +optim: "adam" +weight_decay: 0.1 +learning_rate: 1e-4 +num_steps: 1000000 +warmup_steps: 1000 +decay_steps: 1000 +max_grad_norm: 5 +data_type: "bfloat16" + +steps_per_eval: 100000 +steps_per_report: 10 +steps_per_checkpoint: 100000 diff --git a/flash_lm/configs/fp8_600m.yaml b/flash_lm/configs/fp8_600m.yaml new file mode 100644 index 000000000..c20fa556f --- /dev/null +++ b/flash_lm/configs/fp8_600m.yaml @@ -0,0 +1,27 @@ +model: + model_type: "transformer" + hidden_size: 1024 + head_dim: 128 + vocab_size: 128256 + intermediate_size: 3072 + num_attention_heads: 16 + num_key_value_heads: 8 + num_hidden_layers: 28 + quantization: + mode: "mxfp8" + +seed: 0 +batch_size: 2 +context_size: 2048 +optim: "adam" +weight_decay: 0.1 +learning_rate: 1e-4 +num_steps: 1000000 +warmup_steps: 1000 +decay_steps: 1000 +max_grad_norm: 5 +data_type: "bfloat16" + +steps_per_eval: 10000 +steps_per_report: 10 +steps_per_checkpoint: 10000 diff --git a/flash_lm/configs/tiny.yaml b/flash_lm/configs/tiny.yaml new file mode 100644 index 000000000..6d57899c3 --- /dev/null +++ b/flash_lm/configs/tiny.yaml @@ -0,0 +1,25 @@ +model: + model_type: "transformer" + hidden_size: 512 + head_dim: 128 + vocab_size: 128256 + intermediate_size: 512 + num_attention_heads: 4 + num_key_value_heads: 2 + num_hidden_layers: 4 + +seed: 0 +batch_size: 8 +context_size: 2048 +optim: "adam" +weight_decay: 0.1 +learning_rate: 1e-4 +num_steps: 1000000 +warmup_steps: 1000 +decay_steps: 1000 +max_grad_norm: 5 +data_type: "bfloat16" + +steps_per_eval: 100000 +steps_per_report: 10 +steps_per_checkpoint: 100000 diff --git a/flash_lm/convert.py b/flash_lm/convert.py new file mode 100644 index 000000000..8ab56c19e --- /dev/null +++ b/flash_lm/convert.py @@ -0,0 +1,45 @@ +import argparse +import json +import shutil +from pathlib import Path + +import utils + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert a model for use with mlx-lm.", + ) + parser.add_argument( + "--checkpoint-dir", + default="checkpoints", + type=str, + help="Path to checkpoint", + ) + parser.add_argument( + "--save-dir", + default="mlx_lm_checkpoint", + type=str, + help="Location to save the mlx_lm ready model", + ) + args = parser.parse_args() + + tokenizer = utils.load_tokenizer() + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + # Save tokenizer + tokenizer.save_pretrained(save_dir) + + checkpoint_dir = Path(args.checkpoint_dir) + config = utils.load_config(checkpoint_dir).model + config["model_file"] = f"{config['model_type']}.py" + config = dict(sorted(config.items())) + with open(save_dir / "config.json", "w") as fid: + json.dump(config, fid, indent=4) + + for file in [ + "models/transformer.py", + checkpoint_dir / "model.safetensors", + ]: + dst_path = save_dir / Path(file).name + shutil.copy(file, dst_path) diff --git a/flash_lm/models/transformer.py b/flash_lm/models/transformer.py new file mode 100644 index 000000000..a8f316360 --- /dev/null +++ b/flash_lm/models/transformer.py @@ -0,0 +1,233 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class ModelArgs: + vocab_size: int + hidden_size: int + head_dim: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + mlp_gate_dim: Optional[int] = None + rope_theta: float = 10_000 + rms_norm_eps: float = 1e-5 + tie_word_embeddings: bool = True + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim + self.scale = head_dim**-0.5 + + self.qkv_proj = nn.Linear( + dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=False + ) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + self.rope = nn.RoPE( + head_dim, + base=args.rope_theta, + traditional=False, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[Any] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = mx.split( + self.qkv_proj(x), + [ + self.n_heads * self.head_dim, + (self.n_heads + self.n_kv_heads) * self.head_dim, + ], + axis=-1, + ) + + queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose( + 0, 2, 1, 3 + ) + keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x) -> mx.array: + gate, up = mx.split(self.gate_up_proj(x), 2, axis=-1) + return self.down_proj(nn.silu(gate) * up) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class LanguageModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + # TODO get that from cache or something + mask = "causal" + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model = LanguageModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers + + +class SimpleCache: + def __init__(self): + self.keys = None + self.values = None + self.offset = 0 + + def update_and_fetch(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + self.keys = mx.concatenate([self.keys, keys], axis=-2) + self.values = mx.concatenate([self.values, values], axis=-2) + self.offset = self.keys.shape[-2] + + return self.keys, self.values + + +if __name__ == "__main__": + mx.random.seed(0) + args = ModelArgs( + vocab_size=100, + hidden_size=512, + head_dim=32, + num_hidden_layers=4, + intermediate_size=512, + num_attention_heads=4, + num_key_value_heads=2, + ) + model = Model(args) + x = mx.random.randint(0, 100, shape=(4, 1024)) + y = model(x) + mx.eval(y) + + cache = [SimpleCache() for _ in model.layers] + + x = mx.random.randint(0, 100, shape=(1, 1024)) + y = model(x, cache) + mx.eval(y) + + x = mx.random.randint(0, 100, shape=(1, 1)) + y = model(x, cache) + mx.eval(y) diff --git a/flash_lm/pretrain.py b/flash_lm/pretrain.py new file mode 100644 index 000000000..ea2f01901 --- /dev/null +++ b/flash_lm/pretrain.py @@ -0,0 +1,243 @@ +import argparse +import math +import os +import time +from functools import partial +from pathlib import Path + +import datasets +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np + +# import evaluate +import utils +import wandb +from mlx.nn.utils import average_gradients +from mlx.utils import tree_map_with_path, tree_reduce + + +def load_data(tokenizer, data_path="allenai/dolma3_mix-150B-1025", valid_size=1000): + group = mx.distributed.init() + size = group.size() + rank = group.rank() + ds = datasets.load_dataset( + data_path, + split="train", + streaming=True, + ) + ds = ds.shard(num_shards=size, index=rank) + ds = ds.shuffle(buffer_size=10000) + + def tokenize(d): + tokens = tokenizer.encode(d["text"], add_special_tokens=False) + tokens.append(tokenizer.eos_token_id) + return {"data": tokens} + + ds = ds.map(tokenize) + local_valid_size = valid_size // size + valid_ds = ds.take(local_valid_size) + train_ds = ds.skip(local_valid_size) + return train_ds, valid_ds + + +def iterate_batches(dataset, context_size, batch_size): + """ + Simply concatenate documents until the batch is full. + """ + dataset = iter(dataset) + seq_len = context_size + 1 + batch = np.empty((batch_size * seq_len), np.int32) + d_next = [] + while True: + i = 0 + while i < len(batch): + if len(d_next) > 0: + d = d_next + d_next = [] + else: + d = next(dataset, None) + if d is None: + break + d = d["data"] + e = i + len(d) + if e > len(batch): + trim = e - len(batch) + d_next = d[-trim:] + d = d[:-trim] + e = len(batch) + batch[i:e] = d + i += len(d) + # Iterator ended + if i < len(batch): + break + yield batch.reshape(batch_size, seq_len) + + +def main(config, save_dir): + + np.random.seed(config.seed) + mx.random.seed(config.seed) + + rank, world_size = utils.init_distributed() + batch_size = config.batch_size + context_size = config.context_size + + # data.download_eval_bundle() + + optimizer = utils.load_optimizer(config) + tokenizer = utils.load_tokenizer() + train_set, valid_set = load_data(tokenizer) + + model = utils.load_model(config.model) + dtype = getattr(mx, config.data_type) + + # Quantize the model if specified in the config + quant_params = set() + if quant := config.model.get("quantization", False): + + def class_predicate(p, m): + if isinstance(m, nn.Linear): + quant_params.add(p + ".weight") + quant_params.add(p + ".scales") + return True + return False + + nn.quantize( + model, + mode=quant["mode"], + quantize_input=True, + class_predicate=class_predicate, + ) + + @mx.compile + def loss_fn(params, sample): + model.update( + tree_map_with_path( + lambda p, x: x.astype(dtype) if p not in quant_params else x, params + ) + ) + inputs = sample[:, :-1] + targets = sample[:, 1:] + + logits = model(inputs).astype(mx.float32) + losses = nn.losses.cross_entropy(logits, targets, reduction="none") + return losses.sum() / targets.size + + state = [optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(sample, params): + loss, grads = mx.value_and_grad(loss_fn)(params, sample) + grads = average_gradients(grads, all_reduce_size=4e9) + grads, grad_norm = optim.clip_grad_norm(grads, max_norm=config.max_grad_norm) + params = optimizer.apply_gradients(grads, params) + return loss, grad_norm, params + + def eval_fn(params, dataset): + data_it = iterate_batches( + dataset, + context_size=context_size, + batch_size=batch_size, + ) + losses = 0 + ntoks = 0 + toks_per_batch = context_size * batch_size + for sample in data_it: + loss = loss_fn(params, mx.array(sample)) + losses += loss * toks_per_batch + mx.eval(losses) + ntoks += toks_per_batch + return losses / ntoks + + params = model.trainable_parameters() + nparams = tree_reduce(lambda acc, p: acc + p.size, params, 0) + if rank == 0: + print(f"Model has {nparams} parameters.") + mx.eval(params) + + train_iterator = iterate_batches( + train_set, + context_size=config.context_size, + batch_size=config.batch_size, + ) + + metrics = utils.Metrics() + tokens = 0 + tic = time.perf_counter() + for it, sample in zip(range(0, config.num_steps), train_iterator): + loss, grad_norm, params = step(mx.array(sample), params) + loss = mx.distributed.all_sum(loss) / world_size + grad_norm = mx.distributed.all_sum(grad_norm) / world_size + mx.eval(loss, grad_norm, params, state) + metrics.train_loss.append(loss.item()) + metrics.grad_norm.append(grad_norm.item()) + tokens += config.context_size * config.batch_size * world_size + + if (it + 1) % config.steps_per_report == 0: + toc = time.perf_counter() + metrics.step = it + 1 + metrics.tokens += tokens + metrics.its_per_sec = config.steps_per_report * world_size / (toc - tic) + metrics.toks_per_sec = tokens / (toc - tic) + tokens = 0 + + if (it + 1) % config.steps_per_eval == 0: + # Do the evaluation in the final precision, but only cast the model once + model.update(params) + model.eval() + model.set_dtype(dtype) + eval_params = model.parameters() + mx.eval(eval_params) + loss = eval_fn(eval_params, valid_set) + loss = mx.distributed.all_sum(loss) / world_size + metrics.valid_loss = loss.item() + metrics.valid_ppl = math.exp(metrics.valid_loss) + model.train() + model.update(params) + + if rank == 0: + if (it + 1) % config.steps_per_checkpoint == 0: + utils.save_checkpoint(save_dir, it, params, optimizer, config) + utils.save_checkpoint(save_dir, None, params, optimizer, config) + utils.log_metrics(metrics) + metrics = utils.Metrics(tokens=metrics.tokens) + tic = time.perf_counter() + + if rank == 0: + utils.save_checkpoint(save_dir, None, params, optimizer, config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train an LM.") + parser.add_argument( + "--config", + default="configs/base_600m.yaml", + type=str, + help="Experiment config", + ) + parser.add_argument( + "--save-dir", + default="checkpoints", + type=str, + help="Location to save the model and checkpoints", + ) + parser.add_argument( + "--wandb-name", + default=None, + type=str, + help="Name of experiment for wandb", + ) + args = parser.parse_args() + config = utils.load_config(args.config) + if mx.distributed.init().rank() == 0: + wandb_kwargs = dict( + project="flash_lm", + name=args.wandb_name, + tags=["pretrain", config.model["model_type"]], + ) + if args.wandb_name is None: + wandb_kwargs["mode"] = "disabled" + run = wandb.init(**wandb_kwargs) + main(config, args.save_dir) diff --git a/flash_lm/requirements.txt b/flash_lm/requirements.txt new file mode 100644 index 000000000..721921726 --- /dev/null +++ b/flash_lm/requirements.txt @@ -0,0 +1,5 @@ +transformers +datasets +jinja2 +zstandard +wandb diff --git a/flash_lm/sft.py b/flash_lm/sft.py new file mode 100644 index 000000000..aabfab17e --- /dev/null +++ b/flash_lm/sft.py @@ -0,0 +1,258 @@ +import argparse +import math +import os +import random +import time +from functools import partial +from pathlib import Path + +import datasets +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +import utils +import wandb +from mlx.nn.utils import average_gradients +from mlx.utils import tree_map, tree_reduce + + +def buffer_batches(dataset, batch_size): + + def _sort_batch_shuffle(buffer): + buffer.sort(key=lambda x: len(x[0])) + batches = [ + buffer[s : s + batch_size] for s in range(0, len(buffer), batch_size) + ] + random.shuffle(batches) + return batches + + buffer_size = batch_size * 1000 + buffer = [] + for d in dataset: + buffer.append(d) + if len(buffer) >= buffer_size: + for b in _sort_batch_shuffle(buffer): + yield b + buffer = [] + + for b in _sort_batch_shuffle(buffer): + yield b + + +def iterate_batches(dataset, batch_size, max_length=None): + """ + Add full documents into the batch with padding based on the maximum + document length. + """ + + def _collate(seqs, length, dtype): + return np.array([s + [0] * (length - len(s)) for s in seqs], dtype) + + def _round_up(n): + m = 512 + n = m * ((n + m - 1) // m) + if max_length is not None: + n = min(n, max_length) + return n + 1 + + dataset = ((d["tokens"], d["mask"]) for d in dataset) + for batch in buffer_batches(dataset, batch_size): + tokens, masks = zip(*batch) + if max_length is not None: + tokens = [t[: max_length + 1] for t in tokens] + masks = [m[: max_length + 1] for m in masks] + lengths = [len(t) for t in tokens] + length = _round_up(max(lengths)) + yield { + "data": _collate(tokens, length=length, dtype=np.int32), + "mask": _collate(masks, length=length, dtype=bool), + "lengths": lengths, + } + + +def load_data(tokenizer, data_path="allenai/Dolci-Instruct-SFT", valid_size=1000): + group = mx.distributed.init() + size = group.size() + rank = group.rank() + ds = datasets.load_dataset( + data_path, + split="train", + streaming=True, + ) + ds = ds.shard(num_shards=size, index=rank) + ds = ds.shuffle(buffer_size=10000) + + # Tokenize data so that only the assistant generations are not masked. + n_mask = len(tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)) + + def tokenize(d): + tokens = [] + mask = [] + for m in d["messages"]: + role = m.get("role") + if role == "user": + local_tokens = tokenizer.apply_chat_template([m], return_dict=False) + mask.extend([False] * len(local_tokens)) + elif role == "assistant": + local_tokens = tokenizer.apply_chat_template([m], return_dict=False) + mask.extend([False] * n_mask) + mask.extend([True] * (len(local_tokens) - n_mask)) + else: + raise ValueError(f"Unknown role {role}") + tokens.extend(local_tokens) + return {"tokens": tokens, "mask": mask} + + ds = ds.map(tokenize) + local_valid_size = valid_size // size + valid_ds = ds.take(local_valid_size) + train_ds = ds.skip(local_valid_size) + return train_ds, valid_ds + + +def main(config, checkpoint_dir, save_dir): + random.seed(config.seed) + np.random.seed(config.seed) + mx.random.seed(config.seed) + + rank, world_size = utils.init_distributed() + batch_size = config.batch_size + max_length = config.context_size + + optimizer = utils.load_optimizer(config) + tokenizer = utils.load_tokenizer() + train_set, valid_set = load_data(tokenizer) + + model = utils.load_model(config.model) + model.load_weights(str(Path(checkpoint_dir) / "model.safetensors")) + + dtype = getattr(mx, config.data_type) + + def to_mlx(sample): + return {k: mx.array(v) for k, v in sample.items()} + + def loss_fn(params, sample): + model.update(tree_map(lambda x: x.astype(dtype), params)) + inputs = sample["data"][:, :-1] + targets = sample["data"][:, 1:] + + logits = model(inputs).astype(mx.float32) + losses = nn.losses.cross_entropy(logits, targets, reduction="none") + mask = sample["mask"][:, 1:] + # Avoid divide by 0 (loss should be 0) + ntoks = mx.maximum(mask.sum(), 1) + loss = (losses * mask).sum() / ntoks + return loss, ntoks + + state = [optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(sample, params): + (loss, ntoks), grads = mx.value_and_grad(loss_fn)(params, sample) + grads = average_gradients(grads, all_reduce_size=4e9) + grads, grad_norm = optim.clip_grad_norm(grads, max_norm=config.max_grad_norm) + params = optimizer.apply_gradients(grads, params) + return loss, ntoks, grad_norm, params + + def eval_fn(params): + data_it = data.iterate_batches( + valid_set, + max_length=max_length, + batch_size=batch_size, + ) + losses = 0 + num_toks = 0 + toks_per_batch = max_length * batch_size + for sample in data.prefetch(data_it): + loss, ntoks = loss_fn(params, to_mlx(sample)) + losses += loss * ntoks + num_toks += ntoks + mx.eval(losses, num_toks) + return losses / num_toks + + params = model.trainable_parameters() + nparams = tree_reduce(lambda acc, p: acc + p.size, params, 0) + if rank == 0: + print(f"Model has {nparams} parameters.") + + mx.eval(params) + + train_iterator = iterate_batches( + train_set, + max_length=max_length, + batch_size=config.batch_size, + ) + + metrics = utils.Metrics() + tokens = 0 + tic = time.perf_counter() + for it, sample in zip(range(0, config.num_steps), train_iterator): + sample = to_mlx(sample) + loss, _, grad_norm, params = step(sample, params) + loss = mx.distributed.all_sum(loss) / world_size + grad_norm = mx.distributed.all_sum(grad_norm) / world_size + # Count all tokens processed without padding (not just loss tokens) + num_tokens = mx.distributed.all_sum((sample["lengths"] - 1).sum()) + mx.eval(loss, num_tokens, grad_norm, params, state) + metrics.train_loss.append(loss.item()) + metrics.grad_norm.append(grad_norm.item()) + tokens += num_tokens.item() + + if (it + 1) % config.steps_per_report == 0: + toc = time.perf_counter() + metrics.step = it + 1 + metrics.tokens += tokens + metrics.its_per_sec = config.steps_per_report * world_size / (toc - tic) + metrics.toks_per_sec = tokens / (toc - tic) + tokens = 0 + + if (it + 1) % config.steps_per_eval == 0: + loss = eval_fn(params) + loss = mx.distributed.all_sum(loss) / world_size + metrics.valid_loss = loss.item() + metrics.valid_ppl = math.exp(metrics.valid_loss) + + if rank == 0: + if (it + 1) % config.steps_per_checkpoint == 0: + utils.save_checkpoint(save_dir, it, params, optimizer, config) + utils.save_checkpoint(save_dir, None, params, optimizer, config) + utils.log_metrics(metrics) + metrics = utils.Metrics(tokens=metrics.tokens) + tic = time.perf_counter() + + if rank == 0: + utils.save_checkpoint(save_dir, None, params, optimizer, config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Supervise fine-tune an LM.") + parser.add_argument( + "--checkpoint-dir", + default="checkpoints", + type=str, + help="Location to load the pretrained model checkpoint", + ) + parser.add_argument( + "--save-dir", + default="sft_checkpoints", + type=str, + help="Location to save the model and checkpoints", + ) + parser.add_argument( + "--wandb-name", + default=None, + type=str, + help="Name of experiment for wandb", + ) + args = parser.parse_args() + config = utils.load_config(args.checkpoint_dir) + if mx.distributed.init().rank() == 0: + wandb_kwargs = dict( + project="flash_lm", + name=args.wandb_name, + tags=["pretrain", config.model["model_type"]], + ) + if args.wandb_name is None: + wandb_kwargs["mode"] = "disabled" + run = wandb.init(**wandb_kwargs) + main(config, args.checkpoint_dir, args.save_dir) diff --git a/flash_lm/utils.py b/flash_lm/utils.py new file mode 100644 index 000000000..f1c23aaa5 --- /dev/null +++ b/flash_lm/utils.py @@ -0,0 +1,178 @@ +import dataclasses +import importlib +import inspect +import json +import os +import re +import types +from pathlib import Path + +import mlx.core as mx +import mlx.optimizers as optim +import numpy as np +import wandb + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" +import transformers +import yaml +from mlx.utils import tree_flatten + +yaml_loader = yaml.SafeLoader +yaml_loader.add_implicit_resolver( + "tag:yaml.org,2002:float", + re.compile( + """^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), +) + + +def load_config(config_path): + config_path = Path(config_path) + if config_path.suffix != ".yaml": + config_path = config_path / "config.yaml" + with open(config_path, "r") as fid: + config = yaml.load(fid, yaml_loader) + return types.SimpleNamespace(**config) + + +def save_config(dirname, config): + with open(Path(dirname) / "config.yaml", "w") as fid: + yaml.safe_dump(config.__dict__, fid) + + +def load_tokenizer(path="awni/lmx"): + os.environ["TOKENIZERS_PARALLELISM"] = "true" + return transformers.AutoTokenizer.from_pretrained(path) + + +def save_checkpoint(save_dir, it, params, optimizer, config): + checkpoint_dir = Path(save_dir) + if it is not None: + checkpoint_dir /= f"{it:012d}" + checkpoint_dir.mkdir(exist_ok=True, parents=True) + save_config(checkpoint_dir, config) + + mx.save_safetensors( + str(checkpoint_dir / "model.safetensors"), + dict(tree_flatten(params)), + ) + mx.save_safetensors( + str(checkpoint_dir / "opt_state.safetensors"), + dict(tree_flatten(optimizer.state)), + ) + + +@dataclasses.dataclass +class Metrics: + step: int = 0 + tokens: int = 0 + train_loss: list = dataclasses.field(default_factory=list) + grad_norm: list = dataclasses.field(default_factory=list) + its_per_sec: float = 0.0 + toks_per_sec: float = 0.0 + valid_loss: float = None + valid_ppl: float = None + eval_core: float = None + + def to_list(self): + metrics = [ + ("step", self.step), + ("train_loss", np.mean(self.train_loss).item()), + ("grad_norm", np.mean(self.grad_norm).item()), + ("its_per_sec", self.its_per_sec), + ("toks_per_sec", self.toks_per_sec), + ("tokens", self.tokens), + ] + + if self.valid_loss is not None: + metrics.append(("valid_loss", self.valid_loss)) + metrics.append(("valid_ppl", self.valid_ppl)) + + if self.eval_core is not None: + metrics.append(("eval_core", self.eval_core)) + + return metrics + + +def log_metrics(metrics): + if isinstance(metrics, list): + list_metrics = metrics + else: + list_metrics = metrics.to_list() + + def to_str(val): + if isinstance(val, float): + return f"{val:.4f}" + else: + return repr(val) + + print(", ".join(f"{n}: {to_str(v)}" for n, v in list_metrics)) + metrics = dict(list_metrics) + step = metrics.pop("step") + wandb.log(metrics, step=step) + + +def init_distributed(): + + rank = int(os.environ.get("MLX_RANK", "0")) + world_size = int(os.environ.get("MLX_WORLD_SIZE", "1")) + + if world_size > 1: + if rank == 0: + print(f"Master host: {os.environ.get('NCCL_HOST_IP')}") + print(f"Rank {rank} of {world_size} initialized.") + mx.distributed.init(backend="nccl") + return rank, world_size + + +def load_optimizer(config): + warmup = optim.linear_schedule( + 0, + config.learning_rate, + config.warmup_steps, + ) + decay = optim.linear_schedule( + config.learning_rate, + 0, + config.decay_steps, + ) + lr_schedule = optim.join_schedules( + [warmup, decay], [config.num_steps - config.decay_steps] + ) + if config.optim == "adam": + optimizer = optim.Adam(learning_rate=lr_schedule) + elif config.optim == "adamw": + optimizer = optim.AdamW( + learning_rate=lr_schedule, + weight_decay=config.get("weight_decay", 0), + ) + elif config.optim == "sgd": + optimizer = optim.SGD(learning_rate=lr_schedule) + return optimizer + + +def load_model(config): + model_type = config["model_type"] + arch = importlib.import_module(f"models.{model_type}") + model_args = arch.ModelArgs( + **{ + k: v + for k, v in config.items() + if k in inspect.signature(arch.ModelArgs).parameters + } + ) + return arch.Model(model_args) + + +if __name__ == "__main__": + + tokenizer = load_tokenizer() + print(f"Vocab size {len(tokenizer.vocab)}")