Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flash_lm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*checkpoints/
75 changes: 75 additions & 0 deletions flash_lm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# flash_lm

Train an on-device LM with MLX

### Install

Install dependencies:

```
pip install -r requirements.tx
```

Install MLX on macOS:

```
pip install mlx
```

Or for CUDA:

```
pip install mlx[cuda13]
```

### Training

For pretraining:

```
python pretrian.py
```

For supervised fine-tuning (SFT):

```
python sft.py
```

### Generation

The model can be easily converted to a format compatible with `mlx_lm` for
generation.

Install `mlx-lm`:

```
pip install mlx-lm
```

Or for CUDA:

```
pip install mlx-lm[cuda13]
```

Then convert a given checkpoint:

```
python convert.py --checkpoint-dir path/to/checkpoint --save-dir path/to/mlx_lm_model
```

Then use any `mlx-lm` command or API:

```
mlx_lm.generate --model path/to/mlx_lm_model --prompt "Hi"
```

### Next Steps

To customize the model change the default config (`configs/tiny.py`) or
make a new config and use it.

```
python pretrian.py --config my_custom_config.py
```
25 changes: 25 additions & 0 deletions flash_lm/configs/base_600m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model:
model_type: "transformer"
hidden_size: 1024
head_dim: 128
vocab_size: 128256
intermediate_size: 3072
num_attention_heads: 16
num_key_value_heads: 8
num_hidden_layers: 28

seed: 0
batch_size: 2
context_size: 2048
optim: "adam"
weight_decay: 0.1
learning_rate: 1e-4
num_steps: 1000000
warmup_steps: 1000
decay_steps: 1000
max_grad_norm: 5
data_type: "bfloat16"

steps_per_eval: 100000
steps_per_report: 10
steps_per_checkpoint: 100000
27 changes: 27 additions & 0 deletions flash_lm/configs/fp8_600m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model:
model_type: "transformer"
hidden_size: 1024
head_dim: 128
vocab_size: 128256
intermediate_size: 3072
num_attention_heads: 16
num_key_value_heads: 8
num_hidden_layers: 28
quantization:
mode: "mxfp8"

seed: 0
batch_size: 2
context_size: 2048
optim: "adam"
weight_decay: 0.1
learning_rate: 1e-4
num_steps: 1000000
warmup_steps: 1000
decay_steps: 1000
max_grad_norm: 5
data_type: "bfloat16"

steps_per_eval: 10000
steps_per_report: 10
steps_per_checkpoint: 10000
25 changes: 25 additions & 0 deletions flash_lm/configs/tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model:
model_type: "transformer"
hidden_size: 512
head_dim: 128
vocab_size: 128256
intermediate_size: 512
num_attention_heads: 4
num_key_value_heads: 2
num_hidden_layers: 4

seed: 0
batch_size: 8
context_size: 2048
optim: "adam"
weight_decay: 0.1
learning_rate: 1e-4
num_steps: 1000000
warmup_steps: 1000
decay_steps: 1000
max_grad_norm: 5
data_type: "bfloat16"

steps_per_eval: 100000
steps_per_report: 10
steps_per_checkpoint: 100000
45 changes: 45 additions & 0 deletions flash_lm/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
import json
import shutil
from pathlib import Path

import utils

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert a model for use with mlx-lm.",
)
parser.add_argument(
"--checkpoint-dir",
default="checkpoints",
type=str,
help="Path to checkpoint",
)
parser.add_argument(
"--save-dir",
default="mlx_lm_checkpoint",
type=str,
help="Location to save the mlx_lm ready model",
)
args = parser.parse_args()

tokenizer = utils.load_tokenizer()
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)

# Save tokenizer
tokenizer.save_pretrained(save_dir)

checkpoint_dir = Path(args.checkpoint_dir)
config = utils.load_config(checkpoint_dir).model
config["model_file"] = f"{config['model_type']}.py"
config = dict(sorted(config.items()))
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)

for file in [
"models/transformer.py",
checkpoint_dir / "model.safetensors",
]:
dst_path = save_dir / Path(file).name
shutil.copy(file, dst_path)
Loading