Add files using upload-large-folder tool
Browse files- README.md +512 -0
- torchtitan/components/__pycache__/ft.cpython-311.pyc +0 -0
- torchtitan/components/optimizer.py +303 -0
- torchtitan/distributed/__pycache__/pipeline.cpython-311.pyc +0 -0
- torchtitan/distributed/pipeline.py +201 -0
- torchtitan/experiments/deepseek_v3/checkpoint.py +154 -0
- torchtitan/experiments/deepseek_v3/download.py +70 -0
- torchtitan/experiments/deepseek_v3/generate.py +308 -0
- torchtitan/experiments/deepseek_v3/inference.sh +15 -0
- torchtitan/experiments/deepseek_v3/model.py +1325 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
- torchtitan/experiments/deepseek_v3/train.py +142 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
- torchtitan/experiments/flux/model/autoencoder.py +388 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/layers.py +286 -0
- torchtitan/experiments/flux/model/model.py +177 -0
- torchtitan/experiments/flux/parallelize_flux.py +26 -0
- torchtitan/experiments/flux/requirements.txt +2 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
- torchtitan/experiments/llama4/__pycache__/__init__.cpython-311.pyc +0 -0
- torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
- torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
- torchtitan/experiments/llama4/model/__pycache__/model.cpython-311.pyc +0 -0
- torchtitan/experiments/llama4/model/model.py +466 -0
- torchtitan/experiments/llama4/model/moe.py +228 -0
- torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
- torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
- torchtitan/experiments/multimodal/mm_collator.py +227 -0
- torchtitan/experiments/multimodal/mm_dataset.py +268 -0
- torchtitan/experiments/multimodal/tests/test_multimodal_model.py +128 -0
- torchtitan/experiments/multimodal/utils.py +437 -0
- torchtitan/experiments/simple_fsdp/__init__.py +33 -0
- torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-311.pyc +0 -0
- torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-311.pyc +0 -0
- torchtitan/experiments/simple_fsdp/model.py +18 -0
- torchtitan/experiments/simple_fsdp/parallelize_llama.py +98 -0
- torchtitan/models/llama3/train_configs/llama3_8b.toml +63 -0
- torchtitan/models/norms.py +35 -0
- torchtitan/protocols/__pycache__/model_converter.cpython-311.pyc +0 -0
- torchtitan/tools/utils.py +143 -0
README.md
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Linear Attention Made Easy
|
| 4 |
+
# This is a fork for the paper:
|
| 5 |
+
# Softpick: No Attention Sink, No Massive Activations with Rectified Softmax
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
## Instructions for Softpick Attention
|
| 10 |
+
|
| 11 |
+
This fork can only work on an older commit of torchtitan and flame, so the setup looks like this:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
git clone https://github.com/zaydzuhri/flame.git
|
| 15 |
+
cd flame
|
| 16 |
+
git checkout softpick-attention
|
| 17 |
+
git submodule update --init --recursive --remote
|
| 18 |
+
cd 3rdparty/torchtitan
|
| 19 |
+
git checkout 4f532e0
|
| 20 |
+
cd ../../
|
| 21 |
+
|
| 22 |
+
pip install .
|
| 23 |
+
pip install flash-attn --no-build-isolation
|
| 24 |
+
```
|
| 25 |
+
The flash-linear-attention submodule has been changed to link to our fork: https://github.com/zaydzuhri/flash-linear-attention/tree/softpick-attention
|
| 26 |
+
So no need to manually clone it.
|
| 27 |
+
|
| 28 |
+
Then prepare the fineweb-edu 100B sample the same way as described in the flame repo guide below.
|
| 29 |
+
|
| 30 |
+
These are the training commands used in the paper:
|
| 31 |
+
```bash
|
| 32 |
+
NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/vanilla_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/vanilla-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
|
| 33 |
+
|
| 34 |
+
NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
And the same for the extra experiments in the appendix:
|
| 38 |
+
```bash
|
| 39 |
+
NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/rectified.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/rectified_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/rectified-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
|
| 40 |
+
|
| 41 |
+
NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.scaled.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_scaled_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-scaled-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Feel free to DM @zmkzmkz on X for any questions regarding the paper or this code!
|
| 45 |
+
|
| 46 |
+
## Flame
|
| 47 |
+
|
| 48 |
+
Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
|
| 49 |
+
|
| 50 |
+
**Feature Highlights:**
|
| 51 |
+
|
| 52 |
+
- 🚀 Minimal, easy-to-use, extensible training framework
|
| 53 |
+
- 🤗 Seamless integration with `fla` and `transformers`
|
| 54 |
+
- 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
|
| 55 |
+
- 🔮 4D parallelism (coming soon)
|
| 56 |
+
|
| 57 |
+
## Setup
|
| 58 |
+
|
| 59 |
+
To get started, clone the `flame` repository and install the required dependencies:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
git clone https://github.com/fla-org/flame.git
|
| 63 |
+
cd flame
|
| 64 |
+
pip install .
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
`flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
|
| 68 |
+
After installation, initialize and update the submodules:
|
| 69 |
+
```sh
|
| 70 |
+
git submodule update --init --recursive
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Dataset Preparation
|
| 74 |
+
To download the dataset to your local disk, create a new Python file with the following content and execute it:
|
| 75 |
+
|
| 76 |
+
```py
|
| 77 |
+
from datasets import load_dataset
|
| 78 |
+
|
| 79 |
+
# load fineweb-edu with parallel processing
|
| 80 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
|
| 81 |
+
|
| 82 |
+
# or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
|
| 83 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Training Recipes
|
| 87 |
+
|
| 88 |
+
Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
|
| 89 |
+
|
| 90 |
+
> [!WARNING]
|
| 91 |
+
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
|
| 92 |
+
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
|
| 93 |
+
|
| 94 |
+
```sh
|
| 95 |
+
bash train.sh \
|
| 96 |
+
--job.config_file flame/models/fla.toml \
|
| 97 |
+
--job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
|
| 98 |
+
--model.config configs/transformer_340M.json \
|
| 99 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 100 |
+
--optimizer.name AdamW \
|
| 101 |
+
--optimizer.eps 1e-15 \
|
| 102 |
+
--optimizer.lr 3e-4 \
|
| 103 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 104 |
+
--lr_scheduler.lr_min 0.1 \
|
| 105 |
+
--lr_scheduler.decay_type cosine \
|
| 106 |
+
--training.batch_size 1 \
|
| 107 |
+
--training.seq_len 65536 \
|
| 108 |
+
--training.context_len 4096 \
|
| 109 |
+
--training.varlen \
|
| 110 |
+
--training.gradient_accumulation_steps 1 \
|
| 111 |
+
--training.steps 20480 \
|
| 112 |
+
--training.max_norm 1.0 \
|
| 113 |
+
--training.skip_nan_inf \
|
| 114 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 115 |
+
--training.dataset_name sample-100BT \
|
| 116 |
+
--training.dataset_split train \
|
| 117 |
+
--training.streaming \
|
| 118 |
+
--training.num_workers 32 \
|
| 119 |
+
--training.prefetch_factor 2 \
|
| 120 |
+
--training.seed 42 \
|
| 121 |
+
--training.compile \
|
| 122 |
+
--checkpoint.interval 2048 \
|
| 123 |
+
--checkpoint.load_step -1 \
|
| 124 |
+
--checkpoint.keep_latest_k 2 \
|
| 125 |
+
--metrics.log_freq 1
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
|
| 129 |
+
**For single-GPU debugging, set `NGPU=1`.**
|
| 130 |
+
|
| 131 |
+
We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
|
| 132 |
+
By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
|
| 133 |
+
|
| 134 |
+
**Key parameters:**
|
| 135 |
+
- `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
|
| 136 |
+
- `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
|
| 137 |
+
- `--training.steps`: Total number of training steps.
|
| 138 |
+
- `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
|
| 139 |
+
- `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
|
| 140 |
+
- `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
|
| 141 |
+
- `--training.varlen`: Whether to conduct variable-length sequence training.
|
| 142 |
+
- `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
|
| 143 |
+
|
| 144 |
+
> [!WARNING]
|
| 145 |
+
> The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
|
| 146 |
+
> Each step processes `global_batch_size * seq_len` tokens.
|
| 147 |
+
> Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
|
| 148 |
+
|
| 149 |
+
For a detailed explanation of all parameters, run:
|
| 150 |
+
|
| 151 |
+
```sh
|
| 152 |
+
bash train.sh -h
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
<details>
|
| 156 |
+
<summary>Usage</summary>
|
| 157 |
+
|
| 158 |
+
```py
|
| 159 |
+
options:
|
| 160 |
+
-h, --help show this help message and exit
|
| 161 |
+
--job.config_file JOB.CONFIG_FILE
|
| 162 |
+
Job config file
|
| 163 |
+
--job.dump_folder JOB.DUMP_FOLDER
|
| 164 |
+
Folder to dump job outputs
|
| 165 |
+
--job.description JOB.DESCRIPTION
|
| 166 |
+
Description of the job
|
| 167 |
+
--job.use_for_integration_test
|
| 168 |
+
Add this config to the integration test suite
|
| 169 |
+
--job.print_args Print the args to terminal
|
| 170 |
+
--model.config MODEL.CONFIG
|
| 171 |
+
Path to the model config
|
| 172 |
+
--model.norm_type MODEL.NORM_TYPE
|
| 173 |
+
Type of layer normalization to use [layernorm,
|
| 174 |
+
np_layernorm, rmsnorm, fused_rmsnorm]
|
| 175 |
+
--model.tokenizer_path MODEL.TOKENIZER_PATH
|
| 176 |
+
Tokenizer path
|
| 177 |
+
--profiling.enable_profiling
|
| 178 |
+
Whether to enable pytorch profiler
|
| 179 |
+
--profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
|
| 180 |
+
Trace files location
|
| 181 |
+
--profiling.profile_freq PROFILING.PROFILE_FREQ
|
| 182 |
+
How often to collect profiler traces, in iterations
|
| 183 |
+
--profiling.enable_memory_snapshot
|
| 184 |
+
Whether to dump memory snapshot
|
| 185 |
+
--profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
|
| 186 |
+
Memeory snapshot files location
|
| 187 |
+
--optimizer.name OPTIMIZER.NAME
|
| 188 |
+
Optimizer to use
|
| 189 |
+
--optimizer.eps OPTIMIZER.EPS
|
| 190 |
+
Epsilon value for the optimizer.
|
| 191 |
+
--optimizer.fused Whether the fused implementation(CUDA only) is used.
|
| 192 |
+
--optimizer.scheduler {wsd,cosine,linear}
|
| 193 |
+
Scheduler to use. Currently supported: wsd, cosine,
|
| 194 |
+
and linear.
|
| 195 |
+
--optimizer.lr OPTIMIZER.LR
|
| 196 |
+
Learning rate to use
|
| 197 |
+
--optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
|
| 198 |
+
Min lr ratio for lr scheduler
|
| 199 |
+
--optimizer.early_step_in_backward
|
| 200 |
+
Whether to apply optimizer in the backward. Caution,
|
| 201 |
+
optimizer_in_backward is not compatible with gradients
|
| 202 |
+
clipping, users should not call
|
| 203 |
+
register_post_accumulate_grad_hook after the optimizer
|
| 204 |
+
is built.
|
| 205 |
+
--training.batch_size TRAINING.BATCH_SIZE
|
| 206 |
+
Batch size
|
| 207 |
+
--training.seq_len TRAINING.SEQ_LEN
|
| 208 |
+
Sequence length
|
| 209 |
+
--training.context_len TRAINING.CONTEXT_LEN
|
| 210 |
+
Max length allowed for each sequence
|
| 211 |
+
--training.varlen Whether to take sequences of variable length as input
|
| 212 |
+
--training.warmup_steps TRAINING.WARMUP_STEPS
|
| 213 |
+
Steps for lr scheduler warmup, normally 1/5 of
|
| 214 |
+
--training.steps
|
| 215 |
+
--training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
|
| 216 |
+
Number of steps to accumulate gradients before
|
| 217 |
+
updating parameters
|
| 218 |
+
--training.steps TRAINING.STEPS
|
| 219 |
+
How many train steps to run
|
| 220 |
+
--training.max_norm TRAINING.MAX_NORM
|
| 221 |
+
Max norm for gradient clipping
|
| 222 |
+
--training.skip_nan_inf
|
| 223 |
+
Skip batch updates when NaN or INF gradients are
|
| 224 |
+
encountered during training
|
| 225 |
+
--training.dataset TRAINING.DATASET
|
| 226 |
+
Dataset to use, with comma separated values
|
| 227 |
+
--training.dataset_name TRAINING.DATASET_NAME
|
| 228 |
+
The name of the dataset config, with comma separated
|
| 229 |
+
values if provided
|
| 230 |
+
--training.dataset_split TRAINING.DATASET_SPLIT
|
| 231 |
+
Dataset split to use, with comma separated values if
|
| 232 |
+
provided
|
| 233 |
+
--training.data_dir TRAINING.DATA_DIR
|
| 234 |
+
Data dirs to use, with comma separated values if
|
| 235 |
+
provided
|
| 236 |
+
--training.data_files TRAINING.DATA_FILES
|
| 237 |
+
Data files to use, with comma separated values if
|
| 238 |
+
provided
|
| 239 |
+
--training.data_probs TRAINING.DATA_PROBS
|
| 240 |
+
Data sampling probabilities, with comma separated
|
| 241 |
+
values if provided
|
| 242 |
+
--training.streaming Whether to load dataset in streaming mode, used for
|
| 243 |
+
huge dataset
|
| 244 |
+
--training.num_workers TRAINING.NUM_WORKERS
|
| 245 |
+
Number of subprocesses to use for data loading. 0
|
| 246 |
+
means that the data will be loaded in the main
|
| 247 |
+
process.
|
| 248 |
+
--training.prefetch_factor TRAINING.PREFETCH_FACTOR
|
| 249 |
+
Number of batches loaded in advance by each worker.2
|
| 250 |
+
means there will be a total of 2 * num_workers batches
|
| 251 |
+
prefetched across all workers.
|
| 252 |
+
--training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
|
| 253 |
+
The `data_parallel_replicate_degree` argument
|
| 254 |
+
specifies the degree of data parallelism for weight
|
| 255 |
+
replication. When this value is greater than 1,
|
| 256 |
+
weights will be replicated across
|
| 257 |
+
`data_parallel_replicate_degree` ranks. If
|
| 258 |
+
`data_parallel_shard_degree` is also greater than 1,
|
| 259 |
+
the parallelism method used is HSDP (Hybrid Sharded
|
| 260 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 261 |
+
used is DDP (Distributed Data Parallelism). 1 means
|
| 262 |
+
disabled.
|
| 263 |
+
--training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
|
| 264 |
+
The `data_parallel_shard_degree` argument specifies
|
| 265 |
+
the degree of data parallelism for weight sharding.
|
| 266 |
+
When this value is greater than 1, weights will be
|
| 267 |
+
sharded across `data_parallel_shard_degree` ranks. If
|
| 268 |
+
`data_parallel_replicate_degree` is also greater than
|
| 269 |
+
1, the parallelism method used is HSDP (Hybrid Sharded
|
| 270 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 271 |
+
used is FSDP (Fully Sharded Data Parallelism). -1
|
| 272 |
+
means leftover ranks will be used (After
|
| 273 |
+
DP_REPLICATE/SP/PP). Note that only
|
| 274 |
+
`data_parallel_shard_degree` can be negative. 1 means
|
| 275 |
+
disabled.
|
| 276 |
+
--training.enable_cpu_offload
|
| 277 |
+
Whether to apply CPU offloading of parameters,
|
| 278 |
+
gradients, and optimizer states in FSDP
|
| 279 |
+
--training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
|
| 280 |
+
Tensor Parallelism degree. 1 means disabled.
|
| 281 |
+
--training.disable_loss_parallel
|
| 282 |
+
Whether to apply loss parallel when sequence parallel
|
| 283 |
+
is enabled
|
| 284 |
+
--training.mixed_precision_param {bfloat16,float32}
|
| 285 |
+
torch dtype to use for parameters when applying mixed
|
| 286 |
+
precision via FSDP. This feature only takes effect
|
| 287 |
+
when data_parallel_shard_degree > 1
|
| 288 |
+
--training.mixed_precision_reduce {float32}
|
| 289 |
+
torch dtype to use for reductions when applying mixed
|
| 290 |
+
precision via FSDP. This feature only takes effect
|
| 291 |
+
when data_parallel_shard_degree > 1
|
| 292 |
+
--training.compile Whether to compile the model
|
| 293 |
+
--training.gc_freq TRAINING.GC_FREQ
|
| 294 |
+
Python garbage control scheduling interval, in steps
|
| 295 |
+
--training.seed TRAINING.SEED
|
| 296 |
+
Choose the base RNG seed used for training
|
| 297 |
+
--training.deterministic
|
| 298 |
+
Use deterministic algorithms wherever possible, may be
|
| 299 |
+
slower
|
| 300 |
+
--metrics.log_freq METRICS.LOG_FREQ
|
| 301 |
+
How often to log metrics to TensorBoard, in iterations
|
| 302 |
+
--metrics.enable_tensorboard
|
| 303 |
+
Whether to log metrics to TensorBoard
|
| 304 |
+
--metrics.disable_color_printing
|
| 305 |
+
Whether to disable color printing in logs
|
| 306 |
+
--metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
|
| 307 |
+
Folder to dump TensorBoard states
|
| 308 |
+
--metrics.rank_0_only
|
| 309 |
+
Whether to save TensorBoard metrics only for rank 0 or
|
| 310 |
+
for all ranks. When pipeline_parallel_degree is > 1,
|
| 311 |
+
this option uses the 0th rank of the last stage
|
| 312 |
+
pipeline group, which is the only stage that computes
|
| 313 |
+
loss metrics.
|
| 314 |
+
--metrics.enable_wandb
|
| 315 |
+
Whether to log metrics to Weights & Biases
|
| 316 |
+
--experimental.enable_async_tensor_parallel
|
| 317 |
+
Whether to apply async tensor parallel (currently only
|
| 318 |
+
effective when compile is enabled)
|
| 319 |
+
--experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
|
| 320 |
+
Pipeline Parallelism degree, or number of ranks. 1
|
| 321 |
+
means disabled. If using looped schedules, this still
|
| 322 |
+
specifies the number of physical ranks, not the number
|
| 323 |
+
of stages. Stages per rank are inferred from split
|
| 324 |
+
points degree, and schedule.
|
| 325 |
+
--experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
|
| 326 |
+
Specify comma-separated names of modules to use as the
|
| 327 |
+
beginning of a split point. e.g. "layers.0,layers.2"
|
| 328 |
+
will cause the model to be split into 3 stages, the
|
| 329 |
+
first containing all the layers up to layers.0, the
|
| 330 |
+
second containing layers.0 and up to layers.2, the
|
| 331 |
+
third containing layers.2 and all the remaining
|
| 332 |
+
layers. Note: fully-automated splitting may be enabled
|
| 333 |
+
in the future, but currently the split points must be
|
| 334 |
+
specified manually.
|
| 335 |
+
--experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
|
| 336 |
+
Specify the Pipeline Parallel schedule to use. The
|
| 337 |
+
supported schedules are: https://github.com/pytorch/py
|
| 338 |
+
torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
|
| 339 |
+
rch/distributed/pipelining/schedules.py#L2161. The
|
| 340 |
+
schedule must be compatible with the split points and
|
| 341 |
+
stages_per_rank. Looped schedules (e.g.
|
| 342 |
+
Interleaved1F1B) require specifying
|
| 343 |
+
pipeline_parallel_degree = number of ranks, and
|
| 344 |
+
split_points = number of stages - 1
|
| 345 |
+
--experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
|
| 346 |
+
Specify the path to the pipeline parallel schedule csv
|
| 347 |
+
file to use. The pipeline_parallel_schedule argument
|
| 348 |
+
must be either PipelineScheduleSingle,
|
| 349 |
+
PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 350 |
+
--experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
|
| 351 |
+
How many microbatches to split the global training
|
| 352 |
+
batch into when using pipeline parallelism. The global
|
| 353 |
+
training batch size must be evenly divisible by the
|
| 354 |
+
number of microbatches. The default value will be the
|
| 355 |
+
number of pipeline stages, if unspecified.
|
| 356 |
+
--experimental.enable_compiled_autograd
|
| 357 |
+
Enable CompiledAutograd to compile the backward.
|
| 358 |
+
--experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
|
| 359 |
+
Context parallelism degree. 1 means disabled.
|
| 360 |
+
--experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
|
| 361 |
+
The collective to use in context parallel SDPA for kv
|
| 362 |
+
shards exchange. 'allgather' means to all-gather all
|
| 363 |
+
kv shards on ranks after the first sub-SDPA
|
| 364 |
+
computation, 'alltoall' means to all-to-all shuffle
|
| 365 |
+
the kv shards. The default value is 'allgather'.
|
| 366 |
+
--checkpoint.enable_checkpoint
|
| 367 |
+
Whether to enable checkpoint
|
| 368 |
+
--checkpoint.folder CHECKPOINT.FOLDER
|
| 369 |
+
The folder to store the checkpoints. When
|
| 370 |
+
enable_checkpoint is set to true, checkpoints will be
|
| 371 |
+
in {--job.dump_folder}/{--checkpoint.folder}.
|
| 372 |
+
--checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
|
| 373 |
+
Checkpointing interval unit of measurement ['step',
|
| 374 |
+
'seconds']
|
| 375 |
+
--checkpoint.interval CHECKPOINT.INTERVAL
|
| 376 |
+
Checkpointing interval, in steps or seconds depending
|
| 377 |
+
on --checkpoint.interval_type
|
| 378 |
+
--checkpoint.model_weights_only
|
| 379 |
+
When model_weights_only=True, only model weights will
|
| 380 |
+
be saved at the end of training. With this,
|
| 381 |
+
checkpoints can be loaded using `torch.load(...,
|
| 382 |
+
weights_only=True)` after conversion. When
|
| 383 |
+
model_weights_only=False, the full checkpoint will be
|
| 384 |
+
saved. A full checkpoint includes model, optimizer and
|
| 385 |
+
train_state, which can be used to resume training. The
|
| 386 |
+
default value is false.
|
| 387 |
+
--checkpoint.export_dtype {float16,bfloat16,float32}
|
| 388 |
+
Converts to the specified precision when training
|
| 389 |
+
completes and model_weights_only=true. Currently
|
| 390 |
+
supports float32, float16, and bfloat16. The default
|
| 391 |
+
value is float32.
|
| 392 |
+
--checkpoint.create_seed_checkpoint
|
| 393 |
+
Initializes the full model without applying
|
| 394 |
+
parallelisms, and then saves it as a seed checkpoint.
|
| 395 |
+
Note: requires user to call train.py without
|
| 396 |
+
specifying any parallelisms, e.g. NGPU=1. Could be
|
| 397 |
+
implemented as a separate script, but this way shares
|
| 398 |
+
more code.
|
| 399 |
+
--checkpoint.async_mode CHECKPOINT.ASYNC_MODE
|
| 400 |
+
Which async checkpoint mode to use. Currently there
|
| 401 |
+
are 3 different modes. 1. "disabled": synchronized
|
| 402 |
+
checkpointing will be used. 2. "async":
|
| 403 |
+
torch.distributed.checkpoint.async_save will be used.
|
| 404 |
+
1. "async_with_pinned_mem": this option utilizes a
|
| 405 |
+
dedicated pinned memory space and creates a separate
|
| 406 |
+
process for faster GPU->CPU transfer performance and
|
| 407 |
+
eliminating GIL contention. The cost is increased CPU
|
| 408 |
+
memory usage. If insufficient CPU memory is available,
|
| 409 |
+
performance may degrade due to memory paging. For most
|
| 410 |
+
users, "async" should suffice as the performance
|
| 411 |
+
overhead is typically small (on the order of tens of
|
| 412 |
+
seconds) compared to checkpointing frequency. This
|
| 413 |
+
mode can be employed to pursue near-zero checkpointing
|
| 414 |
+
times (e.g., < 1 second) given appropriate hardware
|
| 415 |
+
support such as ample CPU memory and fast PCIe.
|
| 416 |
+
"disabled" is the default mode.
|
| 417 |
+
--checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
|
| 418 |
+
Keeps only the latest k checkpoints, and purging older
|
| 419 |
+
ones. If 0, keep all checkpoints. 0 is the default
|
| 420 |
+
value.
|
| 421 |
+
--checkpoint.load_step CHECKPOINT.LOAD_STEP
|
| 422 |
+
Load the checkpoint at the specified step. If -1, load
|
| 423 |
+
the latest checkpoint.
|
| 424 |
+
--float8.enable_float8_linear
|
| 425 |
+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
|
| 426 |
+
This feature requires you to install 'torchao' which
|
| 427 |
+
can be found here: https://github.com/pytorch/ao
|
| 428 |
+
--float8.enable_fsdp_float8_all_gather
|
| 429 |
+
Whether enable float8 all-gather in FSDP
|
| 430 |
+
--float8.precompute_float8_dynamic_scale_for_fsdp
|
| 431 |
+
Whether precompute float8 scales dynamically for FSDP
|
| 432 |
+
--float8.scaling_type_input {dynamic,delayed}
|
| 433 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 434 |
+
--float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
|
| 435 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 436 |
+
--float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
|
| 437 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 438 |
+
--comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
|
| 439 |
+
Timeout for communication operations, during
|
| 440 |
+
initialization and first train step.
|
| 441 |
+
--comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
|
| 442 |
+
Timeout for communication operations after the first
|
| 443 |
+
train step -- usually a tighter bound than during
|
| 444 |
+
initialization.
|
| 445 |
+
--comm.trace_buf_size COMM.TRACE_BUF_SIZE
|
| 446 |
+
Flight recorder ring buffer size, >0 means recording
|
| 447 |
+
by default, 0 means disabled
|
| 448 |
+
--memory_estimation.enabled
|
| 449 |
+
Whether to estimate memory usage for FSDP
|
| 450 |
+
--memory_estimation.disable_fake_mode
|
| 451 |
+
Whether to estimate memory under FakeTensorMode
|
| 452 |
+
```
|
| 453 |
+
</details>
|
| 454 |
+
|
| 455 |
+
### Training with `torch.compile`
|
| 456 |
+
|
| 457 |
+
Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
|
| 458 |
+
In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
|
| 459 |
+
|
| 460 |
+
However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
|
| 461 |
+
We are actively working on resolving these issues to make compilation transparent to users.
|
| 462 |
+
In the meantime, please ensure you are using the latest dependencies.
|
| 463 |
+
|
| 464 |
+
Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
|
| 465 |
+
|
| 466 |
+
### Training with multiple datasets
|
| 467 |
+
|
| 468 |
+
If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
|
| 469 |
+
`flame` allows training with multiple datasets easily.
|
| 470 |
+
For example, you can specify the following arguments to train on 6 datasets with different proportions:
|
| 471 |
+
|
| 472 |
+
```sh
|
| 473 |
+
--training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
|
| 474 |
+
--training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
|
| 475 |
+
```
|
| 476 |
+
|
| 477 |
+
### ~Finalizing training~
|
| 478 |
+
|
| 479 |
+
> [!NOTE]
|
| 480 |
+
> We have done this conversion automatically in the training script since our latest updates.
|
| 481 |
+
|
| 482 |
+
Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
|
| 483 |
+
To facilitate this, we provide a straightforward conversion script:
|
| 484 |
+
|
| 485 |
+
```sh
|
| 486 |
+
python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
|
| 487 |
+
```
|
| 488 |
+
After this, your model will be in the 🤗 format, ready to be shared or deployed.
|
| 489 |
+
You can then easily publish your model using the `huggingface_hub` for wider accessibility.
|
| 490 |
+
|
| 491 |
+
### Continual training
|
| 492 |
+
|
| 493 |
+
If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
|
| 494 |
+
This allows you to seamlessly resume training with `flame`.
|
| 495 |
+
```sh
|
| 496 |
+
python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
|
| 497 |
+
```
|
| 498 |
+
Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
|
| 499 |
+
The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
|
| 500 |
+
|
| 501 |
+
Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
|
| 502 |
+
|
| 503 |
+
## Multi-node training
|
| 504 |
+
|
| 505 |
+
If you have access to multi-node GPUs, consider leveraging them for optimal performance.
|
| 506 |
+
This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
|
| 507 |
+
|
| 508 |
+
To set up multi-node training:
|
| 509 |
+
* Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
|
| 510 |
+
* If you're using a job scheduler like Slurm, it will handle these variables for you.
|
| 511 |
+
|
| 512 |
+
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
|
torchtitan/components/__pycache__/ft.cpython-311.pyc
ADDED
|
Binary file (7.05 kB). View file
|
|
|
torchtitan/components/optimizer.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Any, Generic, Iterator, TypeVar
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.checkpoint.state_dict import (
|
| 13 |
+
get_optimizer_state_dict,
|
| 14 |
+
set_optimizer_state_dict,
|
| 15 |
+
StateDictOptions,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 18 |
+
from torch.optim import Optimizer
|
| 19 |
+
|
| 20 |
+
from torchtitan.components.ft import FTManager, has_torchft
|
| 21 |
+
from torchtitan.config_manager import JobConfig
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"OptimizersContainer",
|
| 25 |
+
"build_optimizers",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if has_torchft:
|
| 30 |
+
import torchft as ft
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
T = TypeVar("T", bound=Optimizer)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class OptimizersContainer(Optimizer, Stateful, Generic[T]):
|
| 37 |
+
"""A container for multiple optimizers.
|
| 38 |
+
|
| 39 |
+
This class is used to wrap multiple optimizers into a single object that can be
|
| 40 |
+
used to reduce the complexity of the training loop. This mimics the behavior of
|
| 41 |
+
``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
|
| 42 |
+
|
| 43 |
+
**Note**
|
| 44 |
+
Users who want to customize the optimizer behavior can inherit from this class and
|
| 45 |
+
extend the functionality as needed. The following methods must follow the same signature
|
| 46 |
+
as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
|
| 47 |
+
``load_state_dict()``.
|
| 48 |
+
|
| 49 |
+
**Limitations**
|
| 50 |
+
This class assumes that all the optimizers are the same type and have the same
|
| 51 |
+
configurations. With this assumption, TorchTitan can support lr scheduler resharding
|
| 52 |
+
(e.g., loading a checkpoint with a different number of GPUs and/or different
|
| 53 |
+
parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
|
| 54 |
+
resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 58 |
+
optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
|
| 59 |
+
name (str): Name of the optimizers.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
optimizers: list[T]
|
| 63 |
+
model_parts: list[nn.Module]
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
model_parts: list[nn.Module],
|
| 68 |
+
optimizer_cls: type[T],
|
| 69 |
+
optimizer_kwargs: dict[str, Any],
|
| 70 |
+
) -> None:
|
| 71 |
+
all_params = []
|
| 72 |
+
self.optimizers = []
|
| 73 |
+
self.model_parts = model_parts
|
| 74 |
+
for model in self.model_parts:
|
| 75 |
+
params = [p for p in model.parameters() if p.requires_grad]
|
| 76 |
+
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
|
| 77 |
+
all_params.extend(params)
|
| 78 |
+
self._validate_length(len(self.model_parts))
|
| 79 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 80 |
+
|
| 81 |
+
def __iter__(self) -> Iterator[T]:
|
| 82 |
+
return iter(self.optimizers)
|
| 83 |
+
|
| 84 |
+
def __len__(self) -> int:
|
| 85 |
+
return len(self.optimizers)
|
| 86 |
+
|
| 87 |
+
def step(self, *args, **kwargs) -> None:
|
| 88 |
+
for optimizer in self.optimizers:
|
| 89 |
+
optimizer.step(*args, **kwargs)
|
| 90 |
+
|
| 91 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 92 |
+
for optimizer in self.optimizers:
|
| 93 |
+
optimizer.zero_grad(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def state_dict(self) -> dict[str, Any]:
|
| 96 |
+
func = functools.partial(
|
| 97 |
+
get_optimizer_state_dict,
|
| 98 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 99 |
+
)
|
| 100 |
+
return {
|
| 101 |
+
k: v
|
| 102 |
+
for sd in map(func, self.model_parts, self.optimizers)
|
| 103 |
+
for k, v in sd.items()
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 107 |
+
func = functools.partial(
|
| 108 |
+
set_optimizer_state_dict,
|
| 109 |
+
optim_state_dict=state_dict,
|
| 110 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 111 |
+
)
|
| 112 |
+
list(map(func, self.model_parts, self.optimizers))
|
| 113 |
+
|
| 114 |
+
def _validate_length(self, expected_length: int) -> None:
|
| 115 |
+
assert expected_length == len(self.optimizers), (
|
| 116 |
+
"Must pass one optimizer per model part or per param if "
|
| 117 |
+
"using OptimizersInBackwardContainer."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _post_init(
|
| 121 |
+
self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
|
| 122 |
+
) -> None:
|
| 123 |
+
# We need to call Optimizer.__init__() to initialize some necessary optimizer
|
| 124 |
+
# functionality such as hooks.
|
| 125 |
+
Optimizer.__init__(self, all_params, optimizer_kwargs)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OptimizersInBackwardContainer(OptimizersContainer):
|
| 129 |
+
"""OptimizersContainer for executing ``optim.step()`` in backward pass.
|
| 130 |
+
|
| 131 |
+
This class extend ``OptimizersContainer`` to support optimizer step in
|
| 132 |
+
backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
|
| 133 |
+
Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
|
| 134 |
+
execute these methods when the gradient is accumulated.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
model_parts: list[nn.Module],
|
| 140 |
+
optimizer_cls: type[T],
|
| 141 |
+
optimizer_kwargs: dict[str, Any],
|
| 142 |
+
) -> None:
|
| 143 |
+
all_params = []
|
| 144 |
+
self.model_parts = model_parts
|
| 145 |
+
|
| 146 |
+
optim_dict = {}
|
| 147 |
+
for model in self.model_parts:
|
| 148 |
+
for p in model.parameters():
|
| 149 |
+
if p.requires_grad:
|
| 150 |
+
optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
|
| 151 |
+
all_params.append(p)
|
| 152 |
+
|
| 153 |
+
def optim_hook(param) -> None:
|
| 154 |
+
optim_dict[param].step()
|
| 155 |
+
optim_dict[param].zero_grad()
|
| 156 |
+
|
| 157 |
+
for model in self.model_parts:
|
| 158 |
+
for param in model.parameters():
|
| 159 |
+
if param.requires_grad:
|
| 160 |
+
param.register_post_accumulate_grad_hook(optim_hook)
|
| 161 |
+
|
| 162 |
+
self.optimizers = list(optim_dict.values())
|
| 163 |
+
|
| 164 |
+
self._validate_length(
|
| 165 |
+
sum(len(list(model.parameters())) for model in self.model_parts)
|
| 166 |
+
)
|
| 167 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 168 |
+
|
| 169 |
+
def step(self) -> None:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
def zero_grad(self) -> None:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class FTOptimizersContainer(OptimizersContainer):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
model_parts: list[nn.Module],
|
| 180 |
+
optimizer_cls: type[T],
|
| 181 |
+
optimizer_kwargs: dict[str, Any],
|
| 182 |
+
ft_manager: "ft.Manager",
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
|
| 185 |
+
|
| 186 |
+
# Force to initialize the optimizer state so that `optim.step()`
|
| 187 |
+
# won't be called by state_dict() and load_state_dict().
|
| 188 |
+
_ = {
|
| 189 |
+
k: v
|
| 190 |
+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
|
| 191 |
+
for k, v in sd.items()
|
| 192 |
+
}
|
| 193 |
+
self.cache_state_dict: dict[str, Any] = {}
|
| 194 |
+
self._ft_optimizer = ft.Optimizer(ft_manager, self)
|
| 195 |
+
self._call_from_ft: bool = False
|
| 196 |
+
|
| 197 |
+
def init_cache_state_dict(self) -> None:
|
| 198 |
+
self.cache_state_dict = super().state_dict()
|
| 199 |
+
|
| 200 |
+
def state_dict(self) -> dict[str, Any]:
|
| 201 |
+
return self.cache_state_dict
|
| 202 |
+
|
| 203 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 204 |
+
# We have to invalidate the `cache_state_dict` because optimizer uses
|
| 205 |
+
# assign instead of copy when doing `load_state_dict()`. Without
|
| 206 |
+
# invalidating the `cache_state_dict`, there will be memory leakage.
|
| 207 |
+
self.cache_state_dict = {}
|
| 208 |
+
super().load_state_dict(state_dict)
|
| 209 |
+
self.init_cache_state_dict()
|
| 210 |
+
|
| 211 |
+
def step(self, *args, **kwargs) -> None:
|
| 212 |
+
"""Calling the correct step() depending on the caller.
|
| 213 |
+
|
| 214 |
+
TorchFT's OptimizerWrapper.step() is designed to be callled only once
|
| 215 |
+
per train step per ft.Manager regardless how many optimizers are used.
|
| 216 |
+
Hence we will need to appropriately dispatch the call.
|
| 217 |
+
"""
|
| 218 |
+
if self._call_from_ft:
|
| 219 |
+
super().step(*args, **kwargs)
|
| 220 |
+
else:
|
| 221 |
+
self._call_from_ft = True
|
| 222 |
+
self._ft_optimizer.step(*args, **kwargs)
|
| 223 |
+
self._call_from_ft = False
|
| 224 |
+
|
| 225 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 226 |
+
"""Calling the correct zero_grad() depending on the caller.
|
| 227 |
+
|
| 228 |
+
Check the comment in ``step()``.
|
| 229 |
+
"""
|
| 230 |
+
if self._call_from_ft:
|
| 231 |
+
super().zero_grad(*args, **kwargs)
|
| 232 |
+
else:
|
| 233 |
+
self._call_from_ft = True
|
| 234 |
+
self._ft_optimizer.zero_grad(*args, **kwargs)
|
| 235 |
+
self._call_from_ft = False
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def build_optimizers(
|
| 239 |
+
model_parts: list[nn.Module],
|
| 240 |
+
job_config: JobConfig,
|
| 241 |
+
ft_manager: FTManager,
|
| 242 |
+
) -> OptimizersContainer:
|
| 243 |
+
"""Create a OptimizersContainer for the given model parts and job config.
|
| 244 |
+
|
| 245 |
+
This function creates a ``OptimizersContainer`` for the given model parts.
|
| 246 |
+
``job_config`` should define the correct optimizer name and parameters.
|
| 247 |
+
This function currently supports creating ``OptimizersContainer`` and
|
| 248 |
+
``OptimizersInBackwardContainer``.
|
| 249 |
+
|
| 250 |
+
**Note**
|
| 251 |
+
Users who want to customize the optimizer behavior can create their own
|
| 252 |
+
``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
|
| 253 |
+
customized ``build_optimizers`` to ``TrainSpec`` will create the customized
|
| 254 |
+
``OptimizersContainer``.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 258 |
+
job_config (JobConfig): Job config containing the optimizer name and parameters.
|
| 259 |
+
"""
|
| 260 |
+
optim_in_bwd = job_config.optimizer.early_step_in_backward
|
| 261 |
+
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
|
| 262 |
+
raise NotImplementedError(
|
| 263 |
+
"Optimizers in backward is not supported with pipeline parallelism."
|
| 264 |
+
)
|
| 265 |
+
name = job_config.optimizer.name
|
| 266 |
+
lr = job_config.optimizer.lr
|
| 267 |
+
eps = job_config.optimizer.eps
|
| 268 |
+
|
| 269 |
+
optim_implementation = job_config.optimizer.implementation
|
| 270 |
+
assert optim_implementation in ["fused", "foreach", "for-loop"]
|
| 271 |
+
|
| 272 |
+
fused = optim_implementation == "fused"
|
| 273 |
+
foreach = optim_implementation == "foreach"
|
| 274 |
+
|
| 275 |
+
optimizer_kwargs = {
|
| 276 |
+
"lr": lr,
|
| 277 |
+
"eps": eps,
|
| 278 |
+
"betas": (0.9, 0.95),
|
| 279 |
+
"weight_decay": 0.1,
|
| 280 |
+
"fused": fused,
|
| 281 |
+
"foreach": foreach,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
optimizer_classes = {
|
| 285 |
+
"Adam": torch.optim.Adam,
|
| 286 |
+
"AdamW": torch.optim.AdamW,
|
| 287 |
+
}
|
| 288 |
+
if name not in optimizer_classes:
|
| 289 |
+
raise NotImplementedError(f"Optimizer {name} not added.")
|
| 290 |
+
optimizer_cls = optimizer_classes[name]
|
| 291 |
+
|
| 292 |
+
if optim_in_bwd and ft_manager.enabled:
|
| 293 |
+
raise ValueError("TorchFT is not supported with optimizers in backward.")
|
| 294 |
+
elif optim_in_bwd:
|
| 295 |
+
return OptimizersInBackwardContainer(
|
| 296 |
+
model_parts, optimizer_cls, optimizer_kwargs
|
| 297 |
+
)
|
| 298 |
+
elif ft_manager.enabled:
|
| 299 |
+
return FTOptimizersContainer(
|
| 300 |
+
model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
|
torchtitan/distributed/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (8.48 kB). View file
|
|
|
torchtitan/distributed/pipeline.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
|
| 10 |
+
from torch.distributed.pipelining.schedules import (
|
| 11 |
+
_PipelineSchedule,
|
| 12 |
+
_PipelineScheduleRuntime,
|
| 13 |
+
get_schedule_class,
|
| 14 |
+
PipelineScheduleMulti,
|
| 15 |
+
PipelineScheduleSingle,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.pipelining.stage import PipelineStage
|
| 18 |
+
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# TODO: It's unclear if this API is general enough to be used by other models.
|
| 27 |
+
# If not, we should move it to a Transformer-specific directory.
|
| 28 |
+
def generate_split_points(
|
| 29 |
+
schedule_str: str,
|
| 30 |
+
layers_per_stage: Optional[int],
|
| 31 |
+
pp_dim: int,
|
| 32 |
+
num_layers: int,
|
| 33 |
+
input_weight: int = 1,
|
| 34 |
+
output_weight: int = 1,
|
| 35 |
+
) -> list[str]:
|
| 36 |
+
"""
|
| 37 |
+
Generate a list of split points based on the number of layers and
|
| 38 |
+
pipeline parallel dimension, ensuring the first and last stages have the least layers.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
schedule_str (str): The string of the schedule name.
|
| 42 |
+
layers_per_stage (int): The number of layers per stage.
|
| 43 |
+
pp_dim (int): The pipeline parallel dimension.
|
| 44 |
+
num_layers (int): The number of layers in the model.
|
| 45 |
+
input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
list[str]: A list of split point FQNs.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
schedule_class = get_schedule_class(schedule_str)
|
| 52 |
+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
|
| 53 |
+
num_stages_per_rank = 1 if is_single_stage_schedule else 2
|
| 54 |
+
|
| 55 |
+
if layers_per_stage is not None:
|
| 56 |
+
total_stages = math.ceil(num_layers / layers_per_stage)
|
| 57 |
+
if total_stages % pp_dim != 0:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})."
|
| 60 |
+
f"Each rank should have the same number of stages. "
|
| 61 |
+
)
|
| 62 |
+
num_stages_per_rank = total_stages // pp_dim
|
| 63 |
+
|
| 64 |
+
if is_single_stage_schedule and num_stages_per_rank != 1:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules."
|
| 67 |
+
)
|
| 68 |
+
elif not is_single_stage_schedule and num_stages_per_rank < 2:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules."
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
total_stages = pp_dim * num_stages_per_rank
|
| 74 |
+
if total_stages > num_layers:
|
| 75 |
+
raise ValueError("Total stages cannot be greater than the number of layers")
|
| 76 |
+
|
| 77 |
+
# Calculate effective number of layers including input and output weights
|
| 78 |
+
effective_num_layers = num_layers + input_weight + output_weight
|
| 79 |
+
base_layers_per_stage = effective_num_layers // total_stages
|
| 80 |
+
|
| 81 |
+
splits = [""] * (total_stages - 1)
|
| 82 |
+
current_layer_index = 0
|
| 83 |
+
|
| 84 |
+
# First stage
|
| 85 |
+
layers_on_first_stage = max(0, base_layers_per_stage - input_weight)
|
| 86 |
+
current_layer_index += layers_on_first_stage
|
| 87 |
+
splits[0] = "layers." + str(current_layer_index)
|
| 88 |
+
|
| 89 |
+
# Last stage
|
| 90 |
+
layers_on_last_stage = max(0, base_layers_per_stage - output_weight)
|
| 91 |
+
splits[-1] = "layers." + str(num_layers - layers_on_last_stage)
|
| 92 |
+
|
| 93 |
+
# Middle stages
|
| 94 |
+
remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1
|
| 95 |
+
middle_stages = len(splits) - 2
|
| 96 |
+
layers_per_middle_stage = remaining_layers // middle_stages
|
| 97 |
+
# split remainder evenly across middle stages
|
| 98 |
+
remainder = remaining_layers % middle_stages
|
| 99 |
+
|
| 100 |
+
for i in range(1, middle_stages + 1):
|
| 101 |
+
current_layer_index += layers_per_middle_stage
|
| 102 |
+
if remainder > 0:
|
| 103 |
+
current_layer_index += 1
|
| 104 |
+
remainder -= 1
|
| 105 |
+
splits[i] = "layers." + str(current_layer_index)
|
| 106 |
+
|
| 107 |
+
logger.info(
|
| 108 |
+
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} "
|
| 109 |
+
"This may be sub-optimal as the number of layers per stage may be unbalanced."
|
| 110 |
+
)
|
| 111 |
+
return splits
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_pipeline_schedule(
|
| 115 |
+
job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
|
| 116 |
+
) -> _PipelineSchedule:
|
| 117 |
+
"""Builds a pipeline schedule for the given job configuration and stages.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
job_config (JobConfig): The job configuration.
|
| 121 |
+
stages (list[PipelineStage]): The stages to be scheduled.
|
| 122 |
+
loss_fn (Callable): The loss function.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
_PipelineSchedule: The pipeline schedule for the given stages.
|
| 126 |
+
"""
|
| 127 |
+
pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv
|
| 128 |
+
|
| 129 |
+
# Validate that pp_schedule_csv is a valid path
|
| 130 |
+
if pp_schedule_csv:
|
| 131 |
+
if not os.path.isfile(pp_schedule_csv):
|
| 132 |
+
raise FileNotFoundError(
|
| 133 |
+
f"The specified path {pp_schedule_csv} does not exist or is not a file."
|
| 134 |
+
)
|
| 135 |
+
schedule_class = _PipelineScheduleRuntime
|
| 136 |
+
else:
|
| 137 |
+
schedule_class = get_schedule_class(
|
| 138 |
+
job_config.parallelism.pipeline_parallel_schedule
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
|
| 142 |
+
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
|
| 143 |
+
batch_size = job_config.training.batch_size
|
| 144 |
+
# validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
|
| 145 |
+
if batch_size % microbatch_size != 0:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
|
| 148 |
+
"Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
|
| 149 |
+
)
|
| 150 |
+
n_microbatches = batch_size // microbatch_size
|
| 151 |
+
# We expect that the number of local stages (`len(stages)`) is the same across all ranks
|
| 152 |
+
num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
|
| 153 |
+
if n_microbatches < num_total_stages:
|
| 154 |
+
logger.warning(
|
| 155 |
+
f"Number of microbatches ({n_microbatches}) is less than the total number "
|
| 156 |
+
f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
schedule = schedule_class(
|
| 160 |
+
stages if looped_schedule else stages[0],
|
| 161 |
+
n_microbatches=n_microbatches,
|
| 162 |
+
loss_fn=loss_fn,
|
| 163 |
+
)
|
| 164 |
+
logger.info(
|
| 165 |
+
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
|
| 166 |
+
f"with {n_microbatches} microbatches and {num_total_stages} stages."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if pp_schedule_csv:
|
| 170 |
+
assert schedule_class in [
|
| 171 |
+
PipelineScheduleSingle,
|
| 172 |
+
PipelineScheduleMulti,
|
| 173 |
+
_PipelineScheduleRuntime,
|
| 174 |
+
], (
|
| 175 |
+
"Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), "
|
| 176 |
+
"and _PipelineScheduleRuntime support csv schedules"
|
| 177 |
+
)
|
| 178 |
+
schedule._load_csv(pp_schedule_csv)
|
| 179 |
+
|
| 180 |
+
return schedule
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# TODO(whc) should this be a utility inside torch.pipelining?
|
| 184 |
+
def stage_ids_this_rank(
|
| 185 |
+
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
|
| 186 |
+
) -> tuple[int]:
|
| 187 |
+
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
|
| 188 |
+
assert (
|
| 189 |
+
num_stages % pp_size == 0
|
| 190 |
+
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
|
| 191 |
+
stages_per_rank = num_stages // pp_size
|
| 192 |
+
if style == "loop":
|
| 193 |
+
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
|
| 194 |
+
elif style == "v":
|
| 195 |
+
assert (
|
| 196 |
+
stages_per_rank == 2
|
| 197 |
+
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
|
| 198 |
+
stage_v_pairs = list(
|
| 199 |
+
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
|
| 200 |
+
)
|
| 201 |
+
return stage_v_pairs[pp_rank]
|
torchtitan/experiments/deepseek_v3/checkpoint.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from typing import Dict, Optional, Set, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
|
| 15 |
+
from transformers.utils import cached_file
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
|
| 24 |
+
try:
|
| 25 |
+
with open(file_path, "r") as file:
|
| 26 |
+
data = json.load(file)
|
| 27 |
+
|
| 28 |
+
if "weight_map" in data and isinstance(data["weight_map"], dict):
|
| 29 |
+
return data["weight_map"]
|
| 30 |
+
else:
|
| 31 |
+
logger.info("No 'weight_map' dictionary found in the JSON file.")
|
| 32 |
+
return None
|
| 33 |
+
except (json.JSONDecodeError, Exception) as e:
|
| 34 |
+
logger.info(f"An error occurred while reading the JSON file: {str(e)}")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_hf_weight_map_and_path(
|
| 39 |
+
model_id: str,
|
| 40 |
+
) -> Tuple[Dict[str, str], str]:
|
| 41 |
+
"""Get the weight map for a given HF model id and also the cache path for loading the weights"""
|
| 42 |
+
try:
|
| 43 |
+
index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(
|
| 46 |
+
f"Model `{model_id}` not found in HF cache. "
|
| 47 |
+
f"You can download the model using `python download.py {model_id}"
|
| 48 |
+
)
|
| 49 |
+
raise e
|
| 50 |
+
|
| 51 |
+
weight_map = read_weights_from_json(index_file)
|
| 52 |
+
weight_path = os.path.dirname(index_file)
|
| 53 |
+
logger.info(f"Loading weights from: {weight_path}")
|
| 54 |
+
return weight_map, weight_path
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_needed_files(
|
| 58 |
+
state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
|
| 59 |
+
) -> Set[str]:
|
| 60 |
+
needed_files = set()
|
| 61 |
+
for param in state_dict.keys():
|
| 62 |
+
file = weight_map.get(param)
|
| 63 |
+
if file:
|
| 64 |
+
needed_files.add(file)
|
| 65 |
+
elif param.endswith("weight"):
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Parameter {param} not found in weight map, please check..."
|
| 68 |
+
)
|
| 69 |
+
logger.info(f"Needed files: {needed_files}")
|
| 70 |
+
return needed_files
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_safetensor_file(
|
| 74 |
+
full_path: str, device: torch.device
|
| 75 |
+
) -> Dict[str, torch.Tensor]:
|
| 76 |
+
tensors = {}
|
| 77 |
+
with safe_open(full_path, framework="pt", device=device) as f:
|
| 78 |
+
for k in f.keys():
|
| 79 |
+
tensors[k] = f.get_tensor(k)
|
| 80 |
+
logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
|
| 81 |
+
return tensors
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_safetensor_weights(
|
| 85 |
+
model: torch.nn.Module,
|
| 86 |
+
weight_map: Dict[str, str],
|
| 87 |
+
file_location: str,
|
| 88 |
+
device: torch.device,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Load safetensor weights into a `nn.Module`.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
model (Module): The PyTorch module to load weights into. It may be a
|
| 95 |
+
model chunk or a full model.
|
| 96 |
+
weight_map (Dict[str, str]): Mapping of model parameters to file names.
|
| 97 |
+
file_location (str): Directory containing the weight files.
|
| 98 |
+
device (torch.device): The device to load tensors onto.
|
| 99 |
+
"""
|
| 100 |
+
model_state_dict = model.state_dict()
|
| 101 |
+
needed_files = get_needed_files(model_state_dict, weight_map)
|
| 102 |
+
updated_states: Set[str] = set()
|
| 103 |
+
|
| 104 |
+
for file in needed_files:
|
| 105 |
+
full_path = os.path.join(file_location, file)
|
| 106 |
+
try:
|
| 107 |
+
checkpoint = load_safetensor_file(full_path, "cpu")
|
| 108 |
+
except FileNotFoundError:
|
| 109 |
+
logger.error(f"File not found: {full_path}")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
|
| 112 |
+
|
| 113 |
+
matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
|
| 114 |
+
for key in matched_keys:
|
| 115 |
+
# Check shape
|
| 116 |
+
if model_state_dict[key].shape != checkpoint[key].shape:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Shape mismatch for {key}: "
|
| 119 |
+
f"model needs {model_state_dict[key].shape}, but "
|
| 120 |
+
f"checkpoint has {checkpoint[key].shape}"
|
| 121 |
+
)
|
| 122 |
+
model_state_dict[key] = checkpoint[key].to(device)
|
| 123 |
+
|
| 124 |
+
updated_states.update(matched_keys)
|
| 125 |
+
|
| 126 |
+
missing_keys = set(model_state_dict.keys()) - updated_states
|
| 127 |
+
if missing_keys:
|
| 128 |
+
raise RuntimeError(
|
| 129 |
+
f"Partially updated state dict. Missing parameters: {missing_keys}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
model.load_state_dict(model_state_dict, strict=False, assign=True)
|
| 133 |
+
logger.info(f"Successfully loaded {len(updated_states)} weights into model")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_weights_from_hf(
|
| 137 |
+
model: torch.nn.Module,
|
| 138 |
+
distribution: str,
|
| 139 |
+
device: torch.device,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Load the weights from Hugging Face format (index file + multiple safetensor
|
| 143 |
+
files), and fill into `model`. Model config is needed b/c we permute
|
| 144 |
+
wq and wk weights based on attn heads.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
weight_map, weight_path = get_hf_weight_map_and_path(distribution)
|
| 148 |
+
|
| 149 |
+
load_safetensor_weights(
|
| 150 |
+
model,
|
| 151 |
+
weight_map,
|
| 152 |
+
weight_path,
|
| 153 |
+
device,
|
| 154 |
+
)
|
torchtitan/experiments/deepseek_v3/download.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Usage:
|
| 8 |
+
# Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
|
| 9 |
+
# python download.py {model_id} [custom_model_path]
|
| 10 |
+
# Examples:
|
| 11 |
+
# python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
|
| 12 |
+
# python download.py custom "deepseek-ai/new-model" # Download a custom model path
|
| 13 |
+
|
| 14 |
+
# Available models:
|
| 15 |
+
# "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 16 |
+
# "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 17 |
+
# "v2": "deepseek-ai/DeepSeek-V2",
|
| 18 |
+
# "v3": "deepseek-ai/deepseek-v3",
|
| 19 |
+
# "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 20 |
+
# "custom": None, # Placeholder for custom models
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
from transformers import AutoModelForCausalLM
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
MODELS = {
|
| 29 |
+
"v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
| 30 |
+
"v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
|
| 31 |
+
"v2": "deepseek-ai/DeepSeek-V2",
|
| 32 |
+
"v3": "deepseek-ai/deepseek-v3",
|
| 33 |
+
"v3-0324": "deepseek-ai/DeepSeek-V3-0324",
|
| 34 |
+
"custom": None, # For custom (any) models
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def print_usage():
|
| 39 |
+
print("Usage:")
|
| 40 |
+
print(" python download.py [model_version]")
|
| 41 |
+
print(" python download.py custom [custom_model_path]")
|
| 42 |
+
print("\nAvailable predefined models:")
|
| 43 |
+
for key, model in MODELS.items():
|
| 44 |
+
if key != "custom": # Skip the custom placeholder
|
| 45 |
+
print(f" {key}: {model}")
|
| 46 |
+
print("\nFor custom models:")
|
| 47 |
+
print(" custom: Specify your own model path")
|
| 48 |
+
print(' Example: python download.py custom "organization/model-name"')
|
| 49 |
+
sys.exit(1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Process command line arguments
|
| 53 |
+
if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
|
| 54 |
+
print_usage()
|
| 55 |
+
|
| 56 |
+
if sys.argv[1] == "custom":
|
| 57 |
+
if len(sys.argv) != 3:
|
| 58 |
+
print("Error: Custom model requires a model path")
|
| 59 |
+
print_usage()
|
| 60 |
+
model_id = sys.argv[2]
|
| 61 |
+
print(f"Using custom model: {model_id}")
|
| 62 |
+
else:
|
| 63 |
+
model_id = MODELS[sys.argv[1]]
|
| 64 |
+
print(f"Downloading model: {model_id}")
|
| 65 |
+
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
)
|
torchtitan/experiments/deepseek_v3/generate.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 4 generate.py
|
| 8 |
+
|
| 9 |
+
# use inference.sh "Your Question Here?" to run inference with a single prompt.
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
from checkpoint import load_weights_from_hf
|
| 18 |
+
from model import DeepseekForCausalLM
|
| 19 |
+
from model_config import deepseek_config_registry
|
| 20 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 21 |
+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
|
| 22 |
+
from torchtitan.tools.utils import Color
|
| 23 |
+
from transformers import AutoTokenizer
|
| 24 |
+
|
| 25 |
+
# Uncomment the model you want to run.
|
| 26 |
+
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
|
| 27 |
+
# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
|
| 31 |
+
"""Parse and colorize chat output with optional colors for each role."""
|
| 32 |
+
lines = text.split("\n")
|
| 33 |
+
result = []
|
| 34 |
+
|
| 35 |
+
current_role = None
|
| 36 |
+
current_content = []
|
| 37 |
+
|
| 38 |
+
def _process_current_content():
|
| 39 |
+
if not current_role or not current_content:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
content = "\n".join(current_content)
|
| 43 |
+
if current_role == "output":
|
| 44 |
+
return (
|
| 45 |
+
f"Output: {output_color}{content}{color.reset}"
|
| 46 |
+
if output_color
|
| 47 |
+
else f"Output: {content}"
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
try:
|
| 51 |
+
prefix, rest = current_content[0].split(":", 1)
|
| 52 |
+
role_color = user_color if current_role == "user" else assistant_color
|
| 53 |
+
if role_color:
|
| 54 |
+
formatted = f"{prefix}:{role_color}{rest}{color.reset}"
|
| 55 |
+
if len(current_content) > 1:
|
| 56 |
+
formatted += (
|
| 57 |
+
f"{role_color}\n"
|
| 58 |
+
+ "\n".join(current_content[1:])
|
| 59 |
+
+ f"{color.reset}"
|
| 60 |
+
)
|
| 61 |
+
return formatted
|
| 62 |
+
except ValueError:
|
| 63 |
+
pass
|
| 64 |
+
return content
|
| 65 |
+
|
| 66 |
+
for line in lines:
|
| 67 |
+
if line.startswith("Output:"):
|
| 68 |
+
if processed := _process_current_content():
|
| 69 |
+
result.append(processed)
|
| 70 |
+
current_role = "output"
|
| 71 |
+
content = line[len("Output:") :].strip()
|
| 72 |
+
if output_color:
|
| 73 |
+
content = f"Output: {output_color}{content}{color.reset}"
|
| 74 |
+
else:
|
| 75 |
+
content = f"Output: {content}"
|
| 76 |
+
result.append(content)
|
| 77 |
+
current_content = []
|
| 78 |
+
|
| 79 |
+
elif line.startswith("User:"):
|
| 80 |
+
if processed := _process_current_content():
|
| 81 |
+
result.append(processed)
|
| 82 |
+
current_role = "user"
|
| 83 |
+
current_content = [line]
|
| 84 |
+
|
| 85 |
+
elif line.startswith("Assistant:"):
|
| 86 |
+
if processed := _process_current_content():
|
| 87 |
+
result.append(processed)
|
| 88 |
+
current_role = "assistant"
|
| 89 |
+
current_content = [line]
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
if current_content:
|
| 93 |
+
current_content.append(line)
|
| 94 |
+
elif line.strip() and current_role is None:
|
| 95 |
+
# Handle system message at the beginning
|
| 96 |
+
current_role = "output"
|
| 97 |
+
if output_color:
|
| 98 |
+
result.append(f"Output: {output_color}{line.strip()}{color.reset}")
|
| 99 |
+
else:
|
| 100 |
+
result.append(f"Output: {line.strip()}")
|
| 101 |
+
|
| 102 |
+
# Process the last segment
|
| 103 |
+
if processed := _process_current_content():
|
| 104 |
+
result.append(processed)
|
| 105 |
+
|
| 106 |
+
return "\n".join(result)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
color = Color()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class DistConfig:
|
| 114 |
+
mesh: DeviceMesh
|
| 115 |
+
pp_mesh: DeviceMesh
|
| 116 |
+
ep_mesh: DeviceMesh
|
| 117 |
+
pp_size: int
|
| 118 |
+
ep_size: int
|
| 119 |
+
ep_rank: int
|
| 120 |
+
pp_rank: int
|
| 121 |
+
device: torch.device
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_model(dist_config: DistConfig):
|
| 125 |
+
model_args = deepseek_config_registry[model_id]
|
| 126 |
+
model_args.ep_size = dist_config.ep_size
|
| 127 |
+
model_args.num_stages = dist_config.pp_size
|
| 128 |
+
model_args.stage_idx = dist_config.pp_rank
|
| 129 |
+
model_args.max_seq_len = 16384
|
| 130 |
+
|
| 131 |
+
with dist_config.device, dist_config.mesh:
|
| 132 |
+
model = DeepseekForCausalLM(model_args)
|
| 133 |
+
load_weights_from_hf(model, model_id, dist_config.device)
|
| 134 |
+
model.eval()
|
| 135 |
+
model.setup_symm_mem(torch.bfloat16, dist_config.device)
|
| 136 |
+
|
| 137 |
+
stage = PipelineStage(
|
| 138 |
+
model,
|
| 139 |
+
dist_config.pp_rank,
|
| 140 |
+
dist_config.pp_size,
|
| 141 |
+
dist_config.device,
|
| 142 |
+
group=dist_config.pp_mesh.get_group(),
|
| 143 |
+
)
|
| 144 |
+
pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
|
| 145 |
+
return model, pp_schedule
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def create_dist_config(mesh: DeviceMesh):
|
| 149 |
+
rank = dist.get_rank()
|
| 150 |
+
device_count = torch.cuda.device_count()
|
| 151 |
+
device = torch.device("cuda", rank % device_count)
|
| 152 |
+
|
| 153 |
+
dist_config = DistConfig(
|
| 154 |
+
mesh=mesh,
|
| 155 |
+
pp_mesh=mesh["pp"],
|
| 156 |
+
ep_mesh=mesh["ep"],
|
| 157 |
+
pp_rank=mesh["pp"].get_local_rank(),
|
| 158 |
+
pp_size=mesh["pp"].size(),
|
| 159 |
+
ep_size=mesh["ep"].size(),
|
| 160 |
+
ep_rank=mesh["ep"].get_local_rank(),
|
| 161 |
+
device=device,
|
| 162 |
+
)
|
| 163 |
+
return dist_config
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def decode(tokenizer, x):
|
| 167 |
+
output = tokenizer.decode(x[0])
|
| 168 |
+
# Clean up the output by removing special tokens
|
| 169 |
+
bos = tokenizer.bos_token
|
| 170 |
+
output = output.replace(bos, "")
|
| 171 |
+
# Truncate at end of sentence token
|
| 172 |
+
eos_token = tokenizer.eos_token
|
| 173 |
+
if eos_token and eos_token in output:
|
| 174 |
+
output = output.split(eos_token)[0]
|
| 175 |
+
colored_output = colorize_chat(
|
| 176 |
+
output,
|
| 177 |
+
user_color=color.green,
|
| 178 |
+
assistant_color=color.cyan,
|
| 179 |
+
output_color=color.blue,
|
| 180 |
+
)
|
| 181 |
+
return colored_output
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@torch.inference_mode()
|
| 185 |
+
def generate(
|
| 186 |
+
model,
|
| 187 |
+
pp_schedule,
|
| 188 |
+
tokenizer,
|
| 189 |
+
dist_config,
|
| 190 |
+
messages: list[dict],
|
| 191 |
+
n_tokens: int = 50,
|
| 192 |
+
):
|
| 193 |
+
rank = dist.get_rank()
|
| 194 |
+
device = dist_config.device
|
| 195 |
+
x = tokenizer.apply_chat_template(
|
| 196 |
+
[messages] * dist_config.pp_size,
|
| 197 |
+
add_generation_prompt=True,
|
| 198 |
+
return_tensors="pt",
|
| 199 |
+
)
|
| 200 |
+
next_idx = x.shape[-1]
|
| 201 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 202 |
+
x = x.to(device)
|
| 203 |
+
|
| 204 |
+
for _ in range(n_tokens):
|
| 205 |
+
if dist_config.pp_size > 1:
|
| 206 |
+
if dist_config.pp_rank == 0:
|
| 207 |
+
pp_schedule.step(x)
|
| 208 |
+
torch.distributed.broadcast(
|
| 209 |
+
x,
|
| 210 |
+
group=dist_config.pp_mesh.get_group(),
|
| 211 |
+
group_src=dist_config.pp_size - 1,
|
| 212 |
+
)
|
| 213 |
+
elif dist_config.pp_rank == dist_config.pp_size - 1:
|
| 214 |
+
preds = pp_schedule.step()
|
| 215 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 216 |
+
x[:, next_idx] = next_token
|
| 217 |
+
torch.distributed.broadcast(
|
| 218 |
+
x,
|
| 219 |
+
group=dist_config.pp_mesh.get_group(),
|
| 220 |
+
group_src=dist_config.pp_size - 1,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
pp_schedule.step()
|
| 224 |
+
torch.distributed.broadcast(
|
| 225 |
+
x,
|
| 226 |
+
group=dist_config.pp_mesh.get_group(),
|
| 227 |
+
group_src=dist_config.pp_size - 1,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
next_idx += 1
|
| 231 |
+
else:
|
| 232 |
+
preds = model(x)
|
| 233 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 234 |
+
x[:, next_idx] = next_token
|
| 235 |
+
next_idx += 1
|
| 236 |
+
|
| 237 |
+
if rank == 0:
|
| 238 |
+
colored_output = decode(tokenizer, x)
|
| 239 |
+
print(f"Without CUDA Graph:\n{colored_output}")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@torch.inference_mode()
|
| 243 |
+
def generate_with_cuda_graph(
|
| 244 |
+
model,
|
| 245 |
+
tokenizer,
|
| 246 |
+
dist_config,
|
| 247 |
+
messages: list[dict],
|
| 248 |
+
n_tokens: int = 10,
|
| 249 |
+
):
|
| 250 |
+
rank = dist.get_rank()
|
| 251 |
+
device = dist_config.device
|
| 252 |
+
x = tokenizer.apply_chat_template(
|
| 253 |
+
[messages] * dist_config.pp_size,
|
| 254 |
+
add_generation_prompt=True,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
)
|
| 257 |
+
next_idx = x.shape[-1]
|
| 258 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 259 |
+
x = x.to(device)
|
| 260 |
+
|
| 261 |
+
torch.cuda.synchronize()
|
| 262 |
+
|
| 263 |
+
# Create CUDA graph
|
| 264 |
+
g = torch.cuda.CUDAGraph()
|
| 265 |
+
with torch.cuda.graph(g):
|
| 266 |
+
preds = model(x)
|
| 267 |
+
|
| 268 |
+
# Run CUDA graph
|
| 269 |
+
for _ in range(n_tokens):
|
| 270 |
+
g.replay()
|
| 271 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 272 |
+
x[:, next_idx] = next_token
|
| 273 |
+
next_idx += 1
|
| 274 |
+
|
| 275 |
+
if rank == 0:
|
| 276 |
+
colored_output = decode(tokenizer, x)
|
| 277 |
+
print(f"With CUDA Graph:\n{colored_output}")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
# Get user prompt from command line arguments
|
| 282 |
+
user_prompt = "What is 2+2?" # Default prompt
|
| 283 |
+
if len(sys.argv) > 1:
|
| 284 |
+
user_prompt = sys.argv[1]
|
| 285 |
+
|
| 286 |
+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
|
| 287 |
+
rank = dist.get_rank()
|
| 288 |
+
if rank == 0:
|
| 289 |
+
print(
|
| 290 |
+
f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
dist_config = create_dist_config(mesh)
|
| 294 |
+
model, pp_schedule = create_model(dist_config)
|
| 295 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 296 |
+
|
| 297 |
+
messages = [
|
| 298 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 299 |
+
{"role": "user", "content": user_prompt},
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
generate(model, pp_schedule, tokenizer, dist_config, messages)
|
| 303 |
+
generate_with_cuda_graph(model, tokenizer, dist_config, messages)
|
| 304 |
+
|
| 305 |
+
if rank == 0:
|
| 306 |
+
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
|
| 307 |
+
|
| 308 |
+
dist.destroy_process_group()
|
torchtitan/experiments/deepseek_v3/inference.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#!/usr/bin/bash
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
NGPU=${NGPU:-"4"}
|
| 10 |
+
|
| 11 |
+
# Get the prompt from command line argument or use a default
|
| 12 |
+
prompt="${1:-What is 2+2?}"
|
| 13 |
+
|
| 14 |
+
# Run the model with the prompt
|
| 15 |
+
torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
|
torchtitan/experiments/deepseek_v3/model.py
ADDED
|
@@ -0,0 +1,1325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
|
| 8 |
+
# Hugging Face Model Hub. Url:
|
| 9 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
|
| 10 |
+
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
|
| 11 |
+
#
|
| 12 |
+
# It has been modified from its original forms to accommodate naming convention
|
| 13 |
+
# and usage patterns of the TorchTitan project.
|
| 14 |
+
|
| 15 |
+
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
""" PyTorch DeepSeek model."""
|
| 29 |
+
import math
|
| 30 |
+
from typing import Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.distributed as dist
|
| 34 |
+
|
| 35 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
import torch.utils.checkpoint
|
| 38 |
+
|
| 39 |
+
from attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 40 |
+
from indices import generate_permute_indices
|
| 41 |
+
from model_config import ModelArgs
|
| 42 |
+
from symm_mem_recipes import OnDeviceAllToAllV
|
| 43 |
+
from torch import nn
|
| 44 |
+
from torch.distributed._functional_collectives import all_to_all_single_autograd
|
| 45 |
+
|
| 46 |
+
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
|
| 47 |
+
ALIGN_SIZE_M,
|
| 48 |
+
grouped_gemm_forward,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Get model parallel subgroup by name:
|
| 52 |
+
# e.g. "pp", "ep", None
|
| 53 |
+
def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
|
| 54 |
+
glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
|
| 55 |
+
return glob.get_group(dim_name)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RMSNorm(nn.Module):
|
| 59 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 62 |
+
self.variance_epsilon = eps
|
| 63 |
+
|
| 64 |
+
def forward(self, hidden_states):
|
| 65 |
+
input_dtype = hidden_states.dtype
|
| 66 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 67 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 68 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 69 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RotaryEmbedding(nn.Module):
|
| 73 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.dim = dim
|
| 77 |
+
self.max_position_embeddings = max_position_embeddings
|
| 78 |
+
self.base = base
|
| 79 |
+
inv_freq = 1.0 / (
|
| 80 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 81 |
+
)
|
| 82 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 83 |
+
|
| 84 |
+
# Build here to make `torch.jit.trace` work.
|
| 85 |
+
self._set_cos_sin_cache(
|
| 86 |
+
seq_len=max_position_embeddings,
|
| 87 |
+
device=self.inv_freq.device,
|
| 88 |
+
dtype=torch.get_default_dtype(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 92 |
+
self.max_seq_len_cached = seq_len
|
| 93 |
+
t = torch.arange(
|
| 94 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
freqs = torch.outer(t, self.inv_freq.to(t.device))
|
| 98 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 99 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 100 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 101 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, seq_len=None):
|
| 104 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 105 |
+
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
| 106 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 110 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
| 115 |
+
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
dim,
|
| 120 |
+
max_position_embeddings=2048,
|
| 121 |
+
base=10000,
|
| 122 |
+
device=None,
|
| 123 |
+
scaling_factor=1.0,
|
| 124 |
+
):
|
| 125 |
+
self.scaling_factor = scaling_factor
|
| 126 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 127 |
+
|
| 128 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 129 |
+
self.max_seq_len_cached = seq_len
|
| 130 |
+
t = torch.arange(
|
| 131 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 132 |
+
)
|
| 133 |
+
t = t / self.scaling_factor
|
| 134 |
+
|
| 135 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 136 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 137 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 138 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 139 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
|
| 143 |
+
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
| 144 |
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim,
|
| 149 |
+
max_position_embeddings=2048,
|
| 150 |
+
base=10000,
|
| 151 |
+
device=None,
|
| 152 |
+
scaling_factor=1.0,
|
| 153 |
+
):
|
| 154 |
+
self.scaling_factor = scaling_factor
|
| 155 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 156 |
+
|
| 157 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 158 |
+
self.max_seq_len_cached = seq_len
|
| 159 |
+
|
| 160 |
+
if seq_len > self.max_position_embeddings:
|
| 161 |
+
base = self.base * (
|
| 162 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
| 163 |
+
- (self.scaling_factor - 1)
|
| 164 |
+
) ** (self.dim / (self.dim - 2))
|
| 165 |
+
inv_freq = 1.0 / (
|
| 166 |
+
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
| 167 |
+
)
|
| 168 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 169 |
+
|
| 170 |
+
t = torch.arange(
|
| 171 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 175 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 176 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 177 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 178 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Inverse dim formula to find dim based on number of rotations
|
| 182 |
+
def yarn_find_correction_dim(
|
| 183 |
+
num_rotations, dim, base=10000, max_position_embeddings=2048
|
| 184 |
+
):
|
| 185 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
| 186 |
+
2 * math.log(base)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Find dim range bounds based on rotations
|
| 191 |
+
def yarn_find_correction_range(
|
| 192 |
+
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
| 193 |
+
):
|
| 194 |
+
low = math.floor(
|
| 195 |
+
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
| 196 |
+
)
|
| 197 |
+
high = math.ceil(
|
| 198 |
+
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
| 199 |
+
)
|
| 200 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
| 204 |
+
if scale <= 1:
|
| 205 |
+
return 1.0
|
| 206 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def yarn_linear_ramp_mask(min, max, dim):
|
| 210 |
+
if min == max:
|
| 211 |
+
max += 0.001 # Prevent singularity
|
| 212 |
+
|
| 213 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 214 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 215 |
+
return ramp_func
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class YarnRotaryEmbedding(RotaryEmbedding):
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
dim,
|
| 222 |
+
max_position_embeddings=2048,
|
| 223 |
+
base=10000,
|
| 224 |
+
device=None,
|
| 225 |
+
scaling_factor=1.0,
|
| 226 |
+
original_max_position_embeddings=4096,
|
| 227 |
+
beta_fast=32,
|
| 228 |
+
beta_slow=1,
|
| 229 |
+
mscale=1,
|
| 230 |
+
mscale_all_dim=0,
|
| 231 |
+
):
|
| 232 |
+
self.scaling_factor = scaling_factor
|
| 233 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
| 234 |
+
self.beta_fast = beta_fast
|
| 235 |
+
self.beta_slow = beta_slow
|
| 236 |
+
self.mscale = mscale
|
| 237 |
+
self.mscale_all_dim = mscale_all_dim
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
dim = self.dim
|
| 243 |
+
|
| 244 |
+
freq_extra = 1.0 / (
|
| 245 |
+
self.base
|
| 246 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 247 |
+
)
|
| 248 |
+
freq_inter = 1.0 / (
|
| 249 |
+
self.scaling_factor
|
| 250 |
+
* self.base
|
| 251 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
low, high = yarn_find_correction_range(
|
| 255 |
+
self.beta_fast,
|
| 256 |
+
self.beta_slow,
|
| 257 |
+
dim,
|
| 258 |
+
self.base,
|
| 259 |
+
self.original_max_position_embeddings,
|
| 260 |
+
)
|
| 261 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
| 262 |
+
device=device, dtype=torch.float32
|
| 263 |
+
)
|
| 264 |
+
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
| 265 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 266 |
+
|
| 267 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 268 |
+
|
| 269 |
+
freqs = torch.outer(t, inv_freq)
|
| 270 |
+
|
| 271 |
+
_mscale = float(
|
| 272 |
+
yarn_get_mscale(self.scaling_factor, self.mscale)
|
| 273 |
+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 277 |
+
self.register_buffer(
|
| 278 |
+
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
|
| 279 |
+
)
|
| 280 |
+
self.register_buffer(
|
| 281 |
+
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 286 |
+
def rotate_half(x):
|
| 287 |
+
"""Rotates half the hidden dims of the input."""
|
| 288 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 289 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 290 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 294 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 295 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
q (`torch.Tensor`): The query tensor.
|
| 299 |
+
k (`torch.Tensor`): The key tensor.
|
| 300 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 301 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 302 |
+
position_ids (`torch.Tensor`):
|
| 303 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 304 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 305 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 306 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 307 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 308 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 309 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 310 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 311 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 312 |
+
Returns:
|
| 313 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 314 |
+
"""
|
| 315 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 316 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 317 |
+
|
| 318 |
+
b, h, s, d = q.shape
|
| 319 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 320 |
+
|
| 321 |
+
b, h, s, d = k.shape
|
| 322 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 323 |
+
|
| 324 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 325 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 326 |
+
return q_embed, k_embed
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class MLP(nn.Module):
|
| 330 |
+
act_fn = nn.SiLU()
|
| 331 |
+
|
| 332 |
+
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.config = config
|
| 335 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 336 |
+
self.intermediate_size = (
|
| 337 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 341 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 342 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 343 |
+
|
| 344 |
+
def forward(self, x):
|
| 345 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 346 |
+
return down_proj
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class MoEGate(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.config = config
|
| 353 |
+
self.top_k = config.num_experts_per_tok
|
| 354 |
+
self.n_routed_experts = config.n_routed_experts
|
| 355 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 356 |
+
self.scoring_func = config.scoring_func
|
| 357 |
+
self.seq_aux = config.seq_aux
|
| 358 |
+
self.topk_method = config.topk_method
|
| 359 |
+
self.n_group = config.n_group
|
| 360 |
+
self.topk_group = config.topk_group
|
| 361 |
+
|
| 362 |
+
# topk selection algorithm
|
| 363 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 364 |
+
self.gating_dim = config.hidden_size
|
| 365 |
+
self.weight = nn.Parameter(
|
| 366 |
+
torch.empty((self.n_routed_experts, self.gating_dim))
|
| 367 |
+
)
|
| 368 |
+
if self.topk_method == "noaux_tc":
|
| 369 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 370 |
+
# Changed from torch.empty to torch.rand to avoid non-even
|
| 371 |
+
# distribution for runs without actual weigths
|
| 372 |
+
torch.rand((self.n_routed_experts))
|
| 373 |
+
)
|
| 374 |
+
self.reset_parameters()
|
| 375 |
+
|
| 376 |
+
def reset_parameters(self) -> None:
|
| 377 |
+
import torch.nn.init as init
|
| 378 |
+
|
| 379 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden_states):
|
| 382 |
+
bsz, seq_len, h = hidden_states.shape
|
| 383 |
+
# compute gating score
|
| 384 |
+
hidden_states = hidden_states.view(-1, h)
|
| 385 |
+
logits = F.linear(
|
| 386 |
+
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
| 387 |
+
)
|
| 388 |
+
if self.scoring_func == "sigmoid":
|
| 389 |
+
scores = logits.sigmoid()
|
| 390 |
+
elif self.scoring_func == "softmax":
|
| 391 |
+
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
| 392 |
+
else:
|
| 393 |
+
raise NotImplementedError(
|
| 394 |
+
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# select top-k experts
|
| 398 |
+
if self.topk_method == "noaux_tc":
|
| 399 |
+
scores_for_choice = scores.view(
|
| 400 |
+
bsz * seq_len, -1
|
| 401 |
+
) + self.e_score_correction_bias.unsqueeze(0)
|
| 402 |
+
group_scores = (
|
| 403 |
+
scores_for_choice.view(bsz * seq_len, self.n_group, -1)
|
| 404 |
+
.topk(2, dim=-1)[0]
|
| 405 |
+
.sum(dim=-1)
|
| 406 |
+
) # [n, n_group]
|
| 407 |
+
group_idx = torch.topk(
|
| 408 |
+
group_scores, k=self.topk_group, dim=-1, sorted=False
|
| 409 |
+
)[
|
| 410 |
+
1
|
| 411 |
+
] # [n, top_k_group]
|
| 412 |
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 413 |
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 414 |
+
score_mask = (
|
| 415 |
+
group_mask.unsqueeze(-1)
|
| 416 |
+
.expand(
|
| 417 |
+
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
|
| 418 |
+
)
|
| 419 |
+
.reshape(bsz * seq_len, -1)
|
| 420 |
+
) # [n, e]
|
| 421 |
+
tmp_scores = scores_for_choice.masked_fill(
|
| 422 |
+
~score_mask.bool(), 0.0
|
| 423 |
+
) # [n, e]
|
| 424 |
+
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
| 425 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 426 |
+
elif self.topk_method == "greedy":
|
| 427 |
+
topk_weight, topk_idx = torch.topk(
|
| 428 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
raise NotImplementedError(
|
| 432 |
+
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# norm gate to sum 1
|
| 436 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 437 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 438 |
+
topk_weight = topk_weight / denominator
|
| 439 |
+
topk_weight = (
|
| 440 |
+
topk_weight * self.routed_scaling_factor
|
| 441 |
+
) # must multiply the scaling factor
|
| 442 |
+
|
| 443 |
+
return topk_idx, topk_weight
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MoE(nn.Module):
|
| 447 |
+
"""
|
| 448 |
+
A mixed expert module containing shared experts.
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
# Class attributes:
|
| 452 |
+
# Two shuffle method supported:
|
| 453 |
+
# 1. "torch_all_to_all"
|
| 454 |
+
# 2. "symm_mem" (see `setup_symm_mem` below)
|
| 455 |
+
shuffle_method = "torch_all_to_all"
|
| 456 |
+
|
| 457 |
+
# Symmetric memory buffers shared by all MoE instances across layers
|
| 458 |
+
token_send_buf: Optional[torch.Tensor] = None
|
| 459 |
+
token_gather_buf: Optional[torch.Tensor] = None
|
| 460 |
+
|
| 461 |
+
def __init__(self, config):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.config = config
|
| 464 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 465 |
+
|
| 466 |
+
# ep_size is the number of ranks in expert dimension
|
| 467 |
+
if config.ep_size <= 1:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
"For code simplicity, this model only supports distributed experts, "
|
| 470 |
+
"thus EP size must be > 1, please modify your model config"
|
| 471 |
+
)
|
| 472 |
+
self.ep_group = get_group("ep")
|
| 473 |
+
assert config.ep_size == self.ep_group.size()
|
| 474 |
+
self.ep_size = config.ep_size
|
| 475 |
+
self.ep_rank = self.ep_group.rank()
|
| 476 |
+
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
| 477 |
+
# Use ModuleDict instead of ModuleList to preserve absoulte expert
|
| 478 |
+
# IDs while avoiding `None` experts. The absolute expert IDs match
|
| 479 |
+
# with checkpoint FQNs.
|
| 480 |
+
self.experts = nn.ModuleDict()
|
| 481 |
+
for i in range(self.experts_per_rank):
|
| 482 |
+
abs_expert_id = self.ep_rank * self.experts_per_rank + i
|
| 483 |
+
self.experts[str(abs_expert_id)] = MLP(
|
| 484 |
+
config, intermediate_size=config.moe_intermediate_size
|
| 485 |
+
)
|
| 486 |
+
self.gate = MoEGate(config)
|
| 487 |
+
if config.n_shared_experts is not None:
|
| 488 |
+
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
| 489 |
+
self.shared_experts = MLP(
|
| 490 |
+
config=config, intermediate_size=intermediate_size
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def combine_experts(self, submod_name):
|
| 494 |
+
all_weights = []
|
| 495 |
+
for expert in self.experts.values():
|
| 496 |
+
lin = expert.get_submodule(submod_name)
|
| 497 |
+
all_weights.append(lin.weight)
|
| 498 |
+
lin.weight = None
|
| 499 |
+
|
| 500 |
+
concat_weight = torch.cat(all_weights)
|
| 501 |
+
self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
|
| 502 |
+
|
| 503 |
+
# This function is used to create a symm mem buffer for MoE's. It is for
|
| 504 |
+
# shuffling tokens fully "on-device", as compared to traditional torch
|
| 505 |
+
# all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
|
| 506 |
+
# calls this function, the `shuffle_method` would switch from
|
| 507 |
+
# `torch_all_to_all` to `symm_mem`.
|
| 508 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 509 |
+
# Switch shuffle method
|
| 510 |
+
self.shuffle_method = "symm_mem"
|
| 511 |
+
|
| 512 |
+
# Combine expert weights
|
| 513 |
+
print("Combining expert weights for Group GEMM")
|
| 514 |
+
self.combine_experts("gate_proj")
|
| 515 |
+
self.combine_experts("up_proj")
|
| 516 |
+
self.combine_experts("down_proj")
|
| 517 |
+
|
| 518 |
+
# Assuming worst case, 2x tokens are routed to one EP rank
|
| 519 |
+
overflow = 2
|
| 520 |
+
OnDeviceAllToAllV.max_output_len = (
|
| 521 |
+
self.config.max_seq_len * self.num_experts_per_tok * overflow
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Symmetric memory buffers are shared by all MoE instances across
|
| 525 |
+
# layers, we only need to initialize them once
|
| 526 |
+
if MoE.token_send_buf is not None:
|
| 527 |
+
return
|
| 528 |
+
|
| 529 |
+
# Input buffer for DP-to-EP shuffle
|
| 530 |
+
MoE.token_send_buf = symm_mem.empty(
|
| 531 |
+
self.config.max_seq_len
|
| 532 |
+
* self.num_experts_per_tok, # seq len * top k (flattened)
|
| 533 |
+
self.config.hidden_size, # hidden dim
|
| 534 |
+
dtype=dtype,
|
| 535 |
+
device=device,
|
| 536 |
+
)
|
| 537 |
+
# Input buffer for EP-to-DP shuffle
|
| 538 |
+
MoE.token_gather_buf = symm_mem.empty(
|
| 539 |
+
self.config.max_seq_len
|
| 540 |
+
* self.num_experts_per_tok # seq len * top k (flattened)
|
| 541 |
+
* overflow,
|
| 542 |
+
self.config.hidden_size, # hidden dim
|
| 543 |
+
dtype=dtype,
|
| 544 |
+
device=device,
|
| 545 |
+
)
|
| 546 |
+
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
|
| 547 |
+
|
| 548 |
+
def get_send_buf(self):
|
| 549 |
+
# [Why detach?] During a first forward-backward step, the buffer would
|
| 550 |
+
# be included in a computational graph. In a second step, autograd will
|
| 551 |
+
# return an error saying "Trying to backward through the graph a second
|
| 552 |
+
# time (or directly access saved tensors more than once)". This is
|
| 553 |
+
# because the buffer is still in the graph, and autograd is trying to
|
| 554 |
+
# backward through the graph a second time. To avoid this, we detach the
|
| 555 |
+
# buffer from the graph. `detach()` returns a new tensor, which shares
|
| 556 |
+
# the same storage with the original one.
|
| 557 |
+
self.token_send_buf.grad = None
|
| 558 |
+
return self.token_send_buf.detach()
|
| 559 |
+
|
| 560 |
+
def get_gather_buf(self):
|
| 561 |
+
# See [Why detach?] in `get_send_buf`
|
| 562 |
+
self.token_gather_buf.grad = None
|
| 563 |
+
return self.token_gather_buf.detach()
|
| 564 |
+
|
| 565 |
+
def forward(self, hidden_states):
|
| 566 |
+
identity = hidden_states
|
| 567 |
+
orig_shape = hidden_states.shape
|
| 568 |
+
# for each token, select top-k experts, and compute the weight for each expert
|
| 569 |
+
topk_idx, topk_weight = self.gate(hidden_states)
|
| 570 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 571 |
+
if self.shuffle_method == "symm_mem":
|
| 572 |
+
y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
|
| 573 |
+
else: # "torch_all_to_all"
|
| 574 |
+
y = self.moe_forward(hidden_states, topk_idx, topk_weight)
|
| 575 |
+
|
| 576 |
+
y = y.view(*orig_shape)
|
| 577 |
+
if self.config.n_shared_experts is not None:
|
| 578 |
+
y = y + self.shared_experts(identity)
|
| 579 |
+
return y
|
| 580 |
+
|
| 581 |
+
def moe_forward(self, x, topk_ids, topk_weight):
|
| 582 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 583 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 584 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 585 |
+
# `idxs`), we don't need gradients here.
|
| 586 |
+
with torch.no_grad():
|
| 587 |
+
# [seq_len, n_routed_experts]
|
| 588 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 589 |
+
# Fill 1 to the selected experts
|
| 590 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 591 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 592 |
+
# Token indices for each expert
|
| 593 |
+
idxs = topk_ids.view(-1).argsort()
|
| 594 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 595 |
+
|
| 596 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 597 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 598 |
+
|
| 599 |
+
# This part exchange the information about the number of tokens send and
|
| 600 |
+
# received by each expert. We can understand this information as "side
|
| 601 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 602 |
+
# needed.
|
| 603 |
+
with torch.no_grad():
|
| 604 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 605 |
+
# which is the input splits
|
| 606 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 607 |
+
tokens_per_expert.shape[0]
|
| 608 |
+
)
|
| 609 |
+
dist.all_to_all_single(
|
| 610 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 611 |
+
)
|
| 612 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 613 |
+
|
| 614 |
+
# DP to EP token shuffle. This part needs gradient.
|
| 615 |
+
if self.shuffle_method == "symm_mem":
|
| 616 |
+
# Move input to the `token_send_buf` symm mem
|
| 617 |
+
token_send_buf = self.get_send_buf()
|
| 618 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 619 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 620 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 621 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 622 |
+
token_send_buf,
|
| 623 |
+
input_splits,
|
| 624 |
+
self.ep_group,
|
| 625 |
+
)
|
| 626 |
+
with torch.no_grad():
|
| 627 |
+
# Received tokens from all other ranks. TODO: use mask instead
|
| 628 |
+
received = output_splits.sum()
|
| 629 |
+
# TODO: don't use `received`
|
| 630 |
+
gathered_tokens = token_gather_buf[:received]
|
| 631 |
+
else: # "torch_all_to_all"
|
| 632 |
+
# Prepare input ans output splits
|
| 633 |
+
with torch.no_grad():
|
| 634 |
+
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
|
| 635 |
+
dim=1
|
| 636 |
+
)
|
| 637 |
+
gathered_tokens = all_to_all_single_autograd(
|
| 638 |
+
sorted_tokens,
|
| 639 |
+
output_splits.tolist(),
|
| 640 |
+
input_splits.tolist(),
|
| 641 |
+
self.ep_group,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# This part prepares a 1D tensor with the same length as
|
| 645 |
+
# `gathered_tokens`. The 1D tensor is filled with local expert IDs which
|
| 646 |
+
# the tokens in `gathered_tokens` are headed for. This part doesn't need
|
| 647 |
+
# gradient.
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
gatherd_idxs = (
|
| 650 |
+
torch.arange(
|
| 651 |
+
tokens_per_expert_group.numel(),
|
| 652 |
+
device=tokens_per_expert_group.device,
|
| 653 |
+
)
|
| 654 |
+
% self.experts_per_rank
|
| 655 |
+
)
|
| 656 |
+
gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
|
| 657 |
+
|
| 658 |
+
# Prepare buffer for tokens processed by experts
|
| 659 |
+
if self.shuffle_method == "symm_mem":
|
| 660 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 661 |
+
# going to send them out after expert processing
|
| 662 |
+
processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
|
| 663 |
+
else: # "torch_all_to_all"
|
| 664 |
+
processed_tokens = torch.empty_like(gathered_tokens)
|
| 665 |
+
|
| 666 |
+
# This part processes the tokens routed to the local experts.
|
| 667 |
+
# TODO: can we use group GEMM here?
|
| 668 |
+
for i, expert in enumerate(self.experts.values()):
|
| 669 |
+
processed_tokens[gatherd_idxs == i] = expert(
|
| 670 |
+
gathered_tokens[gatherd_idxs == i]
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 674 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 675 |
+
if self.shuffle_method == "symm_mem":
|
| 676 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 677 |
+
processed_tokens,
|
| 678 |
+
output_splits,
|
| 679 |
+
self.ep_group,
|
| 680 |
+
)
|
| 681 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 682 |
+
else: # "torch_all_to_all"
|
| 683 |
+
returned_tokens = all_to_all_single_autograd(
|
| 684 |
+
processed_tokens,
|
| 685 |
+
input_splits.tolist(),
|
| 686 |
+
output_splits.tolist(),
|
| 687 |
+
self.ep_group,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 691 |
+
output_tokens[idxs] = returned_tokens
|
| 692 |
+
final_out = (
|
| 693 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 694 |
+
.type(topk_weight.dtype)
|
| 695 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 696 |
+
.sum(dim=1)
|
| 697 |
+
.type(returned_tokens.dtype)
|
| 698 |
+
)
|
| 699 |
+
return final_out
|
| 700 |
+
|
| 701 |
+
def moe_on_device(self, x, topk_ids, topk_weight):
|
| 702 |
+
# This part sorts the token indices so that tokens routed to the same expert reside consecutively.
|
| 703 |
+
# An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
|
| 704 |
+
# Since this is an "aritificial" index creation (final outcome being
|
| 705 |
+
# `idxs`), we don't need gradients here.
|
| 706 |
+
with torch.no_grad():
|
| 707 |
+
# [seq_len, n_routed_experts]
|
| 708 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
|
| 709 |
+
# Fill 1 to the selected experts
|
| 710 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 711 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 712 |
+
# Token indices for each expert
|
| 713 |
+
idxs = topk_ids.view(-1).argsort()
|
| 714 |
+
sorted_tokens_shape = idxs.shape + x.shape[1:]
|
| 715 |
+
|
| 716 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 717 |
+
assert sorted_tokens.shape == sorted_tokens_shape
|
| 718 |
+
|
| 719 |
+
# This part exchange the information about the number of tokens send and
|
| 720 |
+
# received by each expert. We can understand this information as "side
|
| 721 |
+
# band", which is not part of the actual data. Thus no gradient is
|
| 722 |
+
# needed.
|
| 723 |
+
with torch.no_grad():
|
| 724 |
+
# Sum the tokens over local experts, then we get tokens per EP rank,
|
| 725 |
+
# which is the input splits
|
| 726 |
+
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 727 |
+
tokens_per_expert.shape[0]
|
| 728 |
+
)
|
| 729 |
+
dist.all_to_all_single(
|
| 730 |
+
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
|
| 731 |
+
)
|
| 732 |
+
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
| 733 |
+
|
| 734 |
+
# Move input to the `token_send_buf` symm mem
|
| 735 |
+
token_send_buf = self.get_send_buf()
|
| 736 |
+
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
|
| 737 |
+
# Note: `out=` avoids copy, but it is not differentiable
|
| 738 |
+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
|
| 739 |
+
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
|
| 740 |
+
token_send_buf,
|
| 741 |
+
input_splits,
|
| 742 |
+
self.ep_group,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# We need to permute the received tokens so that tokens for the same expert are contiguous.
|
| 746 |
+
# This part prepares a 1D tensor `permuted_indices` for such permutation.
|
| 747 |
+
# This part doesn't need gradient.
|
| 748 |
+
with torch.no_grad():
|
| 749 |
+
permuted_indices, m_sizes = generate_permute_indices(
|
| 750 |
+
tokens_per_expert_group,
|
| 751 |
+
self.experts_per_rank,
|
| 752 |
+
self.ep_size,
|
| 753 |
+
token_gather_buf.shape[0],
|
| 754 |
+
ALIGN_SIZE_M,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Permute the received tokens so that tokens for the same expert are contiguous.
|
| 758 |
+
contig_tokens = token_gather_buf[permuted_indices]
|
| 759 |
+
|
| 760 |
+
# Run the first grouped GEMM
|
| 761 |
+
w1 = self.get_parameter("gate_proj_weight")
|
| 762 |
+
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
|
| 763 |
+
|
| 764 |
+
# Run the second grouped GEMM
|
| 765 |
+
w3 = self.get_parameter("up_proj_weight")
|
| 766 |
+
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
|
| 767 |
+
|
| 768 |
+
# Apply activation
|
| 769 |
+
hidden_outputs = MLP.act_fn(gate_proj) * up_proj
|
| 770 |
+
|
| 771 |
+
# Run the third grouped GEMM
|
| 772 |
+
w2 = self.get_parameter("down_proj_weight")
|
| 773 |
+
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
|
| 774 |
+
|
| 775 |
+
# Prepare buffer for tokens processed by experts
|
| 776 |
+
# Take necessary space from `token_gather_buf` symm mem because we are
|
| 777 |
+
# going to send them out after expert processing
|
| 778 |
+
processed_tokens = self.get_gather_buf()
|
| 779 |
+
|
| 780 |
+
# Move into Symmetric Memory for the return shuffle
|
| 781 |
+
processed_tokens[permuted_indices] = hidden_outputs
|
| 782 |
+
|
| 783 |
+
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
|
| 784 |
+
# The input/output splits are just a reverse of the previous shuffle.
|
| 785 |
+
token_return_buf, _ = OnDeviceAllToAllV.apply(
|
| 786 |
+
processed_tokens,
|
| 787 |
+
output_splits,
|
| 788 |
+
self.ep_group,
|
| 789 |
+
)
|
| 790 |
+
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
|
| 791 |
+
|
| 792 |
+
output_tokens = torch.empty_like(returned_tokens)
|
| 793 |
+
output_tokens[idxs] = returned_tokens
|
| 794 |
+
final_out = (
|
| 795 |
+
output_tokens.view(*topk_ids.shape, -1)
|
| 796 |
+
.type(topk_weight.dtype)
|
| 797 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 798 |
+
.sum(dim=1)
|
| 799 |
+
.type(returned_tokens.dtype)
|
| 800 |
+
)
|
| 801 |
+
return final_out
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class Attention(nn.Module):
|
| 805 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 806 |
+
|
| 807 |
+
def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
|
| 808 |
+
super().__init__()
|
| 809 |
+
self.config = config
|
| 810 |
+
self.layer_idx = layer_idx
|
| 811 |
+
self.attention_dropout = config.attention_dropout
|
| 812 |
+
self.hidden_size = config.hidden_size
|
| 813 |
+
self.num_heads = config.num_attention_heads
|
| 814 |
+
|
| 815 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 816 |
+
self.rope_theta = config.rope_theta
|
| 817 |
+
self.q_lora_rank = config.q_lora_rank
|
| 818 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 819 |
+
self.kv_lora_rank = config.kv_lora_rank
|
| 820 |
+
self.v_head_dim = config.v_head_dim
|
| 821 |
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
| 822 |
+
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
| 823 |
+
|
| 824 |
+
self.is_causal = True
|
| 825 |
+
|
| 826 |
+
if self.q_lora_rank is None:
|
| 827 |
+
self.q_proj = nn.Linear(
|
| 828 |
+
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
self.q_a_proj = nn.Linear(
|
| 832 |
+
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
|
| 833 |
+
)
|
| 834 |
+
self.q_a_layernorm = RMSNorm(config.q_lora_rank)
|
| 835 |
+
self.q_b_proj = nn.Linear(
|
| 836 |
+
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 840 |
+
self.hidden_size,
|
| 841 |
+
config.kv_lora_rank + config.qk_rope_head_dim,
|
| 842 |
+
bias=config.attention_bias,
|
| 843 |
+
)
|
| 844 |
+
self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
|
| 845 |
+
self.kv_b_proj = nn.Linear(
|
| 846 |
+
config.kv_lora_rank,
|
| 847 |
+
self.num_heads
|
| 848 |
+
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
| 849 |
+
bias=False,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.o_proj = nn.Linear(
|
| 853 |
+
self.num_heads * self.v_head_dim,
|
| 854 |
+
self.hidden_size,
|
| 855 |
+
bias=config.attention_bias,
|
| 856 |
+
)
|
| 857 |
+
self._init_rope()
|
| 858 |
+
|
| 859 |
+
self.softmax_scale = self.q_head_dim ** (-0.5)
|
| 860 |
+
if self.config.rope_scaling is not None:
|
| 861 |
+
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
| 862 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 863 |
+
if mscale_all_dim:
|
| 864 |
+
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 865 |
+
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 866 |
+
|
| 867 |
+
def _init_rope(self):
|
| 868 |
+
if self.config.rope_scaling is None:
|
| 869 |
+
self.rotary_emb = RotaryEmbedding(
|
| 870 |
+
self.qk_rope_head_dim,
|
| 871 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 872 |
+
base=self.rope_theta,
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 876 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 877 |
+
if scaling_type == "linear":
|
| 878 |
+
self.rotary_emb = LinearScalingRotaryEmbedding(
|
| 879 |
+
self.qk_rope_head_dim,
|
| 880 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 881 |
+
scaling_factor=scaling_factor,
|
| 882 |
+
base=self.rope_theta,
|
| 883 |
+
)
|
| 884 |
+
elif scaling_type == "dynamic":
|
| 885 |
+
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
| 886 |
+
self.qk_rope_head_dim,
|
| 887 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 888 |
+
scaling_factor=scaling_factor,
|
| 889 |
+
base=self.rope_theta,
|
| 890 |
+
)
|
| 891 |
+
elif scaling_type == "yarn":
|
| 892 |
+
kwargs = {
|
| 893 |
+
key: self.config.rope_scaling[key]
|
| 894 |
+
for key in [
|
| 895 |
+
"original_max_position_embeddings",
|
| 896 |
+
"beta_fast",
|
| 897 |
+
"beta_slow",
|
| 898 |
+
"mscale",
|
| 899 |
+
"mscale_all_dim",
|
| 900 |
+
]
|
| 901 |
+
if key in self.config.rope_scaling
|
| 902 |
+
}
|
| 903 |
+
self.rotary_emb = YarnRotaryEmbedding(
|
| 904 |
+
self.qk_rope_head_dim,
|
| 905 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 906 |
+
scaling_factor=scaling_factor,
|
| 907 |
+
base=self.rope_theta,
|
| 908 |
+
**kwargs,
|
| 909 |
+
)
|
| 910 |
+
else:
|
| 911 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 912 |
+
|
| 913 |
+
def forward(
|
| 914 |
+
self,
|
| 915 |
+
hidden_states: torch.Tensor,
|
| 916 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 917 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 918 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 919 |
+
bsz, q_len, _ = hidden_states.size()
|
| 920 |
+
|
| 921 |
+
if self.q_lora_rank is None:
|
| 922 |
+
q = self.q_proj(hidden_states)
|
| 923 |
+
else:
|
| 924 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 925 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 926 |
+
q_nope, q_pe = torch.split(
|
| 927 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 931 |
+
compressed_kv, k_pe = torch.split(
|
| 932 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
| 933 |
+
)
|
| 934 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
| 935 |
+
kv = (
|
| 936 |
+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
| 937 |
+
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 938 |
+
.transpose(1, 2)
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
k_nope, value_states = torch.split(
|
| 942 |
+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
| 943 |
+
)
|
| 944 |
+
kv_seq_len = value_states.shape[-2]
|
| 945 |
+
|
| 946 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 947 |
+
|
| 948 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 949 |
+
|
| 950 |
+
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 951 |
+
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
| 952 |
+
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
| 953 |
+
|
| 954 |
+
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 955 |
+
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
| 956 |
+
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
| 957 |
+
|
| 958 |
+
if attention_mask is not None:
|
| 959 |
+
# Attention mask was made 4D because the `attn_weights` above is 4D.
|
| 960 |
+
# We probably can make this mask smarter if we want to pack sequences
|
| 961 |
+
# together, instead of using padding. This optimization can be used in
|
| 962 |
+
# inference. For training, if we want to pack sequences, data loader
|
| 963 |
+
# will pass in a mask containing such info.
|
| 964 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 965 |
+
attention_mask, # None, or user provided mask in 2D
|
| 966 |
+
(bsz, q_len),
|
| 967 |
+
hidden_states,
|
| 968 |
+
0, # past_key_values_length, 0 when training
|
| 969 |
+
)
|
| 970 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 971 |
+
raise ValueError(
|
| 972 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 976 |
+
query=query_states,
|
| 977 |
+
key=key_states,
|
| 978 |
+
value=value_states,
|
| 979 |
+
attn_mask=attention_mask,
|
| 980 |
+
dropout_p=self.attention_dropout,
|
| 981 |
+
is_causal=attention_mask is None,
|
| 982 |
+
scale=self.softmax_scale,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 986 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
| 987 |
+
attn_output = self.o_proj(attn_output)
|
| 988 |
+
|
| 989 |
+
return attn_output
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
class DecoderLayer(nn.Module):
|
| 993 |
+
def __init__(self, config: ModelArgs, layer_idx: int):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.hidden_size = config.hidden_size
|
| 996 |
+
|
| 997 |
+
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
| 998 |
+
|
| 999 |
+
self.mlp = (
|
| 1000 |
+
MoE(config)
|
| 1001 |
+
if (
|
| 1002 |
+
config.n_routed_experts is not None
|
| 1003 |
+
and layer_idx >= config.first_k_dense_replace
|
| 1004 |
+
and layer_idx % config.moe_layer_freq == 0
|
| 1005 |
+
)
|
| 1006 |
+
else MLP(config)
|
| 1007 |
+
)
|
| 1008 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1009 |
+
self.post_attention_layernorm = RMSNorm(
|
| 1010 |
+
config.hidden_size, eps=config.rms_norm_eps
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def forward(
|
| 1014 |
+
self,
|
| 1015 |
+
hidden_states: torch.Tensor,
|
| 1016 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1017 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1018 |
+
) -> torch.Tensor:
|
| 1019 |
+
"""
|
| 1020 |
+
Args:
|
| 1021 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1022 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 1023 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 1024 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 1025 |
+
"""
|
| 1026 |
+
residual = hidden_states
|
| 1027 |
+
|
| 1028 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 1029 |
+
|
| 1030 |
+
# Self Attention
|
| 1031 |
+
hidden_states = self.self_attn(
|
| 1032 |
+
hidden_states=hidden_states,
|
| 1033 |
+
attention_mask=attention_mask,
|
| 1034 |
+
position_ids=position_ids,
|
| 1035 |
+
)
|
| 1036 |
+
hidden_states = residual + hidden_states
|
| 1037 |
+
|
| 1038 |
+
# Fully Connected
|
| 1039 |
+
residual = hidden_states
|
| 1040 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1041 |
+
hidden_states = self.mlp(hidden_states)
|
| 1042 |
+
hidden_states = residual + hidden_states
|
| 1043 |
+
|
| 1044 |
+
return hidden_states
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
Deepseek_INPUTS_DOCSTRING = r"""
|
| 1048 |
+
Args:
|
| 1049 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1050 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1051 |
+
it.
|
| 1052 |
+
|
| 1053 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1054 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1055 |
+
|
| 1056 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1057 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1058 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1059 |
+
|
| 1060 |
+
- 1 for tokens that are **not masked**,
|
| 1061 |
+
- 0 for tokens that are **masked**.
|
| 1062 |
+
|
| 1063 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1064 |
+
|
| 1065 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1066 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1067 |
+
|
| 1068 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 1069 |
+
`past_key_values`).
|
| 1070 |
+
|
| 1071 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1072 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1073 |
+
information on the default strategy.
|
| 1074 |
+
|
| 1075 |
+
- 1 indicates the head is **not masked**,
|
| 1076 |
+
- 0 indicates the head is **masked**.
|
| 1077 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1078 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1079 |
+
config.n_positions - 1]`.
|
| 1080 |
+
|
| 1081 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1082 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 1083 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1084 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 1085 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 1086 |
+
|
| 1087 |
+
Two formats are allowed:
|
| 1088 |
+
- a [`~cache_utils.Cache`] instance;
|
| 1089 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1090 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 1091 |
+
cache format.
|
| 1092 |
+
|
| 1093 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 1094 |
+
legacy cache format will be returned.
|
| 1095 |
+
|
| 1096 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 1097 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 1098 |
+
of shape `(batch_size, sequence_length)`.
|
| 1099 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1100 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1101 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1102 |
+
model's internal embedding lookup matrix.
|
| 1103 |
+
use_cache (`bool`, *optional*):
|
| 1104 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1105 |
+
`past_key_values`).
|
| 1106 |
+
output_attentions (`bool`, *optional*):
|
| 1107 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1108 |
+
tensors for more detail.
|
| 1109 |
+
output_hidden_states (`bool`, *optional*):
|
| 1110 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1111 |
+
more detail.
|
| 1112 |
+
return_dict (`bool`, *optional*):
|
| 1113 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
class DeepseekModel(torch.nn.Module):
|
| 1118 |
+
"""
|
| 1119 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
config: ModelArgs
|
| 1123 |
+
"""
|
| 1124 |
+
|
| 1125 |
+
def __init__(self, config: ModelArgs):
|
| 1126 |
+
super().__init__()
|
| 1127 |
+
self.config = config
|
| 1128 |
+
self.padding_idx = config.pad_token_id
|
| 1129 |
+
self.vocab_size = config.vocab_size
|
| 1130 |
+
|
| 1131 |
+
# Creating model parts related to my stage
|
| 1132 |
+
assert (
|
| 1133 |
+
config.stage_idx < config.num_stages
|
| 1134 |
+
), f"Stage {config.stage_idx} is not in the model"
|
| 1135 |
+
print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
|
| 1136 |
+
|
| 1137 |
+
self.embed_tokens = (
|
| 1138 |
+
nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1139 |
+
if config.stage_idx == 0
|
| 1140 |
+
else None
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
self.layers = torch.nn.ModuleDict()
|
| 1144 |
+
division = config.num_hidden_layers // config.num_stages
|
| 1145 |
+
residual = config.num_hidden_layers % config.num_stages
|
| 1146 |
+
# Some earlier stages may have 1 more layer than latter stages because
|
| 1147 |
+
# the division may have residual; this is more even than giving the
|
| 1148 |
+
# entire residual to the last stage.
|
| 1149 |
+
layers_per_stage = [
|
| 1150 |
+
division + 1 if stage < residual else division
|
| 1151 |
+
for stage in range(config.num_stages)
|
| 1152 |
+
]
|
| 1153 |
+
assert sum(layers_per_stage) == config.num_hidden_layers
|
| 1154 |
+
layer_id_start = sum(layers_per_stage[: config.stage_idx])
|
| 1155 |
+
layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
|
| 1156 |
+
for layer_id in range(layer_id_start, layer_id_end):
|
| 1157 |
+
self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
|
| 1158 |
+
|
| 1159 |
+
self.norm = (
|
| 1160 |
+
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1161 |
+
if config.stage_idx == config.num_stages - 1
|
| 1162 |
+
else None
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
# Initialize weights and apply final processing
|
| 1166 |
+
self.apply(self._init_weights)
|
| 1167 |
+
|
| 1168 |
+
def _init_weights(self, module):
|
| 1169 |
+
std = self.config.initializer_range
|
| 1170 |
+
if isinstance(module, nn.Linear):
|
| 1171 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1172 |
+
if module.bias is not None:
|
| 1173 |
+
module.bias.data.zero_()
|
| 1174 |
+
elif isinstance(module, nn.Embedding):
|
| 1175 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1176 |
+
if module.padding_idx is not None:
|
| 1177 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1178 |
+
|
| 1179 |
+
def forward(
|
| 1180 |
+
self,
|
| 1181 |
+
tokens: torch.Tensor,
|
| 1182 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1183 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1184 |
+
) -> torch.Tensor:
|
| 1185 |
+
# Embedding
|
| 1186 |
+
hidden_states = (
|
| 1187 |
+
self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
# decoder layers
|
| 1191 |
+
for decoder_layer in self.layers.values():
|
| 1192 |
+
hidden_states = decoder_layer(
|
| 1193 |
+
hidden_states,
|
| 1194 |
+
attention_mask=attention_mask,
|
| 1195 |
+
position_ids=position_ids,
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
hidden_states = (
|
| 1199 |
+
self.norm(hidden_states) if self.norm is not None else hidden_states
|
| 1200 |
+
)
|
| 1201 |
+
return hidden_states
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
class DeepseekForCausalLM(torch.nn.Module):
|
| 1205 |
+
def __init__(self, config):
|
| 1206 |
+
super().__init__()
|
| 1207 |
+
self.model = DeepseekModel(config)
|
| 1208 |
+
self.lm_head = (
|
| 1209 |
+
nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1210 |
+
if config.stage_idx == config.num_stages - 1
|
| 1211 |
+
else None
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# Initialize weights and apply final processing
|
| 1215 |
+
# self.post_init()
|
| 1216 |
+
|
| 1217 |
+
def forward(
|
| 1218 |
+
self,
|
| 1219 |
+
tokens: torch.Tensor,
|
| 1220 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1221 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1222 |
+
) -> Tuple:
|
| 1223 |
+
r"""
|
| 1224 |
+
Example:
|
| 1225 |
+
|
| 1226 |
+
```python
|
| 1227 |
+
>>> from transformers import AutoTokenizer, DeepseekForCausalLM
|
| 1228 |
+
|
| 1229 |
+
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 1230 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 1231 |
+
|
| 1232 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1233 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1234 |
+
|
| 1235 |
+
>>> # Generate
|
| 1236 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1237 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1238 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1239 |
+
```"""
|
| 1240 |
+
hidden_states = self.model(
|
| 1241 |
+
tokens,
|
| 1242 |
+
attention_mask=attention_mask,
|
| 1243 |
+
position_ids=position_ids,
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
logits = (
|
| 1247 |
+
self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
|
| 1248 |
+
)
|
| 1249 |
+
return logits
|
| 1250 |
+
|
| 1251 |
+
def prepare_inputs_for_generation(
|
| 1252 |
+
self,
|
| 1253 |
+
input_ids,
|
| 1254 |
+
past_key_values=None,
|
| 1255 |
+
attention_mask=None,
|
| 1256 |
+
**kwargs,
|
| 1257 |
+
):
|
| 1258 |
+
if past_key_values is not None:
|
| 1259 |
+
# Assuming isinstance(past_key_values, Cache):
|
| 1260 |
+
cache_length = past_key_values.get_seq_length()
|
| 1261 |
+
past_length = past_key_values.seen_tokens
|
| 1262 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1263 |
+
|
| 1264 |
+
# Keep only the unprocessed tokens:
|
| 1265 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1266 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
| 1267 |
+
# input)
|
| 1268 |
+
if (
|
| 1269 |
+
attention_mask is not None
|
| 1270 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
| 1271 |
+
):
|
| 1272 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1273 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1274 |
+
# input_ids based on the past_length.
|
| 1275 |
+
elif past_length < input_ids.shape[1]:
|
| 1276 |
+
input_ids = input_ids[:, past_length:]
|
| 1277 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1278 |
+
|
| 1279 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1280 |
+
if (
|
| 1281 |
+
max_cache_length is not None
|
| 1282 |
+
and attention_mask is not None
|
| 1283 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1284 |
+
):
|
| 1285 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1286 |
+
|
| 1287 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1288 |
+
if attention_mask is not None and position_ids is None:
|
| 1289 |
+
# create position_ids on the fly for batch generation
|
| 1290 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1291 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1292 |
+
if past_key_values:
|
| 1293 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1294 |
+
|
| 1295 |
+
model_inputs = {"input_ids": input_ids}
|
| 1296 |
+
|
| 1297 |
+
model_inputs.update(
|
| 1298 |
+
{
|
| 1299 |
+
"position_ids": position_ids,
|
| 1300 |
+
"past_key_values": past_key_values,
|
| 1301 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1302 |
+
"attention_mask": attention_mask,
|
| 1303 |
+
}
|
| 1304 |
+
)
|
| 1305 |
+
return model_inputs
|
| 1306 |
+
|
| 1307 |
+
@staticmethod
|
| 1308 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1309 |
+
reordered_past = ()
|
| 1310 |
+
for layer_past in past_key_values:
|
| 1311 |
+
reordered_past += (
|
| 1312 |
+
tuple(
|
| 1313 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
| 1314 |
+
for past_state in layer_past
|
| 1315 |
+
),
|
| 1316 |
+
)
|
| 1317 |
+
return reordered_past
|
| 1318 |
+
|
| 1319 |
+
# Setup Symmetric Memory for MoE token shuffle.
|
| 1320 |
+
# Supports inference currently.
|
| 1321 |
+
def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
|
| 1322 |
+
for layer in self.model.layers.values():
|
| 1323 |
+
if not isinstance(layer.mlp, MoE):
|
| 1324 |
+
continue
|
| 1325 |
+
layer.mlp.setup_symm_mem(dtype, device)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
| 15 |
+
if sem == "relaxed":
|
| 16 |
+
tl.inline_asm_elementwise(
|
| 17 |
+
"""
|
| 18 |
+
{
|
| 19 |
+
.reg .u32 %tmp32_<1>;
|
| 20 |
+
.reg .pred %p<1>;
|
| 21 |
+
|
| 22 |
+
send_signal:
|
| 23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 25 |
+
@!%p0 bra send_signal;
|
| 26 |
+
}
|
| 27 |
+
""",
|
| 28 |
+
"=r, l",
|
| 29 |
+
[addrs],
|
| 30 |
+
dtype=tl.int32,
|
| 31 |
+
is_pure=False,
|
| 32 |
+
pack=1,
|
| 33 |
+
)
|
| 34 |
+
elif sem == "acq_rel":
|
| 35 |
+
tl.inline_asm_elementwise(
|
| 36 |
+
"""
|
| 37 |
+
{
|
| 38 |
+
.reg .u32 %tmp32_<1>;
|
| 39 |
+
.reg .pred %p<1>;
|
| 40 |
+
|
| 41 |
+
send_signal:
|
| 42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
| 43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
| 44 |
+
@!%p0 bra send_signal;
|
| 45 |
+
}
|
| 46 |
+
""",
|
| 47 |
+
"=r, l",
|
| 48 |
+
[addrs],
|
| 49 |
+
dtype=tl.int32,
|
| 50 |
+
is_pure=False,
|
| 51 |
+
pack=1,
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
| 59 |
+
if sem == "relaxed":
|
| 60 |
+
tl.inline_asm_elementwise(
|
| 61 |
+
"""
|
| 62 |
+
{
|
| 63 |
+
.reg .u32 %tmp32_<1>;
|
| 64 |
+
.reg .pred %p<1>;
|
| 65 |
+
|
| 66 |
+
wait_signal:
|
| 67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 69 |
+
@!%p0 bra wait_signal;
|
| 70 |
+
}
|
| 71 |
+
""",
|
| 72 |
+
"=r, l",
|
| 73 |
+
[addrs],
|
| 74 |
+
dtype=tl.int32,
|
| 75 |
+
is_pure=False,
|
| 76 |
+
pack=1,
|
| 77 |
+
)
|
| 78 |
+
elif sem == "acq_rel":
|
| 79 |
+
tl.inline_asm_elementwise(
|
| 80 |
+
"""
|
| 81 |
+
{
|
| 82 |
+
.reg .u32 %tmp32_<1>;
|
| 83 |
+
.reg .pred %p<1>;
|
| 84 |
+
|
| 85 |
+
wait_signal:
|
| 86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
| 87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
| 88 |
+
@!%p0 bra wait_signal;
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
"=r, l",
|
| 92 |
+
[addrs],
|
| 93 |
+
dtype=tl.int32,
|
| 94 |
+
is_pure=False,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@triton.jit
|
| 102 |
+
def blockwise_barrier(
|
| 103 |
+
signal_pad_ptrs,
|
| 104 |
+
block_id,
|
| 105 |
+
rank: tl.constexpr,
|
| 106 |
+
world_size: tl.constexpr,
|
| 107 |
+
sem: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
| 111 |
+
|
| 112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
| 113 |
+
building block for expressing different synchronization patterns.
|
| 114 |
+
|
| 115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
| 116 |
+
kernels across all devices are visible to the current kernel:
|
| 117 |
+
|
| 118 |
+
blockwise_barrier(..., sem="relaxed")
|
| 119 |
+
sync_threads()
|
| 120 |
+
|
| 121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
| 122 |
+
block are visible to all remote blocks with matching blockIdx:
|
| 123 |
+
|
| 124 |
+
sync_threads()
|
| 125 |
+
blockwise_barrier(..., sem="acq_rel")
|
| 126 |
+
sync_threads()
|
| 127 |
+
|
| 128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
| 129 |
+
for writing by subsequent kernels across all devices.
|
| 130 |
+
|
| 131 |
+
sync_threads()
|
| 132 |
+
blockwise_barrier(..., sem="relaxed")
|
| 133 |
+
|
| 134 |
+
CUDA graph friendliness:
|
| 135 |
+
|
| 136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
| 137 |
+
pad, which resets to a zero-filled state after each successful
|
| 138 |
+
synchronization. This design eliminates the need for incrementing a
|
| 139 |
+
flag from host.
|
| 140 |
+
"""
|
| 141 |
+
if block_id is None:
|
| 142 |
+
block_id = get_flat_bid()
|
| 143 |
+
flat_tid = get_flat_tid()
|
| 144 |
+
|
| 145 |
+
remote_ranks = tl.arange(0, world_size)
|
| 146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
| 147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
| 148 |
+
tl.pointer_type(tl.uint32)
|
| 149 |
+
)
|
| 150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
| 151 |
+
|
| 152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
| 153 |
+
tl.pointer_type(tl.uint32)
|
| 154 |
+
)
|
| 155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
| 156 |
+
|
| 157 |
+
if flat_tid < world_size:
|
| 158 |
+
send_signal(send_addrs, sem)
|
| 159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from .triton_barrier import blockwise_barrier
|
| 14 |
+
from .triton_utils import sync_threads
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.jit
|
| 18 |
+
def _exchange_row_offsets(
|
| 19 |
+
split_sizes_ptrs,
|
| 20 |
+
rank: tl.constexpr,
|
| 21 |
+
world_size: tl.constexpr,
|
| 22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 23 |
+
):
|
| 24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 25 |
+
|
| 26 |
+
# split_sizes_ptr for all ranks
|
| 27 |
+
# All these vector stacks into split_sizes_matrix
|
| 28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
| 29 |
+
|
| 30 |
+
# split_sizes_matrix[remote_rank, :]
|
| 31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
| 32 |
+
tl.pointer_type(tl.int64)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
offsets_ = tl.arange(0, world_size)
|
| 36 |
+
input_split_sizes = tl.load(
|
| 37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
| 41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
| 42 |
+
|
| 43 |
+
# split_sizes_matrix[:, rank]
|
| 44 |
+
output_split_sizes_ptrs = (
|
| 45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
| 46 |
+
)
|
| 47 |
+
output_split_sizes = tl.load(
|
| 48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
| 49 |
+
)
|
| 50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
| 51 |
+
|
| 52 |
+
return input_row_offset, output_row_offset, num_rows
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@triton.jit
|
| 56 |
+
def on_device_all_to_all_v_kernel(
|
| 57 |
+
output_ptr,
|
| 58 |
+
output_splits_ptr,
|
| 59 |
+
input_ptrs,
|
| 60 |
+
input_splits_ptr,
|
| 61 |
+
signal_pad_ptrs,
|
| 62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
| 63 |
+
rank: tl.constexpr,
|
| 64 |
+
world_size: tl.constexpr,
|
| 65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 66 |
+
UNROLL_FACTOR: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
):
|
| 69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 70 |
+
sync_threads()
|
| 71 |
+
|
| 72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
| 74 |
+
|
| 75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
| 76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
| 80 |
+
if block_offset == 0:
|
| 81 |
+
# Update output_splits
|
| 82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
| 83 |
+
|
| 84 |
+
input_ptr = (
|
| 85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
| 86 |
+
tl.pointer_type(tl.bfloat16)
|
| 87 |
+
)
|
| 88 |
+
+ input_row_offset * dim
|
| 89 |
+
)
|
| 90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
| 91 |
+
|
| 92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
| 93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
| 94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
| 95 |
+
)
|
| 96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
| 97 |
+
offset = numel_per_rank * block_offset
|
| 98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
| 99 |
+
|
| 100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
| 101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
| 102 |
+
datas = []
|
| 103 |
+
for j in tl.range(
|
| 104 |
+
i,
|
| 105 |
+
i + outer_loop_step,
|
| 106 |
+
BLOCK_SIZE,
|
| 107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
| 108 |
+
):
|
| 109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
| 110 |
+
data = tl.load(input_ptr + offsets)
|
| 111 |
+
tl.store(output_ptr + offsets, data)
|
| 112 |
+
|
| 113 |
+
offset += unroll_region_size
|
| 114 |
+
while offset < end:
|
| 115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
| 116 |
+
mask = offsets < num_rows * dim
|
| 117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
| 118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
| 119 |
+
offset += BLOCK_SIZE
|
| 120 |
+
|
| 121 |
+
sync_threads()
|
| 122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _on_device_all_to_all_v(
|
| 127 |
+
output: torch.Tensor,
|
| 128 |
+
output_splits: torch.Tensor,
|
| 129 |
+
input: torch.Tensor,
|
| 130 |
+
input_splits: torch.Tensor,
|
| 131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
| 133 |
+
UNROLL_FACTOR: int = 8,
|
| 134 |
+
BLOCK_SIZE: int = 16384,
|
| 135 |
+
):
|
| 136 |
+
assert output.dim() == 2, f"{output.shape}"
|
| 137 |
+
assert input.dim() == 2, f"{input.shape}"
|
| 138 |
+
assert output.shape[1] == input.shape[1]
|
| 139 |
+
|
| 140 |
+
dim = output.shape[1]
|
| 141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
| 142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
| 143 |
+
|
| 144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
| 145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
| 146 |
+
output,
|
| 147 |
+
output_splits,
|
| 148 |
+
input_hdl.buffer_ptrs_dev,
|
| 149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
| 150 |
+
input_hdl.signal_pad_ptrs_dev,
|
| 151 |
+
dim=dim,
|
| 152 |
+
rank=input_hdl.rank,
|
| 153 |
+
world_size=input_hdl.world_size,
|
| 154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
| 155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
| 156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 157 |
+
num_warps=16,
|
| 158 |
+
)
|
| 159 |
+
# log_triton_kernel(kernel)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
| 164 |
+
# A symmetric memory holding the grad_output during backward
|
| 165 |
+
grad_output_buf = None
|
| 166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
| 167 |
+
splits_buf = None
|
| 168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
| 169 |
+
max_output_len = None
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def forward(
|
| 173 |
+
ctx,
|
| 174 |
+
input: torch.Tensor,
|
| 175 |
+
input_splits: torch.Tensor,
|
| 176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
input: input tensor with data for all ranks concatenated.
|
| 181 |
+
input_splits: input splits of shape (group.world_size,)
|
| 182 |
+
group: process group to scope the collective.
|
| 183 |
+
"""
|
| 184 |
+
# Initialize input splits buffer (one time only)
|
| 185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
| 186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
| 187 |
+
*input_splits.shape,
|
| 188 |
+
dtype=input_splits.dtype,
|
| 189 |
+
device=input_splits.device,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
| 193 |
+
raise RuntimeError(
|
| 194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Allocate output buffer
|
| 198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
| 199 |
+
# Allocate output splits tensor
|
| 200 |
+
output_splits = torch.empty_like(input_splits)
|
| 201 |
+
# Copy input splits to the buffer
|
| 202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
| 203 |
+
|
| 204 |
+
# Shuffle input to output
|
| 205 |
+
_on_device_all_to_all_v(
|
| 206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Output splits in forward is the input splits in backward
|
| 210 |
+
ctx.save_for_backward(output_splits)
|
| 211 |
+
ctx.group = group
|
| 212 |
+
ctx.input_shape = input.shape
|
| 213 |
+
return output, output_splits
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def backward(ctx, grad_output, grad_splits):
|
| 217 |
+
"""
|
| 218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
| 219 |
+
Args:
|
| 220 |
+
`grad_output`: output's gradients passed from the downstream.
|
| 221 |
+
`grad_splits`: unused.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
# Initialize grad_output buffer (one time only)
|
| 225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
| 226 |
+
assert (
|
| 227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
| 228 |
+
), "`max_output_len` not set"
|
| 229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
| 230 |
+
OnDeviceAllToAllV.max_output_len,
|
| 231 |
+
*grad_output.shape[1:],
|
| 232 |
+
dtype=grad_output.dtype,
|
| 233 |
+
device=grad_output.device,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
| 237 |
+
# our symm_mem buffer?
|
| 238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
| 239 |
+
grad_output
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Size info
|
| 243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
| 244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
| 245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
| 246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
| 247 |
+
|
| 248 |
+
# Shuffle gradients back to the input
|
| 249 |
+
_on_device_all_to_all_v(
|
| 250 |
+
grad_input,
|
| 251 |
+
grad_input_splits,
|
| 252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
| 253 |
+
OnDeviceAllToAllV.splits_buf,
|
| 254 |
+
group=ctx.group,
|
| 255 |
+
)
|
| 256 |
+
return grad_input, None, None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Alias
|
| 260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@triton.jit
|
| 12 |
+
def get_tid():
|
| 13 |
+
return tl.inline_asm_elementwise(
|
| 14 |
+
"""
|
| 15 |
+
mov.u32 $0, %tid.x;
|
| 16 |
+
mov.u32 $1, %tid.y;
|
| 17 |
+
mov.u32 $2, %tid.z;
|
| 18 |
+
""",
|
| 19 |
+
"=r,=r,=r",
|
| 20 |
+
[],
|
| 21 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
| 22 |
+
is_pure=True,
|
| 23 |
+
pack=1,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def get_ntid():
|
| 29 |
+
return tl.inline_asm_elementwise(
|
| 30 |
+
"""
|
| 31 |
+
mov.u32 $0, %ntid.x;
|
| 32 |
+
mov.u32 $1, %ntid.y;
|
| 33 |
+
mov.u32 $2, %ntid.z;
|
| 34 |
+
""",
|
| 35 |
+
"=r,=r,=r",
|
| 36 |
+
[],
|
| 37 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
| 38 |
+
is_pure=True,
|
| 39 |
+
pack=1,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@triton.jit
|
| 44 |
+
def get_flat_tid():
|
| 45 |
+
tid_x, tid_y, tid_z = get_tid()
|
| 46 |
+
ntid_x, ntid_y, _ = get_ntid()
|
| 47 |
+
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@triton.jit
|
| 51 |
+
def get_flat_bid():
|
| 52 |
+
return (
|
| 53 |
+
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
|
| 54 |
+
+ tl.program_id(1) * tl.num_programs(0)
|
| 55 |
+
+ tl.program_id(0)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@triton.jit
|
| 60 |
+
def sync_threads():
|
| 61 |
+
tl.inline_asm_elementwise(
|
| 62 |
+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
|
| 63 |
+
)
|
torchtitan/experiments/deepseek_v3/train.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 8 run.py
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from checkpoint import load_weights_from_hf
|
| 11 |
+
from model import DeepseekForCausalLM
|
| 12 |
+
from model_config import deepseek_config_registry
|
| 13 |
+
|
| 14 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 15 |
+
from torch.distributed.fsdp import fully_shard
|
| 16 |
+
from torch.distributed.pipelining import PipelineStage, Schedule1F1B
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Use DeepSeek-V2-Lite as a proxy
|
| 20 |
+
model_id = "deepseek-ai/DeepSeek-V2-Lite"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Run full model
|
| 24 |
+
def run_full_model(
|
| 25 |
+
mesh: DeviceMesh,
|
| 26 |
+
):
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
device_count = torch.cuda.device_count()
|
| 29 |
+
device = torch.device("cuda", rank % device_count)
|
| 30 |
+
|
| 31 |
+
pp_mesh = mesh["pp"]
|
| 32 |
+
ep_mesh = mesh["ep"]
|
| 33 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 34 |
+
ep_rank = ep_mesh.get_local_rank()
|
| 35 |
+
pp_size = pp_mesh.size()
|
| 36 |
+
ep_size = ep_mesh.size()
|
| 37 |
+
|
| 38 |
+
# Get model configs
|
| 39 |
+
model_args = deepseek_config_registry[model_id]
|
| 40 |
+
# [Note]: I am making the model smaller for testing / avoiding OOM. If you
|
| 41 |
+
# have sufficient GPUs for model parallelism, you can remove this line.
|
| 42 |
+
model_args.num_hidden_layers = 16
|
| 43 |
+
|
| 44 |
+
# Apply model parallelism
|
| 45 |
+
model_args.ep_size = ep_size
|
| 46 |
+
model_args.num_stages = pp_size
|
| 47 |
+
model_args.stage_idx = pp_rank
|
| 48 |
+
print(model_args)
|
| 49 |
+
|
| 50 |
+
# Instantiate model
|
| 51 |
+
with device, mesh:
|
| 52 |
+
model = DeepseekForCausalLM(model_args)
|
| 53 |
+
|
| 54 |
+
# Load weights
|
| 55 |
+
load_weights_from_hf(model, model_id, device)
|
| 56 |
+
model.train()
|
| 57 |
+
|
| 58 |
+
# Apply data parallelism
|
| 59 |
+
fsdp_mesh = mesh["fsdp"]
|
| 60 |
+
hsdp_mesh = mesh["ep", "fsdp"]
|
| 61 |
+
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
|
| 62 |
+
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
|
| 63 |
+
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
|
| 64 |
+
# it will be ineconomical re-gather the weights.
|
| 65 |
+
for layer in model.model.layers.values():
|
| 66 |
+
# Apply FSDP to experts
|
| 67 |
+
if hasattr(layer.mlp, "experts"):
|
| 68 |
+
for expert in layer.mlp.experts.values():
|
| 69 |
+
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
|
| 70 |
+
# Apply HSDP to other parts such as attention, layernorm, because they
|
| 71 |
+
# are doing DDP on EP dimension
|
| 72 |
+
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 73 |
+
|
| 74 |
+
# Apply HSDP on root model (lm_head, embeddings, etc)
|
| 75 |
+
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
|
| 76 |
+
|
| 77 |
+
# Synthetic setting
|
| 78 |
+
microbatches = pp_size * 2
|
| 79 |
+
|
| 80 |
+
# Use Symmetric Memory for MoE token shuffle.
|
| 81 |
+
# TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
|
| 82 |
+
# currently supported for forward only. See `generate.py`.
|
| 83 |
+
# model.setup_symm_mem(torch.bfloat16, device)
|
| 84 |
+
|
| 85 |
+
# Example inputs
|
| 86 |
+
torch.manual_seed(ep_rank)
|
| 87 |
+
bs = 4
|
| 88 |
+
seqlen = 128
|
| 89 |
+
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
|
| 90 |
+
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
|
| 91 |
+
|
| 92 |
+
# Create loss function
|
| 93 |
+
loss_fn = torch.nn.functional.cross_entropy
|
| 94 |
+
|
| 95 |
+
# Run forward and backward
|
| 96 |
+
steps = 2
|
| 97 |
+
for _ in range(steps):
|
| 98 |
+
if pp_size > 1:
|
| 99 |
+
# Create pipeline stage
|
| 100 |
+
stage = PipelineStage(
|
| 101 |
+
model,
|
| 102 |
+
pp_rank,
|
| 103 |
+
pp_size,
|
| 104 |
+
device,
|
| 105 |
+
group=pp_mesh.get_group(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Create pipeline schedule
|
| 109 |
+
losses = []
|
| 110 |
+
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
|
| 111 |
+
|
| 112 |
+
if pp_rank == 0:
|
| 113 |
+
y = pp_schedule.step(x)
|
| 114 |
+
elif pp_rank == pp_size - 1:
|
| 115 |
+
y = pp_schedule.step(target=label, losses=losses)
|
| 116 |
+
loss = torch.mean(torch.stack(losses))
|
| 117 |
+
else:
|
| 118 |
+
pp_schedule.step()
|
| 119 |
+
else:
|
| 120 |
+
y = model(x)
|
| 121 |
+
loss = loss_fn(y, label)
|
| 122 |
+
loss.backward()
|
| 123 |
+
|
| 124 |
+
if pp_rank == pp_size - 1:
|
| 125 |
+
print(f"logits: {y.shape}")
|
| 126 |
+
print(f"{loss=}")
|
| 127 |
+
|
| 128 |
+
if pp_rank == 0:
|
| 129 |
+
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
|
| 130 |
+
print(f"{torch.linalg.norm(param.grad)=}")
|
| 131 |
+
|
| 132 |
+
model.zero_grad()
|
| 133 |
+
|
| 134 |
+
print("Backward done")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
|
| 139 |
+
|
| 140 |
+
run_full_model(mesh)
|
| 141 |
+
|
| 142 |
+
dist.destroy_process_group()
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/dataset/flux_dataset.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import random
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from datasets import Dataset, load_dataset
|
| 17 |
+
from datasets.distributed import split_dataset_by_node
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 21 |
+
|
| 22 |
+
from torch.utils.data import IterableDataset
|
| 23 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 24 |
+
|
| 25 |
+
from torchtitan.config_manager import JobConfig
|
| 26 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 27 |
+
from torchtitan.tools.logging import logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _process_cc12m_image(
|
| 31 |
+
img: Image.Image,
|
| 32 |
+
output_size: int = 256,
|
| 33 |
+
) -> Optional[torch.Tensor]:
|
| 34 |
+
"""Process CC12M image to the desired size."""
|
| 35 |
+
|
| 36 |
+
width, height = img.size
|
| 37 |
+
# Skip low resolution images
|
| 38 |
+
if width < output_size or height < output_size:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
if width >= height:
|
| 42 |
+
# resize height to be equal to output_size, then crop
|
| 43 |
+
new_width, new_height = math.ceil(output_size / height * width), output_size
|
| 44 |
+
img = img.resize((new_width, new_height))
|
| 45 |
+
left = random.randint(0, new_width - output_size)
|
| 46 |
+
resized_img = img.crop((left, 0, left + output_size, output_size))
|
| 47 |
+
else:
|
| 48 |
+
# resize width to be equal to output_size, the crop
|
| 49 |
+
new_width, new_height = (
|
| 50 |
+
output_size,
|
| 51 |
+
math.ceil(output_size / width * height),
|
| 52 |
+
)
|
| 53 |
+
img = img.resize((new_width, new_height))
|
| 54 |
+
lower = random.randint(0, new_width - output_size)
|
| 55 |
+
resized_img = img.crop((0, lower, output_size, lower + output_size))
|
| 56 |
+
|
| 57 |
+
assert resized_img.size[0] == resized_img.size[1] == output_size
|
| 58 |
+
|
| 59 |
+
# Skip grayscale images
|
| 60 |
+
if resized_img.mode == "L":
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
np_img = np.array(resized_img).transpose((2, 0, 1))
|
| 64 |
+
tensor_img = torch.tensor(np_img).float() / 255.0
|
| 65 |
+
|
| 66 |
+
# NOTE: The following commented code is an alternative way
|
| 67 |
+
# img_transform = transforms.Compose(
|
| 68 |
+
# [
|
| 69 |
+
# transforms.Resize(max(output_size, output_size)),
|
| 70 |
+
# transforms.CenterCrop((output_size, output_size)),
|
| 71 |
+
# transforms.ToTensor(),
|
| 72 |
+
# ]
|
| 73 |
+
# )
|
| 74 |
+
# tensor_img = img_transform(img)
|
| 75 |
+
|
| 76 |
+
return tensor_img
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _flux_data_processor(
|
| 80 |
+
sample: dict[str, Any],
|
| 81 |
+
t5_tokenizer: FluxTokenizer,
|
| 82 |
+
clip_tokenizer: FluxTokenizer,
|
| 83 |
+
output_size: int = 256,
|
| 84 |
+
) -> dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Preprocess CC12M dataset sample image and text for Flux model.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
sample: A sample from dataset
|
| 90 |
+
t5_encoder: T5 encoder
|
| 91 |
+
clip_encoder: CLIP encoder
|
| 92 |
+
output_size: The output image size
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
|
| 96 |
+
t5_tokens = t5_tokenizer.encode(sample["txt"])
|
| 97 |
+
clip_tokens = clip_tokenizer.encode(sample["txt"])
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"image": img,
|
| 101 |
+
"clip_tokens": clip_tokens, # type: List[int]
|
| 102 |
+
"t5_tokens": t5_tokens, # type: List[int]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class TextToImageDatasetConfig:
|
| 108 |
+
path: str
|
| 109 |
+
loader: Callable
|
| 110 |
+
data_processor: Callable
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
DATASETS = {
|
| 114 |
+
"cc12m": TextToImageDatasetConfig(
|
| 115 |
+
path="pixparse/cc12m-wds",
|
| 116 |
+
loader=lambda path: load_dataset(path, split="train", streaming=True),
|
| 117 |
+
data_processor=_flux_data_processor,
|
| 118 |
+
),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _validate_dataset(
|
| 123 |
+
dataset_name: str, dataset_path: Optional[str] = None
|
| 124 |
+
) -> tuple[str, Callable, Callable]:
|
| 125 |
+
"""Validate dataset name and path."""
|
| 126 |
+
if dataset_name not in DATASETS:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Dataset {dataset_name} is not supported. "
|
| 129 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
config = DATASETS[dataset_name]
|
| 133 |
+
path = dataset_path or config.path
|
| 134 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 135 |
+
return path, config.loader, config.data_processor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FluxDataset(IterableDataset, Stateful):
|
| 139 |
+
"""Dataset for FLUX text-to-image model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
dataset_name (str): Name of the dataset.
|
| 143 |
+
dataset_path (str): Path to the dataset.
|
| 144 |
+
model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
|
| 145 |
+
dp_rank (int): Data parallel rank.
|
| 146 |
+
dp_world_size (int): Data parallel world size.
|
| 147 |
+
infinite (bool): Whether to loop over the dataset infinitely.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
dataset_name: str,
|
| 153 |
+
dataset_path: Optional[str],
|
| 154 |
+
t5_tokenizer: FluxTokenizer,
|
| 155 |
+
clip_tokenizer: FluxTokenizer,
|
| 156 |
+
job_config: Optional[JobConfig] = None,
|
| 157 |
+
dp_rank: int = 0,
|
| 158 |
+
dp_world_size: int = 1,
|
| 159 |
+
infinite: bool = False,
|
| 160 |
+
) -> None:
|
| 161 |
+
|
| 162 |
+
# Force lowercase for consistent comparison
|
| 163 |
+
dataset_name = dataset_name.lower()
|
| 164 |
+
|
| 165 |
+
path, dataset_loader, data_processor = _validate_dataset(
|
| 166 |
+
dataset_name, dataset_path
|
| 167 |
+
)
|
| 168 |
+
ds = dataset_loader(path)
|
| 169 |
+
|
| 170 |
+
self.dataset_name = dataset_name
|
| 171 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 172 |
+
|
| 173 |
+
self._t5_tokenizer = t5_tokenizer
|
| 174 |
+
self._clip_tokenizer = clip_tokenizer
|
| 175 |
+
self._data_processor = data_processor
|
| 176 |
+
self.job_config = job_config
|
| 177 |
+
|
| 178 |
+
self.infinite = infinite
|
| 179 |
+
|
| 180 |
+
# Variables for checkpointing
|
| 181 |
+
self._sample_idx = 0
|
| 182 |
+
self._all_samples: list[dict[str, Any]] = []
|
| 183 |
+
|
| 184 |
+
def _get_data_iter(self):
|
| 185 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 186 |
+
return iter([])
|
| 187 |
+
|
| 188 |
+
it = iter(self._data)
|
| 189 |
+
for _ in range(self._sample_idx):
|
| 190 |
+
next(it)
|
| 191 |
+
return it
|
| 192 |
+
|
| 193 |
+
def __iter__(self):
|
| 194 |
+
while True:
|
| 195 |
+
for sample in self._get_data_iter():
|
| 196 |
+
# Use the dataset-specific preprocessor
|
| 197 |
+
sample_dict = self._data_processor(
|
| 198 |
+
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# skip low quality image or image with color channel = 1
|
| 202 |
+
if sample_dict["image"] is None:
|
| 203 |
+
logger.warning(
|
| 204 |
+
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
|
| 205 |
+
)
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
self._all_samples.extend(sample_dict)
|
| 209 |
+
self._sample_idx += 1
|
| 210 |
+
|
| 211 |
+
labels = sample_dict.pop("image")
|
| 212 |
+
yield sample_dict, labels
|
| 213 |
+
|
| 214 |
+
if not self.infinite:
|
| 215 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
# Reset offset for the next iteration
|
| 219 |
+
self._sample_idx = 0
|
| 220 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 221 |
+
|
| 222 |
+
def load_state_dict(self, state_dict):
|
| 223 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 224 |
+
self._all_samples = state_dict["all_samples"]
|
| 225 |
+
|
| 226 |
+
def state_dict(self):
|
| 227 |
+
return {
|
| 228 |
+
"all_samples": self._all_samples,
|
| 229 |
+
"sample_idx": self._sample_idx,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_flux_dataloader(
|
| 234 |
+
dp_world_size: int,
|
| 235 |
+
dp_rank: int,
|
| 236 |
+
job_config: JobConfig,
|
| 237 |
+
# This parameter is not used, keep it for compatibility
|
| 238 |
+
tokenizer: FluxTokenizer | None,
|
| 239 |
+
infinite: bool = True,
|
| 240 |
+
) -> ParallelAwareDataloader:
|
| 241 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 242 |
+
dataset_name = job_config.training.dataset
|
| 243 |
+
dataset_path = job_config.training.dataset_path
|
| 244 |
+
batch_size = job_config.training.batch_size
|
| 245 |
+
|
| 246 |
+
t5_encoder_name = job_config.encoder.t5_encoder
|
| 247 |
+
clip_encoder_name = job_config.encoder.clip_encoder
|
| 248 |
+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
|
| 249 |
+
|
| 250 |
+
ds = FluxDataset(
|
| 251 |
+
dataset_name=dataset_name,
|
| 252 |
+
dataset_path=dataset_path,
|
| 253 |
+
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
|
| 254 |
+
clip_tokenizer=FluxTokenizer(
|
| 255 |
+
clip_encoder_name, max_length=77
|
| 256 |
+
), # fix max_length for CLIP
|
| 257 |
+
dp_rank=dp_rank,
|
| 258 |
+
dp_world_size=dp_world_size,
|
| 259 |
+
infinite=infinite,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return ParallelAwareDataloader(
|
| 263 |
+
dataset=ds,
|
| 264 |
+
dp_rank=dp_rank,
|
| 265 |
+
dp_world_size=dp_world_size,
|
| 266 |
+
batch_size=batch_size,
|
| 267 |
+
)
|
torchtitan/experiments/flux/model/autoencoder.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from safetensors.torch import load_file as load_sft
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class AutoEncoderParams:
|
| 18 |
+
resolution: int = 256
|
| 19 |
+
in_channels: int = 3
|
| 20 |
+
ch: int = 128
|
| 21 |
+
out_ch: int = 3
|
| 22 |
+
ch_mult: tuple[int] = (1, 2, 4, 4)
|
| 23 |
+
num_res_blocks: int = 2
|
| 24 |
+
z_channels: int = 16
|
| 25 |
+
scale_factor: float = 0.3611
|
| 26 |
+
shift_factor: float = 0.1159
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def swish(x: Tensor) -> Tensor:
|
| 30 |
+
return x * torch.sigmoid(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AttnBlock(nn.Module):
|
| 34 |
+
def __init__(self, in_channels: int):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.in_channels = in_channels
|
| 37 |
+
|
| 38 |
+
self.norm = nn.GroupNorm(
|
| 39 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 43 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 44 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 45 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 46 |
+
|
| 47 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 48 |
+
h_ = self.norm(h_)
|
| 49 |
+
q = self.q(h_)
|
| 50 |
+
k = self.k(h_)
|
| 51 |
+
v = self.v(h_)
|
| 52 |
+
|
| 53 |
+
b, c, h, w = q.shape
|
| 54 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 55 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 56 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 57 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 58 |
+
|
| 59 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
return x + self.proj_out(self.attention(x))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResnetBlock(nn.Module):
|
| 66 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 70 |
+
self.out_channels = out_channels
|
| 71 |
+
|
| 72 |
+
self.norm1 = nn.GroupNorm(
|
| 73 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 74 |
+
)
|
| 75 |
+
self.conv1 = nn.Conv2d(
|
| 76 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 77 |
+
)
|
| 78 |
+
self.norm2 = nn.GroupNorm(
|
| 79 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
| 80 |
+
)
|
| 81 |
+
self.conv2 = nn.Conv2d(
|
| 82 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 83 |
+
)
|
| 84 |
+
if self.in_channels != self.out_channels:
|
| 85 |
+
self.nin_shortcut = nn.Conv2d(
|
| 86 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
h = x
|
| 91 |
+
h = self.norm1(h)
|
| 92 |
+
h = swish(h)
|
| 93 |
+
h = self.conv1(h)
|
| 94 |
+
|
| 95 |
+
h = self.norm2(h)
|
| 96 |
+
h = swish(h)
|
| 97 |
+
h = self.conv2(h)
|
| 98 |
+
|
| 99 |
+
if self.in_channels != self.out_channels:
|
| 100 |
+
x = self.nin_shortcut(x)
|
| 101 |
+
|
| 102 |
+
return x + h
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Downsample(nn.Module):
|
| 106 |
+
def __init__(self, in_channels: int):
|
| 107 |
+
super().__init__()
|
| 108 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 109 |
+
self.conv = nn.Conv2d(
|
| 110 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x: Tensor):
|
| 114 |
+
pad = (0, 1, 0, 1)
|
| 115 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 116 |
+
x = self.conv(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Upsample(nn.Module):
|
| 121 |
+
def __init__(self, in_channels: int):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.conv = nn.Conv2d(
|
| 124 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: Tensor):
|
| 128 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 129 |
+
x = self.conv(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Encoder(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
resolution: int,
|
| 137 |
+
in_channels: int,
|
| 138 |
+
ch: int,
|
| 139 |
+
ch_mult: list[int],
|
| 140 |
+
num_res_blocks: int,
|
| 141 |
+
z_channels: int,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.ch = ch
|
| 145 |
+
self.num_resolutions = len(ch_mult)
|
| 146 |
+
self.num_res_blocks = num_res_blocks
|
| 147 |
+
self.resolution = resolution
|
| 148 |
+
self.in_channels = in_channels
|
| 149 |
+
# downsampling
|
| 150 |
+
self.conv_in = nn.Conv2d(
|
| 151 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
curr_res = resolution
|
| 155 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 156 |
+
self.in_ch_mult = in_ch_mult
|
| 157 |
+
self.down = nn.ModuleList()
|
| 158 |
+
block_in = self.ch
|
| 159 |
+
for i_level in range(self.num_resolutions):
|
| 160 |
+
block = nn.ModuleList()
|
| 161 |
+
attn = nn.ModuleList()
|
| 162 |
+
block_in = ch * in_ch_mult[i_level]
|
| 163 |
+
block_out = ch * ch_mult[i_level]
|
| 164 |
+
for _ in range(self.num_res_blocks):
|
| 165 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 166 |
+
block_in = block_out
|
| 167 |
+
down = nn.Module()
|
| 168 |
+
down.block = block
|
| 169 |
+
down.attn = attn
|
| 170 |
+
if i_level != self.num_resolutions - 1:
|
| 171 |
+
down.downsample = Downsample(block_in)
|
| 172 |
+
curr_res = curr_res // 2
|
| 173 |
+
self.down.append(down)
|
| 174 |
+
|
| 175 |
+
# middle
|
| 176 |
+
self.mid = nn.Module()
|
| 177 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 178 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 179 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 180 |
+
|
| 181 |
+
# end
|
| 182 |
+
self.norm_out = nn.GroupNorm(
|
| 183 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 184 |
+
)
|
| 185 |
+
self.conv_out = nn.Conv2d(
|
| 186 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 190 |
+
# downsampling
|
| 191 |
+
hs = [self.conv_in(x)]
|
| 192 |
+
for i_level in range(self.num_resolutions):
|
| 193 |
+
for i_block in range(self.num_res_blocks):
|
| 194 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 195 |
+
if len(self.down[i_level].attn) > 0:
|
| 196 |
+
h = self.down[i_level].attn[i_block](h)
|
| 197 |
+
hs.append(h)
|
| 198 |
+
if i_level != self.num_resolutions - 1:
|
| 199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 200 |
+
|
| 201 |
+
# middle
|
| 202 |
+
h = hs[-1]
|
| 203 |
+
h = self.mid.block_1(h)
|
| 204 |
+
h = self.mid.attn_1(h)
|
| 205 |
+
h = self.mid.block_2(h)
|
| 206 |
+
# end
|
| 207 |
+
h = self.norm_out(h)
|
| 208 |
+
h = swish(h)
|
| 209 |
+
h = self.conv_out(h)
|
| 210 |
+
return h
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Decoder(nn.Module):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ch: int,
|
| 217 |
+
out_ch: int,
|
| 218 |
+
ch_mult: list[int],
|
| 219 |
+
num_res_blocks: int,
|
| 220 |
+
in_channels: int,
|
| 221 |
+
resolution: int,
|
| 222 |
+
z_channels: int,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.ch = ch
|
| 226 |
+
self.num_resolutions = len(ch_mult)
|
| 227 |
+
self.num_res_blocks = num_res_blocks
|
| 228 |
+
self.resolution = resolution
|
| 229 |
+
self.in_channels = in_channels
|
| 230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 231 |
+
|
| 232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 236 |
+
|
| 237 |
+
# z to block_in
|
| 238 |
+
self.conv_in = nn.Conv2d(
|
| 239 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# middle
|
| 243 |
+
self.mid = nn.Module()
|
| 244 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 245 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 246 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 247 |
+
|
| 248 |
+
# upsampling
|
| 249 |
+
self.up = nn.ModuleList()
|
| 250 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 251 |
+
block = nn.ModuleList()
|
| 252 |
+
attn = nn.ModuleList()
|
| 253 |
+
block_out = ch * ch_mult[i_level]
|
| 254 |
+
for _ in range(self.num_res_blocks + 1):
|
| 255 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 256 |
+
block_in = block_out
|
| 257 |
+
up = nn.Module()
|
| 258 |
+
up.block = block
|
| 259 |
+
up.attn = attn
|
| 260 |
+
if i_level != 0:
|
| 261 |
+
up.upsample = Upsample(block_in)
|
| 262 |
+
curr_res = curr_res * 2
|
| 263 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 264 |
+
|
| 265 |
+
# end
|
| 266 |
+
self.norm_out = nn.GroupNorm(
|
| 267 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 268 |
+
)
|
| 269 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 270 |
+
|
| 271 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 272 |
+
# get dtype for proper tracing
|
| 273 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 274 |
+
|
| 275 |
+
# z to block_in
|
| 276 |
+
h = self.conv_in(z)
|
| 277 |
+
|
| 278 |
+
# middle
|
| 279 |
+
h = self.mid.block_1(h)
|
| 280 |
+
h = self.mid.attn_1(h)
|
| 281 |
+
h = self.mid.block_2(h)
|
| 282 |
+
|
| 283 |
+
# cast to proper dtype
|
| 284 |
+
h = h.to(upscale_dtype)
|
| 285 |
+
# upsampling
|
| 286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 287 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 288 |
+
h = self.up[i_level].block[i_block](h)
|
| 289 |
+
if len(self.up[i_level].attn) > 0:
|
| 290 |
+
h = self.up[i_level].attn[i_block](h)
|
| 291 |
+
if i_level != 0:
|
| 292 |
+
h = self.up[i_level].upsample(h)
|
| 293 |
+
|
| 294 |
+
# end
|
| 295 |
+
h = self.norm_out(h)
|
| 296 |
+
h = swish(h)
|
| 297 |
+
h = self.conv_out(h)
|
| 298 |
+
return h
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class DiagonalGaussian(nn.Module):
|
| 302 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.sample = sample
|
| 305 |
+
self.chunk_dim = chunk_dim
|
| 306 |
+
|
| 307 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 308 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 309 |
+
if self.sample:
|
| 310 |
+
std = torch.exp(0.5 * logvar)
|
| 311 |
+
return mean + std * torch.randn_like(mean)
|
| 312 |
+
else:
|
| 313 |
+
return mean
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class AutoEncoder(nn.Module):
|
| 317 |
+
def __init__(self, params: AutoEncoderParams):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.params = params
|
| 320 |
+
self.encoder = Encoder(
|
| 321 |
+
resolution=params.resolution,
|
| 322 |
+
in_channels=params.in_channels,
|
| 323 |
+
ch=params.ch,
|
| 324 |
+
ch_mult=params.ch_mult,
|
| 325 |
+
num_res_blocks=params.num_res_blocks,
|
| 326 |
+
z_channels=params.z_channels,
|
| 327 |
+
)
|
| 328 |
+
self.decoder = Decoder(
|
| 329 |
+
resolution=params.resolution,
|
| 330 |
+
in_channels=params.in_channels,
|
| 331 |
+
ch=params.ch,
|
| 332 |
+
out_ch=params.out_ch,
|
| 333 |
+
ch_mult=params.ch_mult,
|
| 334 |
+
num_res_blocks=params.num_res_blocks,
|
| 335 |
+
z_channels=params.z_channels,
|
| 336 |
+
)
|
| 337 |
+
self.reg = DiagonalGaussian()
|
| 338 |
+
|
| 339 |
+
self.scale_factor = params.scale_factor
|
| 340 |
+
self.shift_factor = params.shift_factor
|
| 341 |
+
|
| 342 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 343 |
+
z = self.reg(self.encoder(x))
|
| 344 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 345 |
+
return z
|
| 346 |
+
|
| 347 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 348 |
+
z = z / self.scale_factor + self.shift_factor
|
| 349 |
+
return self.decoder(z)
|
| 350 |
+
|
| 351 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 352 |
+
return self.decode(self.encode(x))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def load_ae(
|
| 356 |
+
ckpt_path: str,
|
| 357 |
+
autoencoder_params: AutoEncoderParams,
|
| 358 |
+
device: str | torch.device = "cuda",
|
| 359 |
+
dtype=torch.bfloat16,
|
| 360 |
+
) -> AutoEncoder:
|
| 361 |
+
"""
|
| 362 |
+
Load the autoencoder from the given model name.
|
| 363 |
+
Args:
|
| 364 |
+
name (str): The name of the autoencoder.
|
| 365 |
+
device (str or torch.device): The device to load the autoencoder to.
|
| 366 |
+
Returns:
|
| 367 |
+
AutoEncoder: The loaded autoencoder.
|
| 368 |
+
"""
|
| 369 |
+
# Loading the autoencoder
|
| 370 |
+
print("Init AE")
|
| 371 |
+
with torch.device(device):
|
| 372 |
+
ae = AutoEncoder(autoencoder_params)
|
| 373 |
+
|
| 374 |
+
if not os.path.exists(ckpt_path):
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Autoencoder path {ckpt_path} does not exist. Please download it first."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if ckpt_path is not None:
|
| 380 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 381 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 382 |
+
if len(missing) > 0:
|
| 383 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 384 |
+
if len(unexpected) > 0:
|
| 385 |
+
print(
|
| 386 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
| 387 |
+
)
|
| 388 |
+
return ae.to(dtype=dtype)
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FluxEmbedder(nn.Module):
|
| 12 |
+
def __init__(self, version: str, **hf_kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.is_clip = version.startswith("openai")
|
| 15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 16 |
+
|
| 17 |
+
if self.is_clip:
|
| 18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
| 19 |
+
version, **hf_kwargs
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
| 23 |
+
version, **hf_kwargs
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
| 29 |
+
"""
|
| 30 |
+
batch_tokens: [bsz, embedding_length]
|
| 31 |
+
|
| 32 |
+
For T5 Encoder, embeding_length is 768
|
| 33 |
+
For CLIP, embedding_length is 256
|
| 34 |
+
"""
|
| 35 |
+
outputs = self.hf_module(
|
| 36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
output_hidden_states=False,
|
| 39 |
+
)
|
| 40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/layers.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# imported from black-forest-labs/FLUX
|
| 8 |
+
import math
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.math import attention, rope
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmbedND(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.theta = theta
|
| 23 |
+
self.axes_dim = axes_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 26 |
+
n_axes = ids.shape[-1]
|
| 27 |
+
emb = torch.cat(
|
| 28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 29 |
+
dim=-3,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return emb.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 36 |
+
"""
|
| 37 |
+
Create sinusoidal timestep embeddings.
|
| 38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 39 |
+
These may be fractional.
|
| 40 |
+
:param dim: the dimension of the output.
|
| 41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
t = time_factor * t
|
| 45 |
+
half = dim // 2
|
| 46 |
+
freqs = torch.exp(
|
| 47 |
+
-math.log(max_period)
|
| 48 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 49 |
+
/ half
|
| 50 |
+
).to(t.device)
|
| 51 |
+
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
if torch.is_floating_point(t):
|
| 57 |
+
embedding = embedding.to(t)
|
| 58 |
+
return embedding
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MLPEmbedder(nn.Module):
|
| 62 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 65 |
+
self.silu = nn.SiLU()
|
| 66 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RMSNorm(torch.nn.Module):
|
| 73 |
+
def __init__(self, dim: int):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor):
|
| 78 |
+
x_dtype = x.dtype
|
| 79 |
+
x = x.float()
|
| 80 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 81 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class QKNorm(torch.nn.Module):
|
| 85 |
+
def __init__(self, dim: int):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
|
| 88 |
+
self.key_norm = RMSNorm(dim)
|
| 89 |
+
|
| 90 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 91 |
+
q = self.query_norm(q)
|
| 92 |
+
k = self.key_norm(k)
|
| 93 |
+
return q.to(v), k.to(v)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SelfAttention(nn.Module):
|
| 97 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_heads = num_heads
|
| 100 |
+
head_dim = dim // num_heads
|
| 101 |
+
|
| 102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 103 |
+
self.norm = QKNorm(head_dim)
|
| 104 |
+
self.proj = nn.Linear(dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 107 |
+
qkv = self.qkv(x)
|
| 108 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 109 |
+
q, k = self.norm(q, k, v)
|
| 110 |
+
x = attention(q, k, v, pe=pe)
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class ModulationOut:
|
| 117 |
+
shift: Tensor
|
| 118 |
+
scale: Tensor
|
| 119 |
+
gate: Tensor
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Modulation(nn.Module):
|
| 123 |
+
def __init__(self, dim: int, double: bool):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.is_double = double
|
| 126 |
+
self.multiplier = 6 if double else 3
|
| 127 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 128 |
+
|
| 129 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 130 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
| 131 |
+
self.multiplier, dim=-1
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return (
|
| 135 |
+
ModulationOut(*out[:3]),
|
| 136 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DoubleStreamBlock(nn.Module):
|
| 141 |
+
def __init__(
|
| 142 |
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
|
| 146 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 147 |
+
self.num_heads = num_heads
|
| 148 |
+
self.hidden_size = hidden_size
|
| 149 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 150 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 151 |
+
self.img_attn = SelfAttention(
|
| 152 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 156 |
+
self.img_mlp = nn.Sequential(
|
| 157 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 158 |
+
nn.GELU(approximate="tanh"),
|
| 159 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 163 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 164 |
+
self.txt_attn = SelfAttention(
|
| 165 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 169 |
+
self.txt_mlp = nn.Sequential(
|
| 170 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 171 |
+
nn.GELU(approximate="tanh"),
|
| 172 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
| 177 |
+
) -> tuple[Tensor, Tensor]:
|
| 178 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 179 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 180 |
+
|
| 181 |
+
# prepare image for attention
|
| 182 |
+
img_modulated = self.img_norm1(img)
|
| 183 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 184 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 185 |
+
img_q, img_k, img_v = rearrange(
|
| 186 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 187 |
+
)
|
| 188 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 189 |
+
|
| 190 |
+
# prepare txt for attention
|
| 191 |
+
txt_modulated = self.txt_norm1(txt)
|
| 192 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 193 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 194 |
+
txt_q, txt_k, txt_v = rearrange(
|
| 195 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 196 |
+
)
|
| 197 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 198 |
+
|
| 199 |
+
# run actual attention
|
| 200 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 201 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 202 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 203 |
+
|
| 204 |
+
attn = attention(q, k, v, pe=pe)
|
| 205 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 206 |
+
|
| 207 |
+
# calculate the img bloks
|
| 208 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 209 |
+
img = img + img_mod2.gate * self.img_mlp(
|
| 210 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# calculate the txt bloks
|
| 214 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 215 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
| 216 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
| 217 |
+
)
|
| 218 |
+
return img, txt
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class SingleStreamBlock(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
A DiT block with parallel linear layers as described in
|
| 224 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
hidden_size: int,
|
| 230 |
+
num_heads: int,
|
| 231 |
+
mlp_ratio: float = 4.0,
|
| 232 |
+
qk_scale: float | None = None,
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.hidden_dim = hidden_size
|
| 236 |
+
self.num_heads = num_heads
|
| 237 |
+
head_dim = hidden_size // num_heads
|
| 238 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 239 |
+
|
| 240 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 241 |
+
# qkv and mlp_in
|
| 242 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 243 |
+
# proj and mlp_out
|
| 244 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 245 |
+
|
| 246 |
+
self.norm = QKNorm(head_dim)
|
| 247 |
+
|
| 248 |
+
self.hidden_size = hidden_size
|
| 249 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 250 |
+
|
| 251 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 252 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 253 |
+
|
| 254 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 255 |
+
mod, _ = self.modulation(vec)
|
| 256 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 257 |
+
qkv, mlp = torch.split(
|
| 258 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 262 |
+
q, k = self.norm(q, k, v)
|
| 263 |
+
|
| 264 |
+
# compute attention
|
| 265 |
+
attn = attention(q, k, v, pe=pe)
|
| 266 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 267 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 268 |
+
return x + mod.gate * output
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class LastLayer(nn.Module):
|
| 272 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 275 |
+
self.linear = nn.Linear(
|
| 276 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
| 277 |
+
)
|
| 278 |
+
self.adaLN_modulation = nn.Sequential(
|
| 279 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 283 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 284 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 285 |
+
x = self.linear(x)
|
| 286 |
+
return x
|
torchtitan/experiments/flux/model/model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 16 |
+
from torchtitan.experiments.flux.model.layers import (
|
| 17 |
+
DoubleStreamBlock,
|
| 18 |
+
EmbedND,
|
| 19 |
+
LastLayer,
|
| 20 |
+
MLPEmbedder,
|
| 21 |
+
SingleStreamBlock,
|
| 22 |
+
timestep_embedding,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
|
| 26 |
+
from torchtitan.tools.logging import logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class FluxModelArgs(BaseModelArgs):
|
| 31 |
+
in_channels: int = 64
|
| 32 |
+
out_channels: int = 64
|
| 33 |
+
vec_in_dim: int = 768
|
| 34 |
+
context_in_dim: int = 512
|
| 35 |
+
hidden_size: int = 3072
|
| 36 |
+
mlp_ratio: float = 4.0
|
| 37 |
+
num_heads: int = 24
|
| 38 |
+
depth: int = 19
|
| 39 |
+
depth_single_blocks: int = 38
|
| 40 |
+
axes_dim: tuple = (16, 56, 56)
|
| 41 |
+
theta: int = 10_000
|
| 42 |
+
qkv_bias: bool = True
|
| 43 |
+
guidance_embed: bool = True
|
| 44 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
| 45 |
+
|
| 46 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 47 |
+
# context_in_dim is the same as the T5 embedding dimension
|
| 48 |
+
self.context_in_dim = job_config.encoder.max_t5_encoding_len
|
| 49 |
+
|
| 50 |
+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
|
| 51 |
+
# TODO(jianiw): Add the number of flops for the autoencoder
|
| 52 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 53 |
+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
|
| 54 |
+
return nparams, 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FluxModel(nn.Module, ModelProtocol):
|
| 58 |
+
"""
|
| 59 |
+
Transformer model for flow matching on sequences.
|
| 60 |
+
|
| 61 |
+
Agrs:
|
| 62 |
+
model_args: FluxModelArgs.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, model_args: FluxModelArgs):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.model_args = model_args
|
| 72 |
+
self.in_channels = model_args.in_channels
|
| 73 |
+
self.out_channels = model_args.out_channels
|
| 74 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
| 77 |
+
)
|
| 78 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
| 79 |
+
if sum(model_args.axes_dim) != pe_dim:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
| 82 |
+
)
|
| 83 |
+
self.hidden_size = model_args.hidden_size
|
| 84 |
+
self.num_heads = model_args.num_heads
|
| 85 |
+
self.pe_embedder = EmbedND(
|
| 86 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
| 87 |
+
)
|
| 88 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 89 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 90 |
+
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
|
| 91 |
+
self.guidance_in = (
|
| 92 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 93 |
+
if model_args.guidance_embed
|
| 94 |
+
else nn.Identity()
|
| 95 |
+
)
|
| 96 |
+
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
|
| 97 |
+
|
| 98 |
+
self.double_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
DoubleStreamBlock(
|
| 101 |
+
self.hidden_size,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 104 |
+
qkv_bias=model_args.qkv_bias,
|
| 105 |
+
)
|
| 106 |
+
for _ in range(model_args.depth)
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.single_blocks = nn.ModuleList(
|
| 111 |
+
[
|
| 112 |
+
SingleStreamBlock(
|
| 113 |
+
self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
|
| 114 |
+
)
|
| 115 |
+
for _ in range(model_args.depth_single_blocks)
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 120 |
+
|
| 121 |
+
def init_weights(self, buffer_device=None):
|
| 122 |
+
# TODO(jianiw): replace placeholder with real weight init
|
| 123 |
+
for param in self.parameters():
|
| 124 |
+
param.data.uniform_(0, 0.1)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
img: Tensor,
|
| 129 |
+
img_ids: Tensor,
|
| 130 |
+
txt: Tensor,
|
| 131 |
+
txt_ids: Tensor,
|
| 132 |
+
timesteps: Tensor,
|
| 133 |
+
y: Tensor,
|
| 134 |
+
guidance: Tensor | None = None,
|
| 135 |
+
) -> Tensor:
|
| 136 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 137 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 138 |
+
|
| 139 |
+
# running on sequences img
|
| 140 |
+
img = self.img_in(img)
|
| 141 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 142 |
+
if self.model_args.guidance_embed:
|
| 143 |
+
if guidance is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Didn't get guidance strength for guidance distilled model."
|
| 146 |
+
)
|
| 147 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 148 |
+
vec = vec + self.vector_in(y)
|
| 149 |
+
txt = self.txt_in(txt)
|
| 150 |
+
|
| 151 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 152 |
+
pe = self.pe_embedder(ids)
|
| 153 |
+
|
| 154 |
+
for block in self.double_blocks:
|
| 155 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 156 |
+
|
| 157 |
+
img = torch.cat((txt, img), 1)
|
| 158 |
+
for block in self.single_blocks:
|
| 159 |
+
img = block(img, vec=vec, pe=pe)
|
| 160 |
+
img = img[:, txt.shape[1] :, ...]
|
| 161 |
+
|
| 162 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 163 |
+
return img
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
|
| 167 |
+
"""
|
| 168 |
+
Initialize a Flux model from a FluxModelArgs object.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_args (FluxModelArgs): Model configuration arguments.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
FluxModel: FluxModel model.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
return cls(model_args)
|
torchtitan/experiments/flux/parallelize_flux.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 14 |
+
|
| 15 |
+
from torchtitan.config_manager import JobConfig
|
| 16 |
+
from torchtitan.distributed import ParallelDims
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parallelize_flux(
|
| 20 |
+
model: nn.Module,
|
| 21 |
+
world_mesh: DeviceMesh,
|
| 22 |
+
parallel_dims: ParallelDims,
|
| 23 |
+
job_config: JobConfig,
|
| 24 |
+
):
|
| 25 |
+
# TODO: Add model parallel strategy here
|
| 26 |
+
return model
|
torchtitan/experiments/flux/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
einops
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from torchtitan.config_manager import JobConfig
|
| 10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 11 |
+
from torchtitan.tools.profiling import (
|
| 12 |
+
maybe_enable_memory_snapshot,
|
| 13 |
+
maybe_enable_profiling,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFluxDataLoader:
|
| 18 |
+
def test_flux_dataloader(self):
|
| 19 |
+
dataset_name = "cc12m"
|
| 20 |
+
batch_size = 32
|
| 21 |
+
world_size = 4
|
| 22 |
+
rank = 0
|
| 23 |
+
|
| 24 |
+
num_steps = 10
|
| 25 |
+
|
| 26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
| 27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
| 28 |
+
config = JobConfig()
|
| 29 |
+
config.maybe_add_custom_args()
|
| 30 |
+
config.parse_args(
|
| 31 |
+
[
|
| 32 |
+
# Profiling options
|
| 33 |
+
# "--profiling.enable_profiling",
|
| 34 |
+
# "--profiling.profile_freq",
|
| 35 |
+
# "5",
|
| 36 |
+
# "--profiling.enable_memory_snapshot",
|
| 37 |
+
# "--profiling.save_memory_snapshot_folder",
|
| 38 |
+
# "memory_snapshot_flux",
|
| 39 |
+
"--training.dataset",
|
| 40 |
+
dataset_name,
|
| 41 |
+
"--training.batch_size",
|
| 42 |
+
str(batch_size),
|
| 43 |
+
"--encoder.t5_encoder",
|
| 44 |
+
"google/t5-v1_1-small",
|
| 45 |
+
"--encoder.clip_encoder",
|
| 46 |
+
"openai/clip-vit-large-patch14",
|
| 47 |
+
"--encoder.max_t5_encoding_len",
|
| 48 |
+
"512",
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with maybe_enable_profiling(
|
| 53 |
+
config, global_step=0
|
| 54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
| 55 |
+
config, global_step=0
|
| 56 |
+
) as memory_profiler:
|
| 57 |
+
dl = self._build_dataloader(
|
| 58 |
+
config,
|
| 59 |
+
world_size,
|
| 60 |
+
rank,
|
| 61 |
+
)
|
| 62 |
+
dl = iter(dl)
|
| 63 |
+
|
| 64 |
+
for i in range(0, num_steps):
|
| 65 |
+
input_data, labels = next(dl)
|
| 66 |
+
print(f"Step {i} image size: {labels.shape}")
|
| 67 |
+
if torch_profiler:
|
| 68 |
+
torch_profiler.step()
|
| 69 |
+
if memory_profiler:
|
| 70 |
+
memory_profiler.step()
|
| 71 |
+
|
| 72 |
+
print(len(input_data["clip_tokens"]))
|
| 73 |
+
for k, v in input_data.items():
|
| 74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
| 75 |
+
|
| 76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
| 77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
| 78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
| 79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
| 80 |
+
|
| 81 |
+
if torch_profiler:
|
| 82 |
+
torch_profiler.step()
|
| 83 |
+
if memory_profiler:
|
| 84 |
+
memory_profiler.step(exit_ctx=True)
|
| 85 |
+
|
| 86 |
+
def test_preprocess(self):
|
| 87 |
+
# TODO
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _build_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
job_config,
|
| 93 |
+
world_size,
|
| 94 |
+
rank,
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
return build_flux_dataloader(
|
| 98 |
+
dp_world_size=world_size,
|
| 99 |
+
dp_rank=rank,
|
| 100 |
+
job_config=job_config,
|
| 101 |
+
tokenizer=None,
|
| 102 |
+
infinite=False,
|
| 103 |
+
)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
from PIL import ExifTags, Image
|
| 16 |
+
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 20 |
+
|
| 21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
| 22 |
+
AutoEncoder,
|
| 23 |
+
AutoEncoderParams,
|
| 24 |
+
load_ae,
|
| 25 |
+
)
|
| 26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 27 |
+
|
| 28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
| 29 |
+
from torchtitan.experiments.flux.utils import (
|
| 30 |
+
create_position_encoding_for_latents,
|
| 31 |
+
generate_noise_latent,
|
| 32 |
+
pack_latents,
|
| 33 |
+
preprocess_flux_data,
|
| 34 |
+
unpack_latents,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_lin_function(
|
| 43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 44 |
+
) -> Callable[[float], float]:
|
| 45 |
+
m = (y2 - y1) / (x2 - x1)
|
| 46 |
+
b = y1 - m * x1
|
| 47 |
+
return lambda x: m * x + b
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_schedule(
|
| 51 |
+
num_steps: int,
|
| 52 |
+
image_seq_len: int,
|
| 53 |
+
base_shift: float = 0.5,
|
| 54 |
+
max_shift: float = 1.15,
|
| 55 |
+
shift: bool = True,
|
| 56 |
+
) -> list[float]:
|
| 57 |
+
# extra step for zero
|
| 58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 59 |
+
|
| 60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 61 |
+
if shift:
|
| 62 |
+
# estimate mu based on linear estimation between two points
|
| 63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 65 |
+
|
| 66 |
+
return timesteps.tolist()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TestGenerateImage:
|
| 70 |
+
def test_generate_image(self):
|
| 71 |
+
"""
|
| 72 |
+
Run a forward pass of flux model to generate an image.
|
| 73 |
+
"""
|
| 74 |
+
name = "flux-dev"
|
| 75 |
+
img_width = 512
|
| 76 |
+
img_height = 512
|
| 77 |
+
seed = None
|
| 78 |
+
prompt = (
|
| 79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
| 80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
| 81 |
+
)
|
| 82 |
+
device = "cuda"
|
| 83 |
+
num_steps = None
|
| 84 |
+
loop = False
|
| 85 |
+
guidance = 3.5
|
| 86 |
+
output_dir = "output"
|
| 87 |
+
add_sampling_metadata = True
|
| 88 |
+
|
| 89 |
+
prompt = prompt.split("|")
|
| 90 |
+
if len(prompt) == 1:
|
| 91 |
+
prompt = prompt[0]
|
| 92 |
+
additional_prompts = None
|
| 93 |
+
else:
|
| 94 |
+
additional_prompts = prompt[1:]
|
| 95 |
+
prompt = prompt[0]
|
| 96 |
+
|
| 97 |
+
assert not (
|
| 98 |
+
(additional_prompts is not None) and loop
|
| 99 |
+
), "Do not provide additional prompts and set loop to True"
|
| 100 |
+
|
| 101 |
+
torch_device = torch.device(device)
|
| 102 |
+
if num_steps is None:
|
| 103 |
+
num_steps = 30
|
| 104 |
+
|
| 105 |
+
# allow for packing and conversion to latent space
|
| 106 |
+
img_height = 16 * (img_height // 16)
|
| 107 |
+
img_width = 16 * (img_width // 16)
|
| 108 |
+
|
| 109 |
+
# init all components
|
| 110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
| 111 |
+
|
| 112 |
+
ae = load_ae(
|
| 113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
| 114 |
+
autoencoder_params=AutoEncoderParams(),
|
| 115 |
+
device=torch_device,
|
| 116 |
+
dtype=torch.bfloat16,
|
| 117 |
+
)
|
| 118 |
+
clip_tokenizer = FluxTokenizer(
|
| 119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
| 120 |
+
)
|
| 121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
| 122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
| 123 |
+
torch_device, dtype=torch.bfloat16
|
| 124 |
+
)
|
| 125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
| 126 |
+
torch_device, dtype=torch.bfloat16
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
rng = torch.Generator(device="cpu")
|
| 130 |
+
|
| 131 |
+
if seed is None:
|
| 132 |
+
seed = rng.seed()
|
| 133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
| 134 |
+
t0 = time.perf_counter()
|
| 135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
| 136 |
+
|
| 137 |
+
# Tokenize the prompt, on CPU
|
| 138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
| 139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
| 140 |
+
|
| 141 |
+
batch = preprocess_flux_data(
|
| 142 |
+
device=torch_device,
|
| 143 |
+
dtype=torch.bfloat16,
|
| 144 |
+
autoencoder=None,
|
| 145 |
+
clip_encoder=clip_encoder,
|
| 146 |
+
t5_encoder=t5_encoder,
|
| 147 |
+
batch={
|
| 148 |
+
"clip_tokens": clip_tokens,
|
| 149 |
+
"t5_tokens": t5_tokens,
|
| 150 |
+
},
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
img = self._generate_images(
|
| 154 |
+
device=torch_device,
|
| 155 |
+
dtype=torch.bfloat16,
|
| 156 |
+
model=model,
|
| 157 |
+
decoder=ae,
|
| 158 |
+
img_width=img_width,
|
| 159 |
+
img_height=img_height,
|
| 160 |
+
denoising_steps=num_steps,
|
| 161 |
+
seed=seed,
|
| 162 |
+
clip_encodings=batch["clip_encodings"],
|
| 163 |
+
t5_encodings=batch["t5_encodings"],
|
| 164 |
+
guidance=guidance,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
torch.cuda.synchronize()
|
| 169 |
+
t1 = time.perf_counter()
|
| 170 |
+
|
| 171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
| 172 |
+
|
| 173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
| 174 |
+
|
| 175 |
+
def _generate_images(
|
| 176 |
+
self,
|
| 177 |
+
device: torch.device,
|
| 178 |
+
dtype: torch.dtype,
|
| 179 |
+
model: FluxModel,
|
| 180 |
+
decoder: AutoEncoder,
|
| 181 |
+
# image params:
|
| 182 |
+
img_width: int,
|
| 183 |
+
img_height: int,
|
| 184 |
+
# sampling params:
|
| 185 |
+
denoising_steps: int,
|
| 186 |
+
seed: int,
|
| 187 |
+
clip_encodings: torch.Tensor,
|
| 188 |
+
t5_encodings: torch.Tensor,
|
| 189 |
+
guidance: float = 4.0,
|
| 190 |
+
):
|
| 191 |
+
|
| 192 |
+
bsz = clip_encodings.shape[0]
|
| 193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
| 194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
| 195 |
+
|
| 196 |
+
# create denoising schedule
|
| 197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
| 198 |
+
|
| 199 |
+
# create positional encodings
|
| 200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
| 201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
| 202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
| 203 |
+
).to(latents)
|
| 204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
| 205 |
+
|
| 206 |
+
# convert img-like latents into sequences of patches
|
| 207 |
+
latents = pack_latents(latents)
|
| 208 |
+
|
| 209 |
+
# this is ignored for schnell
|
| 210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
| 211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
| 213 |
+
pred = model(
|
| 214 |
+
img=latents,
|
| 215 |
+
img_ids=latent_pos_enc,
|
| 216 |
+
txt=t5_encodings,
|
| 217 |
+
txt_ids=text_pos_enc,
|
| 218 |
+
y=clip_encodings,
|
| 219 |
+
timesteps=t_vec,
|
| 220 |
+
guidance=guidance_vec,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
latents = latents + (t_prev - t_curr) * pred
|
| 224 |
+
|
| 225 |
+
# convert sequences of patches into img-like latents
|
| 226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
| 227 |
+
|
| 228 |
+
img = decoder.decode(latents)
|
| 229 |
+
return img
|
| 230 |
+
|
| 231 |
+
def _save_image(
|
| 232 |
+
self,
|
| 233 |
+
name: str,
|
| 234 |
+
output_name: str,
|
| 235 |
+
x: torch.Tensor,
|
| 236 |
+
add_sampling_metadata: bool,
|
| 237 |
+
prompt: str,
|
| 238 |
+
):
|
| 239 |
+
print(f"Saving {output_name}")
|
| 240 |
+
# bring into PIL format and save
|
| 241 |
+
x = x.clamp(-1, 1)
|
| 242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
| 243 |
+
|
| 244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 245 |
+
|
| 246 |
+
exif_data = Image.Exif()
|
| 247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
| 248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
| 249 |
+
exif_data[ExifTags.Base.Model] = name
|
| 250 |
+
if add_sampling_metadata:
|
| 251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
| 252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# from typing import Dict, List, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import triton
|
| 22 |
+
|
| 23 |
+
# import triton.language as tl
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Try to import the optimized implementations
|
| 31 |
+
try:
|
| 32 |
+
from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
|
| 33 |
+
|
| 34 |
+
except ImportError:
|
| 35 |
+
logging.error(
|
| 36 |
+
"Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
|
| 37 |
+
)
|
| 38 |
+
raise
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_reference_forward(x, w, m_sizes):
|
| 42 |
+
"""
|
| 43 |
+
Reference PyTorch implementation of M*G grouped GEMM forward pass.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
| 47 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
| 48 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output tensor of shape (M, N)
|
| 52 |
+
"""
|
| 53 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
| 54 |
+
|
| 55 |
+
m_start = 0
|
| 56 |
+
for g in range(len(m_sizes)):
|
| 57 |
+
m_size = m_sizes[g].item()
|
| 58 |
+
if m_size > 0:
|
| 59 |
+
m_end = m_start + m_size
|
| 60 |
+
|
| 61 |
+
# Extract group input
|
| 62 |
+
x_g = x[m_start:m_end]
|
| 63 |
+
|
| 64 |
+
# Compute group output
|
| 65 |
+
y_g = torch.matmul(x_g, w.T)
|
| 66 |
+
|
| 67 |
+
# Store result
|
| 68 |
+
result[m_start:m_end] = y_g
|
| 69 |
+
|
| 70 |
+
# Update start index
|
| 71 |
+
m_start = m_end
|
| 72 |
+
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@triton.testing.perf_report(
|
| 77 |
+
triton.testing.Benchmark(
|
| 78 |
+
x_names=["N"], # We'll vary the output dimension
|
| 79 |
+
x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
|
| 80 |
+
# x_vals=[8192, 16384],
|
| 81 |
+
line_arg="provider", # We'll compare different providers
|
| 82 |
+
line_vals=["pytorch_reference", "M*G grouped GEMM"],
|
| 83 |
+
line_names=["PyTorch Reference", "M*G grouped Kernel"],
|
| 84 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 85 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 86 |
+
plot_name="mg_grouped_gemm_comparison",
|
| 87 |
+
args={
|
| 88 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 89 |
+
"K": 7168, # Hidden dimension, fixed for all tests
|
| 90 |
+
"G": 8, # Number of groups
|
| 91 |
+
"dtype": torch.float16,
|
| 92 |
+
"device": "cuda",
|
| 93 |
+
},
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
| 97 |
+
"""
|
| 98 |
+
Benchmark the forward pass of the grouped GEMM implementation.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
M (int): Total batch size dimension
|
| 102 |
+
K (int): Hidden dimension
|
| 103 |
+
N (int): Output dimension
|
| 104 |
+
G (int): Number of groups
|
| 105 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 106 |
+
dtype (torch.dtype): Data type to use
|
| 107 |
+
device (str): Device to use
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
float: Performance in TFLOPS
|
| 111 |
+
"""
|
| 112 |
+
# Create group sizes for M dimension (balanced across groups)
|
| 113 |
+
base_size = M // G
|
| 114 |
+
remainder = M % G
|
| 115 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 116 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 117 |
+
|
| 118 |
+
print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
|
| 119 |
+
|
| 120 |
+
# Create input and weight tensors
|
| 121 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 122 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 123 |
+
|
| 124 |
+
# Pre-compute for PyTorch reference to ensure fair comparison
|
| 125 |
+
if provider == "pytorch_reference":
|
| 126 |
+
# Warmup
|
| 127 |
+
torch.cuda.synchronize()
|
| 128 |
+
compute_reference_forward(x, w, m_sizes)
|
| 129 |
+
torch.cuda.synchronize()
|
| 130 |
+
|
| 131 |
+
# Benchmark
|
| 132 |
+
start_time = time.time()
|
| 133 |
+
for _ in range(10): # Average over 10 runs
|
| 134 |
+
compute_reference_forward(x, w, m_sizes)
|
| 135 |
+
torch.cuda.synchronize()
|
| 136 |
+
end_time = time.time()
|
| 137 |
+
else: # Optimized kernel
|
| 138 |
+
# Warmup
|
| 139 |
+
torch.cuda.synchronize()
|
| 140 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 141 |
+
torch.cuda.synchronize()
|
| 142 |
+
|
| 143 |
+
# Benchmark
|
| 144 |
+
start_time = time.time()
|
| 145 |
+
for _ in range(10): # Average over 10 runs
|
| 146 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
end_time = time.time()
|
| 149 |
+
|
| 150 |
+
# Calculate FLOPs
|
| 151 |
+
# For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
|
| 152 |
+
flops = 2 * M * N * K
|
| 153 |
+
|
| 154 |
+
# Convert to TFLOPS (tera-FLOPS)
|
| 155 |
+
avg_time = (end_time - start_time) / 10 # Average time per run
|
| 156 |
+
tflops = flops / avg_time / 1e12
|
| 157 |
+
|
| 158 |
+
return tflops
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@triton.testing.perf_report(
|
| 162 |
+
triton.testing.Benchmark(
|
| 163 |
+
x_names=["G"], # We'll vary the number of groups
|
| 164 |
+
x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
|
| 165 |
+
line_arg="provider", # We'll compare different providers
|
| 166 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
| 167 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
| 168 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 169 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 170 |
+
plot_name="mg_grouped_gemm_group_scaling",
|
| 171 |
+
args={
|
| 172 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 173 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
| 174 |
+
"N": 8192, # Output dimension, fixed for all tests
|
| 175 |
+
"dtype": torch.float16,
|
| 176 |
+
"device": "cuda",
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
| 181 |
+
"""
|
| 182 |
+
Benchmark how performance scales with number of groups.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
M (int): Total batch size dimension
|
| 186 |
+
K (int): Hidden dimension
|
| 187 |
+
N (int): Output dimension
|
| 188 |
+
G (int): Number of groups
|
| 189 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 190 |
+
dtype (torch.dtype): Data type to use
|
| 191 |
+
device (str): Device to use
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
float: Performance in TFLOPS
|
| 195 |
+
"""
|
| 196 |
+
# Create group sizes for M dimension (balanced across groups)
|
| 197 |
+
base_size = M // G
|
| 198 |
+
remainder = M % G
|
| 199 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 200 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 201 |
+
|
| 202 |
+
# Create input and weight tensors
|
| 203 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 204 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 205 |
+
|
| 206 |
+
# Benchmark logic - same as previous function
|
| 207 |
+
if provider == "pytorch_reference":
|
| 208 |
+
torch.cuda.synchronize()
|
| 209 |
+
compute_reference_forward(x, w, m_sizes)
|
| 210 |
+
torch.cuda.synchronize()
|
| 211 |
+
|
| 212 |
+
start_time = time.time()
|
| 213 |
+
for _ in range(10):
|
| 214 |
+
compute_reference_forward(x, w, m_sizes)
|
| 215 |
+
torch.cuda.synchronize()
|
| 216 |
+
end_time = time.time()
|
| 217 |
+
else:
|
| 218 |
+
torch.cuda.synchronize()
|
| 219 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 220 |
+
torch.cuda.synchronize()
|
| 221 |
+
|
| 222 |
+
start_time = time.time()
|
| 223 |
+
for _ in range(10):
|
| 224 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 225 |
+
torch.cuda.synchronize()
|
| 226 |
+
end_time = time.time()
|
| 227 |
+
|
| 228 |
+
# Calculate FLOPs and TFLOPS
|
| 229 |
+
flops = 2 * M * N * K
|
| 230 |
+
avg_time = (end_time - start_time) / 10
|
| 231 |
+
tflops = flops / avg_time / 1e12
|
| 232 |
+
|
| 233 |
+
return tflops
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@triton.testing.perf_report(
|
| 237 |
+
triton.testing.Benchmark(
|
| 238 |
+
x_names=["group_balance"], # We'll vary the group balance factor
|
| 239 |
+
x_vals=[
|
| 240 |
+
0.0,
|
| 241 |
+
0.25,
|
| 242 |
+
0.5,
|
| 243 |
+
0.75,
|
| 244 |
+
0.9,
|
| 245 |
+
], # Different imbalance factors (0 = balanced, 1 = max imbalance)
|
| 246 |
+
line_arg="provider", # We'll compare different providers
|
| 247 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
| 248 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
| 249 |
+
styles=[("blue", "-"), ("red", "-")],
|
| 250 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
| 251 |
+
plot_name="mg_grouped_gemm_imbalance",
|
| 252 |
+
args={
|
| 253 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
| 254 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
| 255 |
+
"N": 8192, # Output dimension, fixed for all tests
|
| 256 |
+
"G": 4, # Number of groups
|
| 257 |
+
"dtype": torch.float16,
|
| 258 |
+
"device": "cuda",
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
def benchmark_imbalance(
|
| 263 |
+
M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Benchmark how performance is affected by imbalanced group sizes.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
M (int): Total batch size dimension
|
| 270 |
+
K (int): Hidden dimension
|
| 271 |
+
N (int): Output dimension
|
| 272 |
+
G (int): Number of groups
|
| 273 |
+
group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
|
| 274 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
| 275 |
+
dtype (torch.dtype): Data type to use
|
| 276 |
+
device (str): Device to use
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
float: Performance in TFLOPS
|
| 280 |
+
"""
|
| 281 |
+
# Create imbalanced group sizes for M dimension
|
| 282 |
+
if group_balance == 0:
|
| 283 |
+
# Balanced case
|
| 284 |
+
base_size = M // G
|
| 285 |
+
remainder = M % G
|
| 286 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 287 |
+
else:
|
| 288 |
+
# Imbalanced case
|
| 289 |
+
# First group gets more elements, last group gets fewer
|
| 290 |
+
# The imbalance is controlled by the group_balance factor
|
| 291 |
+
remaining = M
|
| 292 |
+
M_sizes = []
|
| 293 |
+
for g in range(G):
|
| 294 |
+
# Interpolate from balanced to imbalanced based on group_balance
|
| 295 |
+
# For balanced (group_balance=0), each group gets M/G
|
| 296 |
+
# For imbalanced (group_balance=1), first group gets much more than last group
|
| 297 |
+
balanced_size = remaining // (G - g)
|
| 298 |
+
|
| 299 |
+
# Adjusting size based on position and imbalance factor
|
| 300 |
+
# First groups get more, last groups get less
|
| 301 |
+
if g < G // 2:
|
| 302 |
+
# First half of groups get more
|
| 303 |
+
adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
|
| 304 |
+
size = balanced_size + adjustment
|
| 305 |
+
else:
|
| 306 |
+
# Second half of groups get less
|
| 307 |
+
adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
|
| 308 |
+
size = balanced_size - adjustment
|
| 309 |
+
|
| 310 |
+
# Ensure we don't go below 1 or take more than remaining
|
| 311 |
+
size = max(1, min(size, remaining))
|
| 312 |
+
M_sizes.append(size)
|
| 313 |
+
remaining -= size
|
| 314 |
+
|
| 315 |
+
# Handle any remaining elements
|
| 316 |
+
if remaining > 0:
|
| 317 |
+
M_sizes[-1] += remaining
|
| 318 |
+
|
| 319 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 320 |
+
|
| 321 |
+
# Create input and weight tensors
|
| 322 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 323 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 324 |
+
|
| 325 |
+
# Benchmark logic
|
| 326 |
+
if provider == "pytorch_reference":
|
| 327 |
+
torch.cuda.synchronize()
|
| 328 |
+
compute_reference_forward(x, w, m_sizes)
|
| 329 |
+
torch.cuda.synchronize()
|
| 330 |
+
|
| 331 |
+
start_time = time.time()
|
| 332 |
+
for _ in range(10):
|
| 333 |
+
compute_reference_forward(x, w, m_sizes)
|
| 334 |
+
torch.cuda.synchronize()
|
| 335 |
+
end_time = time.time()
|
| 336 |
+
else:
|
| 337 |
+
torch.cuda.synchronize()
|
| 338 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 339 |
+
torch.cuda.synchronize()
|
| 340 |
+
|
| 341 |
+
start_time = time.time()
|
| 342 |
+
for _ in range(10):
|
| 343 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 344 |
+
torch.cuda.synchronize()
|
| 345 |
+
end_time = time.time()
|
| 346 |
+
|
| 347 |
+
# Calculate FLOPs and TFLOPS
|
| 348 |
+
flops = 2 * M * N * K
|
| 349 |
+
avg_time = (end_time - start_time) / 10
|
| 350 |
+
tflops = flops / avg_time / 1e12
|
| 351 |
+
|
| 352 |
+
return tflops
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def benchmark_model_configs():
|
| 356 |
+
"""
|
| 357 |
+
Benchmark common model configurations used in DeepSeek-like models.
|
| 358 |
+
"""
|
| 359 |
+
# Model configurations: (M, K, N, G)
|
| 360 |
+
configs = [
|
| 361 |
+
(8192, 7168, 4096, 4), # Config 1
|
| 362 |
+
(8192, 2048, 7168, 4), # Config 2
|
| 363 |
+
(4096, 7168, 4096, 8), # Config 3
|
| 364 |
+
(4096, 2048, 7168, 8), # Config 4
|
| 365 |
+
]
|
| 366 |
+
|
| 367 |
+
results = []
|
| 368 |
+
|
| 369 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 370 |
+
dtype = torch.float16
|
| 371 |
+
|
| 372 |
+
for config_idx, (M, K, N, G) in enumerate(configs):
|
| 373 |
+
logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
|
| 374 |
+
logging.info(f"M={M}, K={K}, N={N}, G={G}")
|
| 375 |
+
|
| 376 |
+
# Create group sizes for M dimension
|
| 377 |
+
base_size = M // G
|
| 378 |
+
remainder = M % G
|
| 379 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
| 380 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
| 381 |
+
|
| 382 |
+
# Create tensors
|
| 383 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
| 384 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
| 385 |
+
|
| 386 |
+
# Benchmark PyTorch reference
|
| 387 |
+
torch.cuda.synchronize()
|
| 388 |
+
compute_reference_forward(x, w, m_sizes) # Warmup
|
| 389 |
+
torch.cuda.synchronize()
|
| 390 |
+
|
| 391 |
+
logging.info("Benchmarking PyTorch reference...")
|
| 392 |
+
torch.cuda.reset_peak_memory_stats()
|
| 393 |
+
start_time = time.time()
|
| 394 |
+
for _ in range(10):
|
| 395 |
+
compute_reference_forward(x, w, m_sizes)
|
| 396 |
+
torch.cuda.synchronize()
|
| 397 |
+
end_time = time.time()
|
| 398 |
+
pt_time = (end_time - start_time) / 10
|
| 399 |
+
pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
| 400 |
+
|
| 401 |
+
# Benchmark optimized kernel
|
| 402 |
+
torch.cuda.synchronize()
|
| 403 |
+
grouped_gemm_forward(x, w, m_sizes) # Warmup
|
| 404 |
+
torch.cuda.synchronize()
|
| 405 |
+
|
| 406 |
+
logging.info("Benchmarking optimized kernel...")
|
| 407 |
+
torch.cuda.reset_peak_memory_stats()
|
| 408 |
+
start_time = time.time()
|
| 409 |
+
for _ in range(10):
|
| 410 |
+
grouped_gemm_forward(x, w, m_sizes)
|
| 411 |
+
torch.cuda.synchronize()
|
| 412 |
+
end_time = time.time()
|
| 413 |
+
opt_time = (end_time - start_time) / 10
|
| 414 |
+
opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
| 415 |
+
|
| 416 |
+
# Calculate FLOPs and speedup
|
| 417 |
+
flops = 2 * M * N * K
|
| 418 |
+
pt_tflops = flops / pt_time / 1e12
|
| 419 |
+
opt_tflops = flops / opt_time / 1e12
|
| 420 |
+
speedup = pt_time / opt_time
|
| 421 |
+
|
| 422 |
+
# Store results
|
| 423 |
+
results.append(
|
| 424 |
+
{
|
| 425 |
+
"config": f"Config {config_idx + 1}",
|
| 426 |
+
"dimensions": f"M={M}, K={K}, N={N}, G={G}",
|
| 427 |
+
"pt_time_ms": pt_time * 1000,
|
| 428 |
+
"opt_time_ms": opt_time * 1000,
|
| 429 |
+
"pt_tflops": pt_tflops,
|
| 430 |
+
"opt_tflops": opt_tflops,
|
| 431 |
+
"speedup": speedup,
|
| 432 |
+
"pt_memory_mb": pt_memory,
|
| 433 |
+
"opt_memory_mb": opt_memory,
|
| 434 |
+
"memory_savings": (
|
| 435 |
+
(pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
|
| 436 |
+
),
|
| 437 |
+
}
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
logging.info(
|
| 441 |
+
f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
|
| 442 |
+
)
|
| 443 |
+
logging.info(
|
| 444 |
+
f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
|
| 445 |
+
)
|
| 446 |
+
logging.info(
|
| 447 |
+
f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Print summary table
|
| 451 |
+
logging.info("\n===== Benchmark Results Summary =====")
|
| 452 |
+
logging.info(
|
| 453 |
+
f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
|
| 454 |
+
)
|
| 455 |
+
logging.info(
|
| 456 |
+
f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
|
| 457 |
+
f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
|
| 458 |
+
)
|
| 459 |
+
logging.info("-" * 100)
|
| 460 |
+
|
| 461 |
+
for result in results:
|
| 462 |
+
logging.info(
|
| 463 |
+
f"{result['config']:<10} | "
|
| 464 |
+
f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
|
| 465 |
+
f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
|
| 466 |
+
f"{result['speedup']:<10.2f} | "
|
| 467 |
+
f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
|
| 468 |
+
f"{result['memory_savings']:<12.2f}%"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
return results
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def plot_benchmark_results(results):
|
| 475 |
+
"""
|
| 476 |
+
Plot benchmark results as bar charts.
|
| 477 |
+
"""
|
| 478 |
+
# Extract data
|
| 479 |
+
configs = [r["config"] for r in results]
|
| 480 |
+
pt_tflops = [r["pt_tflops"] for r in results]
|
| 481 |
+
opt_tflops = [r["opt_tflops"] for r in results]
|
| 482 |
+
speedups = [r["speedup"] for r in results]
|
| 483 |
+
|
| 484 |
+
# Create figure with subplots
|
| 485 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 486 |
+
|
| 487 |
+
# Plot TFLOPS comparison
|
| 488 |
+
x = np.arange(len(configs))
|
| 489 |
+
width = 0.35
|
| 490 |
+
ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
|
| 491 |
+
ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
|
| 492 |
+
ax1.set_xlabel("Model Configuration")
|
| 493 |
+
ax1.set_ylabel("TFLOPS")
|
| 494 |
+
ax1.set_title("Performance Comparison (Higher is Better)")
|
| 495 |
+
ax1.set_xticks(x)
|
| 496 |
+
ax1.set_xticklabels(configs)
|
| 497 |
+
ax1.legend()
|
| 498 |
+
ax1.grid(axis="y", linestyle="--", alpha=0.7)
|
| 499 |
+
|
| 500 |
+
# Plot speedup
|
| 501 |
+
ax2.bar(x, speedups, width=0.6, color="green")
|
| 502 |
+
ax2.set_xlabel("Model Configuration")
|
| 503 |
+
ax2.set_ylabel("Speedup (x)")
|
| 504 |
+
ax2.set_title("Speedup Factor (Higher is Better)")
|
| 505 |
+
ax2.set_xticks(x)
|
| 506 |
+
ax2.set_xticklabels(configs)
|
| 507 |
+
ax2.grid(axis="y", linestyle="--", alpha=0.7)
|
| 508 |
+
|
| 509 |
+
# Add speedup values on top of bars
|
| 510 |
+
for i, v in enumerate(speedups):
|
| 511 |
+
ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
|
| 512 |
+
|
| 513 |
+
plt.tight_layout()
|
| 514 |
+
plt.savefig("mg_grouped_gemm_benchmark_results.png")
|
| 515 |
+
logging.info(
|
| 516 |
+
"Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def compare_mg_implementations():
|
| 521 |
+
"""
|
| 522 |
+
Combine the M*G and N*G benchmark results for comparison.
|
| 523 |
+
"""
|
| 524 |
+
# Only run this if both NG and MG benchmarks have been run
|
| 525 |
+
try:
|
| 526 |
+
import pandas as pd
|
| 527 |
+
|
| 528 |
+
# Try to load previous benchmark results
|
| 529 |
+
mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
|
| 530 |
+
ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
|
| 531 |
+
|
| 532 |
+
# Create comparison plot
|
| 533 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 534 |
+
|
| 535 |
+
# Plot speedup comparison
|
| 536 |
+
configs = mg_results["config"].unique()
|
| 537 |
+
mg_speedups = mg_results.groupby("config")["speedup"].mean()
|
| 538 |
+
ng_speedups = ng_results.groupby("config")["speedup"].mean()
|
| 539 |
+
|
| 540 |
+
x = np.arange(len(configs))
|
| 541 |
+
width = 0.35
|
| 542 |
+
|
| 543 |
+
axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
|
| 544 |
+
axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
|
| 545 |
+
axes[0].set_xlabel("Model Configuration")
|
| 546 |
+
axes[0].set_ylabel("Speedup (x)")
|
| 547 |
+
axes[0].set_title("Speedup Comparison: M*G vs N*G")
|
| 548 |
+
axes[0].set_xticks(x)
|
| 549 |
+
axes[0].set_xticklabels(configs)
|
| 550 |
+
axes[0].legend()
|
| 551 |
+
axes[0].grid(axis="y", linestyle="--", alpha=0.7)
|
| 552 |
+
|
| 553 |
+
# Plot TFLOPS comparison for optimized kernels
|
| 554 |
+
mg_tflops = (
|
| 555 |
+
mg_results[mg_results["implementation"] == "optimized"]
|
| 556 |
+
.groupby("config")["tflops"]
|
| 557 |
+
.mean()
|
| 558 |
+
)
|
| 559 |
+
ng_tflops = (
|
| 560 |
+
ng_results[ng_results["implementation"] == "optimized"]
|
| 561 |
+
.groupby("config")["tflops"]
|
| 562 |
+
.mean()
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
|
| 566 |
+
axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
|
| 567 |
+
axes[1].set_xlabel("Model Configuration")
|
| 568 |
+
axes[1].set_ylabel("TFLOPS")
|
| 569 |
+
axes[1].set_title("Performance Comparison: M*G vs N*G")
|
| 570 |
+
axes[1].set_xticks(x)
|
| 571 |
+
axes[1].set_xticklabels(configs)
|
| 572 |
+
axes[1].legend()
|
| 573 |
+
axes[1].grid(axis="y", linestyle="--", alpha=0.7)
|
| 574 |
+
|
| 575 |
+
plt.tight_layout()
|
| 576 |
+
plt.savefig("mg_vs_ng_comparison.png")
|
| 577 |
+
logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
|
| 578 |
+
|
| 579 |
+
except Exception as e:
|
| 580 |
+
logging.error(f"Could not create comparison plot: {e}")
|
| 581 |
+
logging.info(
|
| 582 |
+
"Run both M*G and N*G benchmarks first to generate comparison plots"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
if __name__ == "__main__":
|
| 587 |
+
parser = argparse.ArgumentParser(
|
| 588 |
+
description="Benchmark M*G Grouped GEMM implementations"
|
| 589 |
+
)
|
| 590 |
+
parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
|
| 591 |
+
parser.add_argument(
|
| 592 |
+
"--triton-bench", action="store_true", help="Run Triton performance reports"
|
| 593 |
+
)
|
| 594 |
+
parser.add_argument(
|
| 595 |
+
"--model-configs", action="store_true", help="Benchmark model configurations"
|
| 596 |
+
)
|
| 597 |
+
parser.add_argument(
|
| 598 |
+
"--compare-mg-ng",
|
| 599 |
+
action="store_true",
|
| 600 |
+
help="Compare M*G and N*G implementations",
|
| 601 |
+
)
|
| 602 |
+
args = parser.parse_args()
|
| 603 |
+
|
| 604 |
+
# Check if CUDA is available
|
| 605 |
+
if not torch.cuda.is_available():
|
| 606 |
+
logging.error(
|
| 607 |
+
"CUDA is not available. This benchmark requires a CUDA-capable GPU."
|
| 608 |
+
)
|
| 609 |
+
exit(1)
|
| 610 |
+
|
| 611 |
+
if args.run_all or args.model_configs:
|
| 612 |
+
# Benchmark model configurations
|
| 613 |
+
logging.info("Running benchmark for model configurations...")
|
| 614 |
+
results = benchmark_model_configs()
|
| 615 |
+
plot_benchmark_results(results)
|
| 616 |
+
|
| 617 |
+
if args.run_all or args.triton_bench:
|
| 618 |
+
# Run Triton performance reports
|
| 619 |
+
logging.info("Running Triton performance reports...")
|
| 620 |
+
benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 621 |
+
benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 622 |
+
benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
|
| 623 |
+
logging.info(
|
| 624 |
+
"Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
if args.run_all or args.compare_mg_ng:
|
| 628 |
+
# Compare M*G and N*G implementations
|
| 629 |
+
logging.info("Comparing M*G and N*G implementations...")
|
| 630 |
+
compare_mg_implementations()
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mg_grouped_gemm import grouped_gemm_forward
|
| 8 |
+
from .tma_autotuning import ALIGN_SIZE_M
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"grouped_gemm_forward",
|
| 12 |
+
"ALIGN_SIZE_M",
|
| 13 |
+
]
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
| 8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Any, Dict, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
from triton import Config as TConfig
|
| 22 |
+
|
| 23 |
+
from triton.runtime import driver # @manual
|
| 24 |
+
|
| 25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===== Supporting utils, CUDA and TMA =====
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CudaUtils:
|
| 32 |
+
@staticmethod
|
| 33 |
+
def is_cuda() -> bool:
|
| 34 |
+
"""Check if Triton is running on CUDA backend."""
|
| 35 |
+
return driver.active.get_current_target().backend == "cuda"
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def verify_tma() -> bool:
|
| 39 |
+
"""Check if TMA is supported on the current device."""
|
| 40 |
+
return (
|
| 41 |
+
CudaUtils.is_cuda()
|
| 42 |
+
and torch.cuda.is_available()
|
| 43 |
+
and torch.cuda.get_device_capability()[0] >= 9
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_num_sms() -> int:
|
| 48 |
+
"""Get the number of streaming multiprocessors on the current device."""
|
| 49 |
+
if not CudaUtils.is_cuda():
|
| 50 |
+
raise RuntimeError("Triton is not running on CUDA backend")
|
| 51 |
+
if not torch.cuda.is_available():
|
| 52 |
+
raise RuntimeError("CUDA is not available")
|
| 53 |
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TmaDescriptorHelper:
|
| 57 |
+
"""Helper class for managing TMA descriptors in Triton kernels."""
|
| 58 |
+
|
| 59 |
+
class KernelParamWrapper:
|
| 60 |
+
"""Wrapper to implement the TmaDescKernelParam interface."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, desc: torch.Tensor):
|
| 63 |
+
self.desc = desc
|
| 64 |
+
|
| 65 |
+
def tma_desc_cpu_ptr(self) -> int:
|
| 66 |
+
"""Return the CPU pointer to the TMA descriptor."""
|
| 67 |
+
return self.desc.data_ptr()
|
| 68 |
+
|
| 69 |
+
def __init__(self, tma_size: int = 128):
|
| 70 |
+
"""Initialize the TMA descriptor helper.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tma_size: Size of the TMA descriptor in bytes
|
| 74 |
+
"""
|
| 75 |
+
if not CudaUtils.verify_tma():
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
"TMA not supported on this device (requires Hopper or newer)"
|
| 78 |
+
)
|
| 79 |
+
if "nv_tma_desc_type" not in dir(tl):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"TMA grid constant descriptors not supported in your Triton version"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.tma_size = tma_size
|
| 85 |
+
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
| 86 |
+
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
| 87 |
+
self.descriptors: Dict[str, torch.Tensor] = {}
|
| 88 |
+
|
| 89 |
+
def init_tma_descriptor(self, name: str) -> None:
|
| 90 |
+
"""Initialize a TMA descriptor with the given name.
|
| 91 |
+
|
| 92 |
+
Call this method outside of the lambda function for grid size.
|
| 93 |
+
"""
|
| 94 |
+
self.descriptors[name] = torch.empty(
|
| 95 |
+
self.tma_size, device="cpu", dtype=torch.int8
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def fill_1d_tma_descriptor(
|
| 99 |
+
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Fill a 1D TMA descriptor.
|
| 102 |
+
|
| 103 |
+
Call this method inside the lambda function for grid size.
|
| 104 |
+
"""
|
| 105 |
+
if name not in self.descriptors:
|
| 106 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 107 |
+
|
| 108 |
+
desc_x = self.descriptors[name]
|
| 109 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 110 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 111 |
+
self.fill_1d_tma_descriptor_inner(
|
| 112 |
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def fill_2d_tma_descriptor(
|
| 116 |
+
self,
|
| 117 |
+
name: str,
|
| 118 |
+
ptr: int,
|
| 119 |
+
dim1: int,
|
| 120 |
+
dim0: int,
|
| 121 |
+
block_dim1: int,
|
| 122 |
+
block_dim0: int,
|
| 123 |
+
element_size: int,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""Fill a 2D TMA descriptor.
|
| 126 |
+
|
| 127 |
+
Call this method inside the lambda function for grid size.
|
| 128 |
+
"""
|
| 129 |
+
if name not in self.descriptors:
|
| 130 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 131 |
+
|
| 132 |
+
desc_x = self.descriptors[name]
|
| 133 |
+
if desc_x.data_ptr() % 64 != 0:
|
| 134 |
+
raise ValueError("TMA descriptor must be 64-byte aligned")
|
| 135 |
+
self.fill_2d_tma_descriptor_inner(
|
| 136 |
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
| 140 |
+
"""Get the TMA descriptor kernel parameter for the given name."""
|
| 141 |
+
if name not in self.descriptors or self.descriptors[name] is None:
|
| 142 |
+
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
| 143 |
+
return self.KernelParamWrapper(self.descriptors[name])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ====== Autotuning utilities ======
|
| 147 |
+
ALIGN_SIZE_M = 128
|
| 148 |
+
|
| 149 |
+
_NV_CONFIGS = [
|
| 150 |
+
triton.Config(
|
| 151 |
+
{
|
| 152 |
+
"BLOCK_SIZE_M": block_size_m,
|
| 153 |
+
"BLOCK_SIZE_N": block_size_n,
|
| 154 |
+
"BLOCK_SIZE_K": block_size_k,
|
| 155 |
+
},
|
| 156 |
+
num_stages=num_stages,
|
| 157 |
+
num_warps=num_warps,
|
| 158 |
+
num_ctas=num_ctas,
|
| 159 |
+
)
|
| 160 |
+
for block_size_m in [ALIGN_SIZE_M, ]
|
| 161 |
+
for block_size_n in [64, 128, 256]
|
| 162 |
+
for block_size_k in [64, 128, 256]
|
| 163 |
+
for num_stages in [3, 4]
|
| 164 |
+
for num_warps in [4, 8]
|
| 165 |
+
for num_ctas in [1]
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
| 170 |
+
device = torch.cuda.current_device()
|
| 171 |
+
# Check for all possible pointer parameter names
|
| 172 |
+
if "grad_input_ptr" in named_args:
|
| 173 |
+
ptr_name = "grad_input_ptr"
|
| 174 |
+
elif "c_ptr" in named_args:
|
| 175 |
+
ptr_name = "c_ptr"
|
| 176 |
+
elif "grad_weight_ptr" in named_args:
|
| 177 |
+
ptr_name = "grad_weight_ptr"
|
| 178 |
+
else:
|
| 179 |
+
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
| 180 |
+
|
| 181 |
+
if dtsize is None:
|
| 182 |
+
dtsize = named_args[ptr_name].element_size()
|
| 183 |
+
if dtype is None:
|
| 184 |
+
dtype = named_args[ptr_name].dtype
|
| 185 |
+
|
| 186 |
+
pruned_configs = []
|
| 187 |
+
for config in configs:
|
| 188 |
+
kw = config.kwargs
|
| 189 |
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
| 190 |
+
kw["BLOCK_SIZE_M"],
|
| 191 |
+
kw["BLOCK_SIZE_N"],
|
| 192 |
+
kw["BLOCK_SIZE_K"],
|
| 193 |
+
config.num_stages,
|
| 194 |
+
)
|
| 195 |
+
G, M, N, K = (
|
| 196 |
+
named_args["G"],
|
| 197 |
+
named_args["M_BUCKET"],
|
| 198 |
+
named_args["N"],
|
| 199 |
+
named_args["K"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 1. make sure we have enough smem
|
| 203 |
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
| 204 |
+
"max_shared_mem"
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
| 208 |
+
if required_shared_memory > max_shared_memory:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
M_PER_GROUP = M // G
|
| 212 |
+
MIN_M_TILES = 64
|
| 213 |
+
# 2. make sure we don't load M tiles that are too big
|
| 214 |
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
| 215 |
+
continue
|
| 216 |
+
# 3. make sure we don't load N tiles that are too small
|
| 217 |
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
num_sm = driver.active.utils.get_device_properties(device)[
|
| 221 |
+
"multiprocessor_count"
|
| 222 |
+
]
|
| 223 |
+
N_TILES = N // BLOCK_N
|
| 224 |
+
MIN_N_TILES = 64
|
| 225 |
+
# 4. make sure we don't load N tiles that are too big
|
| 226 |
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
| 227 |
+
continue
|
| 228 |
+
# 5. make sure we don't load N tiles that are too small
|
| 229 |
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
| 230 |
+
continue
|
| 231 |
+
# 6. make sure K can be evenly divided
|
| 232 |
+
if K % BLOCK_K != 0:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
pruned_configs.append(config)
|
| 236 |
+
|
| 237 |
+
return pruned_configs
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ======== End Autotuning utilities ========
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# pyre-unsafe
|
| 8 |
+
import logging
|
| 9 |
+
import unittest
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from mg_grouped_gemm import grouped_gemm_forward
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestMG_GroupedGEMM(unittest.TestCase):
|
| 19 |
+
def setUp(self) -> None:
|
| 20 |
+
torch.manual_seed(2020)
|
| 21 |
+
|
| 22 |
+
def _run_grouped_gemm_test(
|
| 23 |
+
self,
|
| 24 |
+
shape: Tuple[int, int, int, int],
|
| 25 |
+
device: torch.device,
|
| 26 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 27 |
+
atol: float = 1e-5,
|
| 28 |
+
rtol: float = 1.6e-2,
|
| 29 |
+
) -> None:
|
| 30 |
+
G, M, N, K = shape
|
| 31 |
+
# In M*G grouping, input is [M*G, K] and weights are [N*G, K]
|
| 32 |
+
a = torch.randn(M * G, K, dtype=dtype, device=device)
|
| 33 |
+
b = torch.randn(N * G, K, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
# Create equal-sized groups for simplicity
|
| 36 |
+
m_size = M
|
| 37 |
+
m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
|
| 38 |
+
|
| 39 |
+
result = grouped_gemm_forward(a, b, m_sizes)
|
| 40 |
+
self.assertTrue(result.shape == (M * G, N))
|
| 41 |
+
|
| 42 |
+
expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
|
| 43 |
+
m_start = 0
|
| 44 |
+
for g in range(G):
|
| 45 |
+
m_end = m_start + m_sizes[g]
|
| 46 |
+
b_slice = b[N * g : N * (g+1), :]
|
| 47 |
+
expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
|
| 48 |
+
m_start = m_end
|
| 49 |
+
|
| 50 |
+
# Convert result to match input dtype if needed
|
| 51 |
+
result = result.to(dtype)
|
| 52 |
+
torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
|
| 53 |
+
|
| 54 |
+
def test_MG_grouped_gemm_bf16(self) -> None:
|
| 55 |
+
for G in (1, 4, 16):
|
| 56 |
+
for M in (128, 512, 1024):
|
| 57 |
+
print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
|
| 58 |
+
self._run_grouped_gemm_test(
|
| 59 |
+
(G, M, 1024, 1024),
|
| 60 |
+
torch.device("cuda"),
|
| 61 |
+
dtype=torch.bfloat16,
|
| 62 |
+
atol=1e-5,
|
| 63 |
+
rtol=1.6e-2,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
|
| 67 |
+
"""Test with shapes from Deepseek model."""
|
| 68 |
+
deepseek_shapes = [
|
| 69 |
+
(4, 2048, 4096, 7168), # G, M, N, K
|
| 70 |
+
(4, 2048, 7168, 2048),
|
| 71 |
+
(8, 512, 4096, 7168),
|
| 72 |
+
(8, 512, 7168, 2048),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda")
|
| 76 |
+
|
| 77 |
+
for shape in deepseek_shapes:
|
| 78 |
+
G, M, N, K = shape
|
| 79 |
+
print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
|
| 80 |
+
self._run_grouped_gemm_test(
|
| 81 |
+
shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
|
| 82 |
+
)
|
torchtitan/experiments/llama4/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
torchtitan/experiments/llama4/infra/expert_parallel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.tensor import (
|
| 13 |
+
DeviceMesh,
|
| 14 |
+
distribute_module,
|
| 15 |
+
distribute_tensor,
|
| 16 |
+
DTensor,
|
| 17 |
+
Partial,
|
| 18 |
+
Replicate,
|
| 19 |
+
Shard,
|
| 20 |
+
)
|
| 21 |
+
from torch.distributed.tensor.parallel import ParallelStyle
|
| 22 |
+
from torch.distributed.tensor.placement_types import Placement
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# implementation of Tensor Parallel on the non-shared experts in MoE
|
| 26 |
+
class TensorParallel(ParallelStyle):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
input_layouts: Optional[Tuple[Optional[Placement]]] = None,
|
| 31 |
+
output_layout: Optional[Placement] = None,
|
| 32 |
+
use_local_output: bool = True,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.input_layouts = input_layouts or (Replicate(), None)
|
| 36 |
+
self.output_layout = output_layout or Partial()
|
| 37 |
+
self.desired_input_layouts = (Replicate(), None)
|
| 38 |
+
self.use_local_output = use_local_output
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _prepare_input_fn(
|
| 42 |
+
input_layouts, desired_input_layouts, mod, inputs, device_mesh
|
| 43 |
+
):
|
| 44 |
+
# TODO: figure out dynamo support for instance method and switch this to instance method
|
| 45 |
+
|
| 46 |
+
# annotate module input placements/sharding with input_layouts
|
| 47 |
+
input_tensor, input_layout, desired_input_layout = (
|
| 48 |
+
inputs[0],
|
| 49 |
+
input_layouts[0],
|
| 50 |
+
desired_input_layouts[0],
|
| 51 |
+
)
|
| 52 |
+
if not isinstance(input_tensor, DTensor):
|
| 53 |
+
input_tensor = DTensor.from_local(
|
| 54 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if input_layouts != desired_input_layouts:
|
| 58 |
+
input_tensor = input_tensor.redistribute(
|
| 59 |
+
placements=(desired_input_layout,), async_op=True
|
| 60 |
+
)
|
| 61 |
+
return (input_tensor, *inputs[1:])
|
| 62 |
+
|
| 63 |
+
def _partition_fn(self, name, module, device_mesh):
|
| 64 |
+
module.register_parameter(
|
| 65 |
+
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
|
| 66 |
+
) # Column-wise sharding
|
| 67 |
+
module.register_parameter(
|
| 68 |
+
"w2",
|
| 69 |
+
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
|
| 70 |
+
) # Row-wise sharding
|
| 71 |
+
module.register_parameter(
|
| 72 |
+
"w3",
|
| 73 |
+
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
|
| 74 |
+
) # Column-wise sharding
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 78 |
+
if outputs.placements != (output_layout,):
|
| 79 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 80 |
+
# back to local tensor
|
| 81 |
+
return outputs.to_local() if use_local_output else outputs
|
| 82 |
+
|
| 83 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 84 |
+
return distribute_module(
|
| 85 |
+
module,
|
| 86 |
+
device_mesh,
|
| 87 |
+
self._partition_fn,
|
| 88 |
+
partial(
|
| 89 |
+
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
|
| 90 |
+
),
|
| 91 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
|
| 96 |
+
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
|
| 97 |
+
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
|
| 98 |
+
# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
|
| 99 |
+
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
|
| 100 |
+
class NoParallel(ParallelStyle):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
input_layout: Optional[Placement] = None,
|
| 105 |
+
output_layout: Optional[Placement] = None,
|
| 106 |
+
use_local_output: bool = True,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.input_layout = input_layout or Replicate()
|
| 110 |
+
self.output_layout = output_layout or Replicate()
|
| 111 |
+
self.desired_input_layout = Replicate()
|
| 112 |
+
self.use_local_output = use_local_output
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
|
| 116 |
+
# annotate module input placements/sharding with input_layouts
|
| 117 |
+
input_tensor = inputs[0]
|
| 118 |
+
if not isinstance(input_tensor, DTensor):
|
| 119 |
+
input_tensor = DTensor.from_local(
|
| 120 |
+
input_tensor, device_mesh, (input_layout,), run_check=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if input_layout != desired_input_layout:
|
| 124 |
+
input_tensor = input_tensor.redistribute(
|
| 125 |
+
placements=(desired_input_layout,), async_op=True
|
| 126 |
+
)
|
| 127 |
+
return (input_tensor, *inputs[1:])
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
|
| 131 |
+
if outputs.placements != (output_layout,):
|
| 132 |
+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
|
| 133 |
+
# back to local tensor
|
| 134 |
+
return outputs.to_local() if use_local_output else outputs
|
| 135 |
+
|
| 136 |
+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
| 137 |
+
return distribute_module(
|
| 138 |
+
module,
|
| 139 |
+
device_mesh,
|
| 140 |
+
None,
|
| 141 |
+
partial(
|
| 142 |
+
self._prepare_input_fn, self.input_layout, self.desired_input_layout
|
| 143 |
+
),
|
| 144 |
+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
|
| 145 |
+
)
|
torchtitan/experiments/llama4/infra/parallelize_llama.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
| 13 |
+
from torchtitan.distributed import ParallelDims
|
| 14 |
+
|
| 15 |
+
from torchtitan.models.llama3.parallelize_llama import (
|
| 16 |
+
apply_ac,
|
| 17 |
+
apply_compile,
|
| 18 |
+
apply_ddp,
|
| 19 |
+
apply_fsdp,
|
| 20 |
+
apply_tp,
|
| 21 |
+
)
|
| 22 |
+
from torchtitan.tools.logging import logger
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parallelize_llama(
|
| 26 |
+
model: nn.Module,
|
| 27 |
+
world_mesh: DeviceMesh,
|
| 28 |
+
parallel_dims: ParallelDims,
|
| 29 |
+
job_config: JobConfig,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 33 |
+
parallelism to the model.
|
| 34 |
+
|
| 35 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 36 |
+
the model must fit on GPU or CPU memory.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
if parallel_dims.tp_enabled:
|
| 40 |
+
if (
|
| 41 |
+
job_config.parallelism.enable_async_tensor_parallel
|
| 42 |
+
and not job_config.training.compile
|
| 43 |
+
):
|
| 44 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 45 |
+
|
| 46 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 47 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
| 48 |
+
"rowwise",
|
| 49 |
+
"rowwise_with_gw_hp",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
| 53 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
| 54 |
+
# all-gather happens in high precision.
|
| 55 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
| 56 |
+
|
| 57 |
+
apply_tp(
|
| 58 |
+
model,
|
| 59 |
+
world_mesh["tp"],
|
| 60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 61 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
| 62 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
apply_moe_tp(model, world_mesh["tp"])
|
| 66 |
+
|
| 67 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 68 |
+
if (
|
| 69 |
+
job_config.activation_checkpoint.mode == "selective"
|
| 70 |
+
and job_config.model.use_flex_attn
|
| 71 |
+
):
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"FlexAttention is not compatible with selective AC yet. "
|
| 74 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
| 75 |
+
)
|
| 76 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 77 |
+
|
| 78 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
| 79 |
+
if job_config.training.compile:
|
| 80 |
+
apply_compile(model)
|
| 81 |
+
|
| 82 |
+
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
|
| 83 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 84 |
+
|
| 85 |
+
if (
|
| 86 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 87 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 88 |
+
if parallel_dims.dp_replicate_enabled:
|
| 89 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 90 |
+
else:
|
| 91 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 92 |
+
|
| 93 |
+
apply_fsdp(
|
| 94 |
+
model,
|
| 95 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 96 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 97 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 98 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 99 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 100 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if parallel_dims.dp_replicate_enabled:
|
| 104 |
+
logger.info("Applied HSDP to the model")
|
| 105 |
+
else:
|
| 106 |
+
logger.info("Applied FSDP to the model")
|
| 107 |
+
|
| 108 |
+
if parallel_dims.cp_enabled:
|
| 109 |
+
logger.info("Applied Context Parallel to the model")
|
| 110 |
+
|
| 111 |
+
if job_config.training.enable_cpu_offload:
|
| 112 |
+
logger.info("Applied CPU Offloading to the model")
|
| 113 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 114 |
+
if world_mesh.ndim > 1:
|
| 115 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 116 |
+
apply_ddp(
|
| 117 |
+
model,
|
| 118 |
+
world_mesh,
|
| 119 |
+
enable_compile=job_config.training.compile,
|
| 120 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def apply_moe_tp(
|
| 127 |
+
model: nn.Module,
|
| 128 |
+
tp_mesh: DeviceMesh,
|
| 129 |
+
):
|
| 130 |
+
from torch.distributed.tensor import Partial, Replicate, Shard
|
| 131 |
+
from torch.distributed.tensor.parallel import (
|
| 132 |
+
parallelize_module,
|
| 133 |
+
PrepareModuleInputOutput,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
from .expert_parallel import NoParallel, TensorParallel
|
| 137 |
+
|
| 138 |
+
for _, transformer_block in model.layers.items():
|
| 139 |
+
moe_layer_plan = {
|
| 140 |
+
# input / output sharding on the seqlen dim
|
| 141 |
+
# all-gather for input, reduce-scatter for output
|
| 142 |
+
"moe": PrepareModuleInputOutput(
|
| 143 |
+
input_layouts=(Shard(1),),
|
| 144 |
+
desired_input_layouts=(Replicate(),),
|
| 145 |
+
use_local_input=True,
|
| 146 |
+
output_layouts=(Partial(),),
|
| 147 |
+
desired_output_layouts=(Shard(1),),
|
| 148 |
+
),
|
| 149 |
+
# replicate computation for the router
|
| 150 |
+
"moe.router.gate": NoParallel(),
|
| 151 |
+
# input Replicate, output Partial
|
| 152 |
+
"moe.experts": TensorParallel(),
|
| 153 |
+
"moe.shared_expert": TensorParallel(),
|
| 154 |
+
}
|
| 155 |
+
parallelize_module(
|
| 156 |
+
module=transformer_block,
|
| 157 |
+
device_mesh=tp_mesh,
|
| 158 |
+
parallelize_plan=moe_layer_plan,
|
| 159 |
+
)
|
torchtitan/experiments/llama4/model/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
torchtitan/experiments/llama4/model/model.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from torchtitan.models.attention import build_attention, init_attention_mask
|
| 13 |
+
from torchtitan.models.norms import build_norm
|
| 14 |
+
from torchtitan.protocols.train_spec import ModelProtocol
|
| 15 |
+
|
| 16 |
+
from .args import TransformerModelArgs
|
| 17 |
+
from .moe import MoE
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 21 |
+
"""
|
| 22 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 23 |
+
|
| 24 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 25 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 26 |
+
The returned tensor contains complex values in complex64 data type.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dim (int): Dimension of the frequency tensor.
|
| 30 |
+
end (int): End index for precomputing frequencies.
|
| 31 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 35 |
+
"""
|
| 36 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 37 |
+
t = torch.arange(end, device=freqs.device)
|
| 38 |
+
freqs = torch.outer(t, freqs).float()
|
| 39 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 40 |
+
return freqs_cis
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 46 |
+
|
| 47 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 48 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 49 |
+
|
| 50 |
+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
|
| 51 |
+
and the first seqlen elements will be sliced, but dim must match x.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
| 55 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 59 |
+
"""
|
| 60 |
+
ndim = x.ndim
|
| 61 |
+
assert ndim > 1
|
| 62 |
+
seqlen = x.shape[1]
|
| 63 |
+
freqs_cis = freqs_cis[0:seqlen]
|
| 64 |
+
assert freqs_cis.shape == (seqlen, x.shape[-1])
|
| 65 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 66 |
+
return freqs_cis.view(*shape)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def apply_rotary_emb(
|
| 70 |
+
xq: torch.Tensor,
|
| 71 |
+
xk: torch.Tensor,
|
| 72 |
+
freqs_cis: torch.Tensor,
|
| 73 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 74 |
+
"""
|
| 75 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 76 |
+
|
| 77 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| 78 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| 79 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| 80 |
+
returned as real tensors.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
| 84 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
| 85 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 89 |
+
"""
|
| 90 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 91 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 92 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 93 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 94 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 95 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 99 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
| 100 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
| 101 |
+
if n_rep == 1:
|
| 102 |
+
return x
|
| 103 |
+
return (
|
| 104 |
+
torch.unsqueeze(x, dim=3)
|
| 105 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 106 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Attention(nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
Multi-head attention module.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 116 |
+
|
| 117 |
+
Attributes:
|
| 118 |
+
n_kv_heads (int): Number of key and value heads.
|
| 119 |
+
n_heads (int): Number of query heads.
|
| 120 |
+
n_rep (int): Number of repetitions for local heads.
|
| 121 |
+
head_dim (int): Dimension size of each attention head.
|
| 122 |
+
wq (Linear): Linear transformation for queries.
|
| 123 |
+
wk (Linear): Linear transformation for keys.
|
| 124 |
+
wv (Linear): Linear transformation for values.
|
| 125 |
+
wo (Linear): Linear transformation for output.
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.n_heads = model_args.n_heads
|
| 132 |
+
self.n_kv_heads = (
|
| 133 |
+
model_args.n_heads
|
| 134 |
+
if model_args.n_kv_heads is None
|
| 135 |
+
else model_args.n_kv_heads
|
| 136 |
+
)
|
| 137 |
+
self.n_rep = self.n_heads // self.n_kv_heads
|
| 138 |
+
self.head_dim = model_args.dim // model_args.n_heads
|
| 139 |
+
|
| 140 |
+
self.wq = nn.Linear(
|
| 141 |
+
model_args.dim, model_args.n_heads * self.head_dim, bias=False
|
| 142 |
+
)
|
| 143 |
+
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
| 144 |
+
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
| 145 |
+
self.wo = nn.Linear(
|
| 146 |
+
model_args.n_heads * self.head_dim, model_args.dim, bias=False
|
| 147 |
+
)
|
| 148 |
+
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
|
| 149 |
+
|
| 150 |
+
def init_weights(self, init_std: float):
|
| 151 |
+
for linear in (self.wq, self.wk, self.wv):
|
| 152 |
+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
|
| 153 |
+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
freqs_cis: torch.Tensor,
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Forward pass of the attention module.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
x (torch.Tensor): Input tensor.
|
| 165 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: Output tensor after attention.
|
| 169 |
+
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
bs, seqlen, _ = x.shape
|
| 173 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 174 |
+
|
| 175 |
+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
|
| 176 |
+
# local heads from sizes of xq, xk, and xv as TP may have sharded them
|
| 177 |
+
# after the above linear ops.
|
| 178 |
+
xq = xq.view(bs, seqlen, -1, self.head_dim)
|
| 179 |
+
xk = xk.view(bs, seqlen, -1, self.head_dim)
|
| 180 |
+
xv = xv.view(bs, seqlen, -1, self.head_dim)
|
| 181 |
+
|
| 182 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
| 183 |
+
|
| 184 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 185 |
+
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
| 186 |
+
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
| 187 |
+
|
| 188 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
| 189 |
+
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
| 190 |
+
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
| 191 |
+
|
| 192 |
+
output = self.sdpa(xq, xk, xv)
|
| 193 |
+
|
| 194 |
+
output = output.transpose(
|
| 195 |
+
1, 2
|
| 196 |
+
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
|
| 197 |
+
output = output.view(bs, seqlen, -1)
|
| 198 |
+
return self.wo(output)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FeedForward(nn.Module):
|
| 202 |
+
"""
|
| 203 |
+
FeedForward module
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
dim (int): Input dimension.
|
| 207 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
| 208 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
| 209 |
+
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
|
| 210 |
+
|
| 211 |
+
Attributes:
|
| 212 |
+
w1 (Linear): Linear transformation for the first layer.
|
| 213 |
+
w2 (Linear): Linear transformation for the second layer.
|
| 214 |
+
w3 (Linear): Linear transformation for the third layer.
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
dim: int,
|
| 221 |
+
hidden_dim: int,
|
| 222 |
+
multiple_of: int,
|
| 223 |
+
ffn_dim_multiplier: float | None,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 227 |
+
# custom dim factor multiplier
|
| 228 |
+
if ffn_dim_multiplier is not None:
|
| 229 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 230 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 231 |
+
|
| 232 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 233 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 234 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 238 |
+
|
| 239 |
+
def init_weights(self, init_std: float):
|
| 240 |
+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
|
| 241 |
+
for linear in (self.w2, self.w3):
|
| 242 |
+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class TransformerBlock(nn.Module):
|
| 246 |
+
"""
|
| 247 |
+
TransformerBlock Module
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
layer_id (int): Identifier for the layer.
|
| 251 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 252 |
+
|
| 253 |
+
Attributes:
|
| 254 |
+
n_heads (int): Number of attention heads.
|
| 255 |
+
dim (int): Dimension size of the model.
|
| 256 |
+
head_dim (int): Dimension size of each attention head.
|
| 257 |
+
attention (Attention): Attention module.
|
| 258 |
+
feed_forward (FeedForward): FeedForward module.
|
| 259 |
+
layer_id (int): Identifier for the layer.
|
| 260 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
| 261 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
| 262 |
+
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
def __init__(self, layer_id: int, model_args: TransformerModelArgs):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.n_heads = model_args.n_heads
|
| 268 |
+
self.dim = model_args.dim
|
| 269 |
+
self.attention = Attention(model_args)
|
| 270 |
+
|
| 271 |
+
# use MoE layer for every interleave_moe_layer_step FFN layers
|
| 272 |
+
self.moe_enabled = (
|
| 273 |
+
model_args.moe_enabled
|
| 274 |
+
and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
|
| 275 |
+
)
|
| 276 |
+
if self.moe_enabled:
|
| 277 |
+
self.moe = MoE(model_args)
|
| 278 |
+
else:
|
| 279 |
+
self.feed_forward = FeedForward(
|
| 280 |
+
dim=model_args.dim,
|
| 281 |
+
hidden_dim=4 * model_args.dim,
|
| 282 |
+
multiple_of=model_args.multiple_of,
|
| 283 |
+
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
self.layer_id = layer_id
|
| 287 |
+
self.num_layers = model_args.n_layers
|
| 288 |
+
|
| 289 |
+
self.attention_norm = build_norm(
|
| 290 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
| 291 |
+
)
|
| 292 |
+
self.ffn_norm = build_norm(
|
| 293 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if model_args.depth_init:
|
| 297 |
+
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
|
| 298 |
+
else:
|
| 299 |
+
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
|
| 300 |
+
|
| 301 |
+
def forward(
|
| 302 |
+
self,
|
| 303 |
+
x: torch.Tensor,
|
| 304 |
+
freqs_cis: torch.Tensor,
|
| 305 |
+
):
|
| 306 |
+
"""
|
| 307 |
+
Perform a forward pass through the TransformerBlock.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
x (torch.Tensor): Input tensor.
|
| 311 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
| 315 |
+
|
| 316 |
+
"""
|
| 317 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis)
|
| 318 |
+
if self.moe_enabled:
|
| 319 |
+
out = h + self.moe(self.ffn_norm(h))
|
| 320 |
+
else:
|
| 321 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 322 |
+
return out
|
| 323 |
+
|
| 324 |
+
def init_weights(self):
|
| 325 |
+
for norm in (self.attention_norm, self.ffn_norm):
|
| 326 |
+
norm.reset_parameters()
|
| 327 |
+
self.attention.init_weights(self.weight_init_std)
|
| 328 |
+
if self.moe_enabled:
|
| 329 |
+
self.moe.init_weights(self.weight_init_std)
|
| 330 |
+
else:
|
| 331 |
+
self.feed_forward.init_weights(self.weight_init_std)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class Transformer(nn.Module, ModelProtocol):
|
| 335 |
+
"""
|
| 336 |
+
Transformer Module
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 340 |
+
|
| 341 |
+
Attributes:
|
| 342 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 343 |
+
vocab_size (int): Vocabulary size.
|
| 344 |
+
n_layers (int): Number of layers in the model.
|
| 345 |
+
tok_embeddings (ParallelEmbedding): Token embeddings.
|
| 346 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
| 347 |
+
norm (RMSNorm): Layer normalization for the model output.
|
| 348 |
+
output (ColumnParallelLinear): Linear layer for final output.
|
| 349 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
| 350 |
+
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.model_args = model_args
|
| 356 |
+
self.vocab_size = model_args.vocab_size
|
| 357 |
+
self.n_layers = model_args.n_layers
|
| 358 |
+
self.eos_id = model_args.eos_id
|
| 359 |
+
|
| 360 |
+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
|
| 361 |
+
|
| 362 |
+
# TODO persistent should be set to false, since this buffer can be recomputed.
|
| 363 |
+
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
|
| 364 |
+
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
|
| 365 |
+
# so we need to fix that. (2) if we initialize pipeline-parallel models from
|
| 366 |
+
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
|
| 367 |
+
# initialized by the checkpoint, or we need to add a separate initializer for
|
| 368 |
+
# just the non-persistent buffers that is called after loading checkpoints.
|
| 369 |
+
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
|
| 370 |
+
|
| 371 |
+
self.layers = torch.nn.ModuleDict()
|
| 372 |
+
for layer_id in range(model_args.n_layers):
|
| 373 |
+
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
|
| 374 |
+
|
| 375 |
+
self.norm = build_norm(
|
| 376 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
|
| 380 |
+
self.init_weights()
|
| 381 |
+
|
| 382 |
+
def init_weights(
|
| 383 |
+
self,
|
| 384 |
+
buffer_device: torch.device | None = None,
|
| 385 |
+
):
|
| 386 |
+
"""
|
| 387 |
+
[Note: On ``init_weights`` vs. ``reset_parameters``]
|
| 388 |
+
Modules may define ``reset_parameters`` to initialize parameter values.
|
| 389 |
+
``reset_parameters`` is meant to only initialize directly owned
|
| 390 |
+
parameters/buffers, not those of their child modules, and it can be
|
| 391 |
+
used to give the initial values for these tensors.
|
| 392 |
+
Separately, users may want custom initialization for their modules,
|
| 393 |
+
different from that in ``reset_parameters``. For this, we define
|
| 394 |
+
``init_weights``. We only call it in the constructor of this
|
| 395 |
+
``Transformer`` root module to avoid reinitializing tensors.
|
| 396 |
+
"""
|
| 397 |
+
buffer_device = buffer_device or self.freqs_cis.device
|
| 398 |
+
with torch.device(buffer_device):
|
| 399 |
+
self.freqs_cis = self._precompute_freqs_cis()
|
| 400 |
+
if self.tok_embeddings is not None:
|
| 401 |
+
nn.init.normal_(self.tok_embeddings.weight)
|
| 402 |
+
for layer in self.layers.values():
|
| 403 |
+
if layer is not None:
|
| 404 |
+
layer.init_weights()
|
| 405 |
+
if self.norm is not None:
|
| 406 |
+
self.norm.reset_parameters()
|
| 407 |
+
final_out_std = self.model_args.dim**-0.5
|
| 408 |
+
cutoff_factor = 3
|
| 409 |
+
if self.output is not None:
|
| 410 |
+
nn.init.trunc_normal_(
|
| 411 |
+
self.output.weight,
|
| 412 |
+
mean=0.0,
|
| 413 |
+
std=final_out_std,
|
| 414 |
+
a=-cutoff_factor * final_out_std,
|
| 415 |
+
b=cutoff_factor * final_out_std,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def _precompute_freqs_cis(self) -> torch.Tensor:
|
| 419 |
+
return precompute_freqs_cis(
|
| 420 |
+
self.model_args.dim // self.model_args.n_heads,
|
| 421 |
+
# Need to compute until at least the max token limit for generation
|
| 422 |
+
# TODO: explain in docs/composability.md why we removed the 2x
|
| 423 |
+
# relaxing in our CP enablement PR
|
| 424 |
+
self.model_args.max_seq_len,
|
| 425 |
+
self.model_args.rope_theta,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def forward(self, tokens: torch.Tensor):
|
| 429 |
+
"""
|
| 430 |
+
Perform a forward pass through the Transformer model.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
tokens (torch.Tensor): Input token indices.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
| 437 |
+
|
| 438 |
+
"""
|
| 439 |
+
# TODO: We will to change forward() signature to allow tokens to
|
| 440 |
+
# be always passed in.
|
| 441 |
+
if self.model_args.use_flex_attn:
|
| 442 |
+
init_attention_mask(tokens, eos_id=self.eos_id)
|
| 443 |
+
|
| 444 |
+
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
|
| 445 |
+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
|
| 446 |
+
|
| 447 |
+
for layer in self.layers.values():
|
| 448 |
+
h = layer(h, self.freqs_cis)
|
| 449 |
+
|
| 450 |
+
h = self.norm(h) if self.norm else h
|
| 451 |
+
output = self.output(h) if self.output else h
|
| 452 |
+
return output
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
|
| 456 |
+
"""
|
| 457 |
+
Initialize a Transformer model from a TransformerModelArgs object.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
Transformer: Transformer model.
|
| 464 |
+
|
| 465 |
+
"""
|
| 466 |
+
return cls(model_args)
|
torchtitan/experiments/llama4/model/moe.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .args import TransformerModelArgs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GroupedExperts(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
hidden_dim: int,
|
| 19 |
+
num_experts: int,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.num_experts = num_experts
|
| 23 |
+
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 24 |
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
|
| 25 |
+
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
num_local_tokens_per_expert: torch.Tensor | None = None,
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
if num_local_tokens_per_expert is not None:
|
| 33 |
+
# a tuple of tensors indexed by experts
|
| 34 |
+
# each with shape (tokens_per_expert(varying), dim)
|
| 35 |
+
x = torch.split(
|
| 36 |
+
x,
|
| 37 |
+
split_size_or_sections=num_local_tokens_per_expert.tolist(),
|
| 38 |
+
dim=0,
|
| 39 |
+
)
|
| 40 |
+
out_experts_splits = []
|
| 41 |
+
for expert_idx, x_expert in enumerate(x):
|
| 42 |
+
w1, w2, w3 = (
|
| 43 |
+
self.w1[expert_idx],
|
| 44 |
+
self.w2[expert_idx],
|
| 45 |
+
self.w3[expert_idx],
|
| 46 |
+
)
|
| 47 |
+
h = F.silu(torch.matmul(x_expert, w1))
|
| 48 |
+
h = h * torch.matmul(x_expert, w3)
|
| 49 |
+
h = torch.matmul(h, w2)
|
| 50 |
+
# h shape (tokens_per_expert(varying), dim)
|
| 51 |
+
out_experts_splits.append(h)
|
| 52 |
+
out = torch.cat(out_experts_splits, dim=0)
|
| 53 |
+
|
| 54 |
+
# TODO:optimize with GroupedGEMM
|
| 55 |
+
# https://github.com/pytorch/pytorch/pull/150374
|
| 56 |
+
# _gouped_mm requires shapes to be multiple of 8
|
| 57 |
+
# offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
|
| 58 |
+
# h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
|
| 59 |
+
# h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 60 |
+
# out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 61 |
+
else:
|
| 62 |
+
# x shape (num_experts, tokens_per_expert, dim)
|
| 63 |
+
h = F.silu(torch.bmm(x, self.w1))
|
| 64 |
+
h = h * torch.bmm(x, self.w3)
|
| 65 |
+
# out shape (num_experts, tokens_per_expert, dim)
|
| 66 |
+
out = torch.bmm(h, self.w2)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
def init_weights(self, init_std: float):
|
| 70 |
+
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
|
| 71 |
+
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
|
| 72 |
+
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TokenChoiceTopKRouter(nn.Module):
|
| 76 |
+
"""This class implements token-choice routing. In token-choice top-K routing, each token is
|
| 77 |
+
routed to top K experts based on the router scores.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
|
| 81 |
+
dim (int): Dimension of input tokens.
|
| 82 |
+
num_experts (int): Number of experts in each moe layer.
|
| 83 |
+
top_k (int): Number of experts each token will be routed to in token-choice routing.
|
| 84 |
+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int,
|
| 90 |
+
num_experts: int,
|
| 91 |
+
top_k: int,
|
| 92 |
+
use_sigmoid: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
| 96 |
+
self.num_experts = num_experts
|
| 97 |
+
self.top_k = top_k
|
| 98 |
+
self.use_sigmoid = use_sigmoid
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self, x: torch.Tensor
|
| 102 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
routed_input (torch.Tensor):
|
| 109 |
+
Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
|
| 110 |
+
token_indices (torch.Tensor):
|
| 111 |
+
Token indices for routed_input with shape ``(bs*slen*top_k,)``.
|
| 112 |
+
num_local_tokens_per_expert (torch.Tensor):
|
| 113 |
+
Number of tokens assigned to each expert with shape ``(num_experts,)``.
|
| 114 |
+
"""
|
| 115 |
+
# scores shape (bs*slen, num_experts)
|
| 116 |
+
scores = self.gate(x)
|
| 117 |
+
|
| 118 |
+
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
|
| 119 |
+
if self.use_sigmoid:
|
| 120 |
+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
|
| 121 |
+
else:
|
| 122 |
+
scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
|
| 123 |
+
|
| 124 |
+
# top scores shape (bs*slen, top_k)
|
| 125 |
+
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
|
| 126 |
+
# top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
|
| 127 |
+
|
| 128 |
+
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
|
| 129 |
+
num_local_tokens_per_expert = torch.histc(
|
| 130 |
+
selected_experts_indices.view(-1),
|
| 131 |
+
bins=self.num_experts,
|
| 132 |
+
min=0,
|
| 133 |
+
max=self.num_experts,
|
| 134 |
+
)
|
| 135 |
+
# token_indices_experts_sorted shape (bs*slen*top_k,)
|
| 136 |
+
token_indices_experts_sorted = torch.argsort(
|
| 137 |
+
selected_experts_indices.view(-1), stable=True
|
| 138 |
+
)
|
| 139 |
+
top_scores = top_scores.view(-1)[token_indices_experts_sorted]
|
| 140 |
+
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
|
| 141 |
+
|
| 142 |
+
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
|
| 143 |
+
|
| 144 |
+
def init_weights(self, init_std: float):
|
| 145 |
+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# TODO: implement load balancing auxiliary loss for token-choice routing
|
| 149 |
+
class MoE(nn.Module):
|
| 150 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 151 |
+
super().__init__()
|
| 152 |
+
dim = model_args.dim
|
| 153 |
+
hidden_dim = 4 * model_args.dim
|
| 154 |
+
ffn_dim_multiplier = model_args.ffn_dim_multiplier
|
| 155 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 156 |
+
if ffn_dim_multiplier is not None:
|
| 157 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 158 |
+
|
| 159 |
+
num_experts = model_args.num_experts
|
| 160 |
+
|
| 161 |
+
hidden_dim_denom = 1
|
| 162 |
+
if model_args.auto_scale_hidden_dim:
|
| 163 |
+
hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
|
| 164 |
+
|
| 165 |
+
if model_args.auto_scale_hidden_dim:
|
| 166 |
+
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
| 167 |
+
hidden_dim += -hidden_dim % model_args.multiple_of
|
| 168 |
+
|
| 169 |
+
self.experts = GroupedExperts(
|
| 170 |
+
dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
|
| 171 |
+
)
|
| 172 |
+
self.router = TokenChoiceTopKRouter(
|
| 173 |
+
dim=dim, num_experts=num_experts, top_k=model_args.top_k
|
| 174 |
+
)
|
| 175 |
+
self.shared_expert = (
|
| 176 |
+
GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
|
| 177 |
+
if model_args.use_shared_expert
|
| 178 |
+
else None
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
|
| 188 |
+
"""
|
| 189 |
+
bs, slen, dim = x.shape
|
| 190 |
+
# top_scores and selected_indices shape (bs*slen*top_k,)
|
| 191 |
+
# num_local_tokens_per_expert shape (num_experts,)
|
| 192 |
+
(
|
| 193 |
+
top_scores,
|
| 194 |
+
token_indices,
|
| 195 |
+
num_local_tokens_per_expert,
|
| 196 |
+
) = self.router(x.reshape(bs * slen, dim))
|
| 197 |
+
|
| 198 |
+
# shape (bs*slen*top_k, dim)
|
| 199 |
+
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
|
| 200 |
+
|
| 201 |
+
# shape (bs*slen*top_k, dim)
|
| 202 |
+
routed_input = torch.gather(
|
| 203 |
+
x.view(-1, dim),
|
| 204 |
+
dim=0,
|
| 205 |
+
index=token_indices,
|
| 206 |
+
)
|
| 207 |
+
routed_input = routed_input * top_scores.reshape(-1, 1)
|
| 208 |
+
|
| 209 |
+
# shape (bs*slen*top_k, dim)
|
| 210 |
+
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
|
| 211 |
+
|
| 212 |
+
# shared expert
|
| 213 |
+
if self.shared_expert is not None:
|
| 214 |
+
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
|
| 215 |
+
bs * slen, dim
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
out = torch.zeros_like(x.reshape(bs * slen, dim))
|
| 219 |
+
|
| 220 |
+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
|
| 221 |
+
out = out.reshape(bs, slen, dim)
|
| 222 |
+
return out
|
| 223 |
+
|
| 224 |
+
def init_weights(self, init_std: float):
|
| 225 |
+
self.experts.init_weights(init_std)
|
| 226 |
+
self.router.init_weights(init_std)
|
| 227 |
+
if self.shared_expert is not None:
|
| 228 |
+
self.shared_expert.init_weights(init_std)
|
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the BSD-style license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
set -ex
|
| 9 |
+
|
| 10 |
+
# use envs as local overrides for convenience
|
| 11 |
+
# e.g.
|
| 12 |
+
# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
|
| 13 |
+
NGPU=${NGPU:-"8"}
|
| 14 |
+
LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
|
| 15 |
+
CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
|
| 16 |
+
|
| 17 |
+
overrides=""
|
| 18 |
+
if [ $# -ne 0 ]; then
|
| 19 |
+
overrides="$*"
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
|
| 23 |
+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
|
| 24 |
+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
|
| 25 |
+
convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
|
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: this toml config is still under development
|
| 2 |
+
|
| 3 |
+
[job]
|
| 4 |
+
dump_folder = "./outputs"
|
| 5 |
+
description = "Llama 4 Maverick 17Bx128E training"
|
| 6 |
+
|
| 7 |
+
[profiling]
|
| 8 |
+
enable_profiling = false
|
| 9 |
+
save_traces_folder = "profile_trace"
|
| 10 |
+
profile_freq = 100
|
| 11 |
+
|
| 12 |
+
[metrics]
|
| 13 |
+
log_freq = 10
|
| 14 |
+
enable_tensorboard = false
|
| 15 |
+
save_tb_folder = "tb"
|
| 16 |
+
|
| 17 |
+
[model]
|
| 18 |
+
name = "llama4"
|
| 19 |
+
flavor = "17bx128e"
|
| 20 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 21 |
+
tokenizer_path = "./assets/tokenizer/tokenizer.model"
|
| 22 |
+
# converters = "float8"
|
| 23 |
+
|
| 24 |
+
[optimizer]
|
| 25 |
+
name = "AdamW"
|
| 26 |
+
lr = 4e-3
|
| 27 |
+
eps = 1e-15
|
| 28 |
+
|
| 29 |
+
[lr_scheduler]
|
| 30 |
+
warmup_steps = 600
|
| 31 |
+
lr_min = 0.1
|
| 32 |
+
|
| 33 |
+
[training]
|
| 34 |
+
batch_size = 1
|
| 35 |
+
seq_len = 8192
|
| 36 |
+
max_norm = 1.0 # grad norm clipping
|
| 37 |
+
steps = 3000
|
| 38 |
+
compile = false
|
| 39 |
+
dataset = "c4"
|
| 40 |
+
|
| 41 |
+
[parallelism]
|
| 42 |
+
data_parallel_replicate_degree = 1
|
| 43 |
+
data_parallel_shard_degree = -1
|
| 44 |
+
tensor_parallel_degree = 8
|
| 45 |
+
enable_async_tensor_parallel = false
|
| 46 |
+
pipeline_parallel_degree = 4
|
| 47 |
+
# pipeline_parallel_schedule = "interleaved1f1b"
|
| 48 |
+
# pipeline_parallel_microbatches = 2
|
| 49 |
+
context_parallel_degree = 1
|
| 50 |
+
|
| 51 |
+
[checkpoint]
|
| 52 |
+
enable_checkpoint = false
|
| 53 |
+
folder = "checkpoint"
|
| 54 |
+
interval = 500
|
| 55 |
+
model_weights_only = false
|
| 56 |
+
export_dtype = "float32"
|
| 57 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 58 |
+
|
| 59 |
+
[activation_checkpoint]
|
| 60 |
+
mode = 'full' # ['none', 'selective', 'full']
|
| 61 |
+
|
| 62 |
+
[float8]
|
| 63 |
+
enable_fsdp_float8_all_gather = false
|
| 64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 65 |
+
filter_fqns = "output,router.gate"
|
torchtitan/experiments/multimodal/mm_collator.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from tokenizer.tiktoken import IGNORE_INDEX
|
| 16 |
+
|
| 17 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def padded_collate(
|
| 21 |
+
batch: List[Dict[str, List[int]]],
|
| 22 |
+
padding_idx: int = 0,
|
| 23 |
+
ignore_idx: int = -100,
|
| 24 |
+
) -> Dict[str, torch.Tensor]:
|
| 25 |
+
"""Pad a batch of sequences to the longest sequence length in the batch, and
|
| 26 |
+
convert integer lists to tensors.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
|
| 30 |
+
padding_idx (int): Padding index for input ids. Defaults to 0.
|
| 31 |
+
ignore_idx (int): Padding index for labels. Defaults to -100.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dict[str, torch.Tensor]: Collated input and label tensors.
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
>>> token_pairs = [
|
| 38 |
+
>>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
| 39 |
+
>>> {"input_ids": [7,], "labels": [10,]},
|
| 40 |
+
>>> ]
|
| 41 |
+
>>> collated = padded_collate(
|
| 42 |
+
>>> batch=token_pairs,
|
| 43 |
+
>>> padding_idx=padding_idx,
|
| 44 |
+
>>> ignore_idx=ignore_idx,
|
| 45 |
+
>>> )
|
| 46 |
+
>>> collated["input_ids"]
|
| 47 |
+
>>> tensor([[1, 2, 3], [7, 0, 0]])
|
| 48 |
+
>>> collated["labels"]
|
| 49 |
+
>>> tensor([[4, 5, 6], [10, -100, -100]])
|
| 50 |
+
"""
|
| 51 |
+
input_ids = pad_sequence(
|
| 52 |
+
[x["input_ids"] for x in batch],
|
| 53 |
+
batch_first=True,
|
| 54 |
+
padding_value=padding_idx,
|
| 55 |
+
)
|
| 56 |
+
labels = pad_sequence(
|
| 57 |
+
[x["labels"] for x in batch],
|
| 58 |
+
batch_first=True,
|
| 59 |
+
padding_value=ignore_idx,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
input_ids_seq_len = input_ids.shape[-1]
|
| 63 |
+
labels_seq_len = labels.shape[-1]
|
| 64 |
+
|
| 65 |
+
# Hack to pad correctly and not use max_seq_len, which is costly
|
| 66 |
+
if input_ids_seq_len > labels_seq_len:
|
| 67 |
+
labels = F.pad(
|
| 68 |
+
labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx
|
| 69 |
+
)
|
| 70 |
+
elif labels_seq_len > input_ids_seq_len:
|
| 71 |
+
input_ids = F.pad(
|
| 72 |
+
input_ids,
|
| 73 |
+
(0, labels_seq_len - input_ids_seq_len),
|
| 74 |
+
value=padding_idx,
|
| 75 |
+
)
|
| 76 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# NOTE Inspired from torchtune.data._collate.py
|
| 80 |
+
@dataclass
|
| 81 |
+
class MultiModalCollator:
|
| 82 |
+
padding_idx: int = 128004
|
| 83 |
+
ignore_idx: int = IGNORE_INDEX
|
| 84 |
+
pad_max_tiles: Optional[int] = None
|
| 85 |
+
pad_max_images: Optional[int] = None
|
| 86 |
+
|
| 87 |
+
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 88 |
+
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
|
| 89 |
+
and cross attention masks. This can be used for both training and inference.
|
| 90 |
+
|
| 91 |
+
``batch`` is expected to be a list of sample dicts containing the following::
|
| 92 |
+
- "input_ids": List[int] of length text_seq_len, varies across samples
|
| 93 |
+
- "labels": List[int] of length text_seq_len, varies across samples
|
| 94 |
+
- "encoder_input": Dict[str, List[torch.Tensor]]
|
| 95 |
+
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
|
| 96 |
+
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
|
| 97 |
+
|
| 98 |
+
Shape notation:
|
| 99 |
+
- c = channel dim
|
| 100 |
+
- h = height dim
|
| 101 |
+
- w = weight dim
|
| 102 |
+
|
| 103 |
+
Note:
|
| 104 |
+
For each element in the batch, ``len(images) == len(aspect_ratio)``.
|
| 105 |
+
|
| 106 |
+
This collater does the following:
|
| 107 |
+
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
|
| 108 |
+
(2) Pad image tensors in the tile dimension with zeros to the largest number
|
| 109 |
+
of tiles in the batch
|
| 110 |
+
(3) Add empty images of zeros to samples up to max number of images in the batch
|
| 111 |
+
(4) Pad aspect ratios with (1,1) for all added padding images
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids,
|
| 115 |
+
labels, images, and aspect_ratio.
|
| 116 |
+
padding_idx (int): Padding index for input token ids. Defaults to 0.
|
| 117 |
+
ignore_idx (int): Padding index for labels. Defaults to -100.
|
| 118 |
+
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
|
| 119 |
+
in the batch. Defaults to None.
|
| 120 |
+
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
|
| 121 |
+
in the batch. Defaults to None.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors.
|
| 125 |
+
- tokens: Tensor of shape (bsz, max_seq_len)
|
| 126 |
+
- labels: Tensor of shape (bsz, max_seq_len)
|
| 127 |
+
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 128 |
+
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
|
| 129 |
+
|
| 130 |
+
Example:
|
| 131 |
+
>>> image_id = 1
|
| 132 |
+
>>> tokens_per_tile = 5
|
| 133 |
+
>>> c, h, w = 1, 1, 1
|
| 134 |
+
>>> batch = [
|
| 135 |
+
... {
|
| 136 |
+
... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
|
| 137 |
+
... "encoder_input": {
|
| 138 |
+
... # One image with two tiles, one image with three tiles
|
| 139 |
+
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
|
| 140 |
+
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
|
| 141 |
+
... },
|
| 142 |
+
... },
|
| 143 |
+
... {
|
| 144 |
+
... "input_ids": [1, 4], "labels": [8, 9],
|
| 145 |
+
... "encoder_input": {
|
| 146 |
+
... # One image with four tiles
|
| 147 |
+
... "images": [torch.ones(4, c, h, w)],
|
| 148 |
+
... "aspect_ratio": [torch.tensor([2, 2])],
|
| 149 |
+
... },
|
| 150 |
+
... },
|
| 151 |
+
... ]
|
| 152 |
+
... collator = MultiModalCollator(pad_max_tiles=4)
|
| 153 |
+
>>> model_inputs = collator(batch=batch)
|
| 154 |
+
>>> print(model_inputs["input_ids"])
|
| 155 |
+
tensor([[1, 2, 1, 3],
|
| 156 |
+
[1, 4, 0, 0]])
|
| 157 |
+
>>> print(model_inputs["labels"])
|
| 158 |
+
tensor([[4, 5, 6, 7],
|
| 159 |
+
[8, 9, -100, -100]])
|
| 160 |
+
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 161 |
+
torch.Size([2, 2, 4, 1, 1, 1])
|
| 162 |
+
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
|
| 163 |
+
torch.Size([2, 2, 2])
|
| 164 |
+
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
|
| 165 |
+
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
|
| 166 |
+
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
|
| 167 |
+
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
|
| 168 |
+
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
|
| 169 |
+
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
|
| 170 |
+
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
|
| 171 |
+
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
|
| 172 |
+
"""
|
| 173 |
+
# Text tokens can be handled independently by existing collaters
|
| 174 |
+
text_only = [
|
| 175 |
+
{"input_ids": sample["input_ids"], "labels": sample["labels"]}
|
| 176 |
+
for sample in batch
|
| 177 |
+
]
|
| 178 |
+
collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx)
|
| 179 |
+
|
| 180 |
+
if self.pad_max_tiles is None:
|
| 181 |
+
# Get max number of tiles in batch
|
| 182 |
+
max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch)
|
| 183 |
+
else:
|
| 184 |
+
max_num_tiles = self.pad_max_tiles
|
| 185 |
+
|
| 186 |
+
# Pad images and aspect ratios to max number of tiles
|
| 187 |
+
batch_images = []
|
| 188 |
+
batch_aspect_ratios = []
|
| 189 |
+
|
| 190 |
+
for sample in batch:
|
| 191 |
+
sample_images = []
|
| 192 |
+
for image in sample["encoder_input"]["images"]:
|
| 193 |
+
# Single image in each sample has shape (n_tiles, c, h, w)
|
| 194 |
+
n_tiles = image.shape[0]
|
| 195 |
+
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
|
| 196 |
+
# where image_seq_len = n_tiles * tokens_per_tile
|
| 197 |
+
padding_tiles = max_num_tiles - n_tiles
|
| 198 |
+
|
| 199 |
+
# Image should now have shape (max_num_tiles, c, h, w)
|
| 200 |
+
padded_image = F.pad(
|
| 201 |
+
image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
sample_images.append(padded_image)
|
| 205 |
+
# Stack multiple images and masks per sample in num_images dimension
|
| 206 |
+
batch_images.append(torch.stack(sample_images))
|
| 207 |
+
batch_aspect_ratios.append(
|
| 208 |
+
torch.stack(sample["encoder_input"]["aspect_ratio"])
|
| 209 |
+
)
|
| 210 |
+
# Finally, pad images, masks, aspect ratios to max number of images in batch
|
| 211 |
+
# (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 212 |
+
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
|
| 213 |
+
# (bsz, max_num_images, 2)
|
| 214 |
+
collated_aspect_ratios = pad_sequence(
|
| 215 |
+
batch_aspect_ratios, batch_first=True, padding_value=1
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
batch_dict = {
|
| 219 |
+
"input_ids": collated_text["input_ids"],
|
| 220 |
+
"labels": collated_text["labels"],
|
| 221 |
+
"encoder_input": {
|
| 222 |
+
"images": collated_images,
|
| 223 |
+
"aspect_ratio": collated_aspect_ratios,
|
| 224 |
+
},
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
return batch_dict
|
torchtitan/experiments/multimodal/mm_dataset.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from datasets import Dataset, load_dataset
|
| 13 |
+
from datasets.distributed import split_dataset_by_node
|
| 14 |
+
|
| 15 |
+
from mm_collator import MultiModalCollator
|
| 16 |
+
from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer
|
| 17 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 18 |
+
from torch.utils.data import IterableDataset
|
| 19 |
+
from transform import CLIPTransform
|
| 20 |
+
from utils import load_image
|
| 21 |
+
|
| 22 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 23 |
+
from torchtitan.config_manager import JobConfig
|
| 24 |
+
from torchtitan.tools.logging import logger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_obelics_dataset(dataset_path: str):
|
| 28 |
+
"""Load C4 dataset with default configuration."""
|
| 29 |
+
return load_dataset(dataset_path, split="train", streaming=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _process_obelics_sample(
|
| 33 |
+
sample: dict[str, Any], image_token: str = "<|image|>"
|
| 34 |
+
) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]:
|
| 35 |
+
"""
|
| 36 |
+
This function formats samples from the OBELICS dataset
|
| 37 |
+
Returns:
|
| 38 |
+
Dict[str, Any]: The transformed sample with the following fields:
|
| 39 |
+
- images: List[PIL.Image.Image] with the loaded images
|
| 40 |
+
- text: str with the text of the sample ready to be tokenized including the image tokens
|
| 41 |
+
Example:
|
| 42 |
+
>>> formatted_sample = format_obelics(sample, image_token="<|image|>")
|
| 43 |
+
>>> print(formatted_sample["text"])
|
| 44 |
+
... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :("
|
| 45 |
+
"""
|
| 46 |
+
sample_images = [image for image in sample["images"] if image is not None]
|
| 47 |
+
sample_text = [
|
| 48 |
+
text if text is not None else image_token for text in sample["texts"]
|
| 49 |
+
]
|
| 50 |
+
return {
|
| 51 |
+
"images": [load_image(image) for image in sample_images],
|
| 52 |
+
"text": "".join(map(str, sample_text)),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class DatasetConfig:
|
| 58 |
+
path: str
|
| 59 |
+
loader: Callable
|
| 60 |
+
sample_processor: Callable
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Add your dataset here here - more information at docs/datasets.md
|
| 64 |
+
MM_DATASETS = {
|
| 65 |
+
"obelics": DatasetConfig(
|
| 66 |
+
path="HuggingFaceM4/OBELICS",
|
| 67 |
+
loader=_load_obelics_dataset,
|
| 68 |
+
sample_processor=_process_obelics_sample,
|
| 69 |
+
),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _validate_mm_dataset(
|
| 74 |
+
dataset_name: str, dataset_path: str = None
|
| 75 |
+
) -> tuple[str, Callable, Callable]:
|
| 76 |
+
"""Validate dataset name and path."""
|
| 77 |
+
if dataset_name not in MM_DATASETS:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"Dataset {dataset_name} is not supported. "
|
| 80 |
+
f"Supported datasets are: {list(MM_DATASETS.keys())}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
config = MM_DATASETS[dataset_name]
|
| 84 |
+
path = dataset_path or config.path
|
| 85 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 86 |
+
return path, config.loader, config.sample_processor
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class MultiModalDataset(IterableDataset, Stateful):
|
| 90 |
+
"""PyTorch MultiModal Dataset.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
dataset_name (str): name of the dataset to load
|
| 94 |
+
tokenizer (Tokenizer):
|
| 95 |
+
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
|
| 96 |
+
world_size (int): number of data parallel processes participating in training
|
| 97 |
+
rank (int): rank of the current data parallel process
|
| 98 |
+
infinite (bool): whether to loop infinitely over the dataset
|
| 99 |
+
|
| 100 |
+
We currently ONLY support the OBELICS dataset
|
| 101 |
+
|
| 102 |
+
Example use:
|
| 103 |
+
>>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer)
|
| 104 |
+
>>> for batch in Dataloader(ds, batch_size=8):
|
| 105 |
+
print(f"Batch size: {len(batch)}")
|
| 106 |
+
Batch size: 8
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
dataset_name: str,
|
| 112 |
+
dataset_path: Optional[str],
|
| 113 |
+
tokenizer: Tokenizer,
|
| 114 |
+
image_token: str = "<|image|>",
|
| 115 |
+
tile_size: int = 448,
|
| 116 |
+
max_num_tiles: int = 4,
|
| 117 |
+
seq_len: int = 2048,
|
| 118 |
+
dp_rank: int = 0,
|
| 119 |
+
dp_world_size: int = 1,
|
| 120 |
+
infinite: bool = False,
|
| 121 |
+
) -> None:
|
| 122 |
+
# Force lowercase for consistent comparison
|
| 123 |
+
dataset_name = dataset_name.lower()
|
| 124 |
+
|
| 125 |
+
path, dataset_loader, sample_processor = _validate_mm_dataset(
|
| 126 |
+
dataset_name, dataset_path
|
| 127 |
+
)
|
| 128 |
+
ds = dataset_loader(path)
|
| 129 |
+
|
| 130 |
+
# TODO: support shuffling
|
| 131 |
+
self.dataset_name = dataset_name
|
| 132 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 133 |
+
self._tokenizer = tokenizer
|
| 134 |
+
self.seq_len = seq_len
|
| 135 |
+
self.infinite = infinite
|
| 136 |
+
self._sample_processor = sample_processor
|
| 137 |
+
self.image_token = (
|
| 138 |
+
image_token # TODO(tj.solergibert) Add `image_token` to JobConfig
|
| 139 |
+
)
|
| 140 |
+
# TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig
|
| 141 |
+
self.transform_image = CLIPTransform(
|
| 142 |
+
image_mean=(
|
| 143 |
+
0.48145466,
|
| 144 |
+
0.4578275,
|
| 145 |
+
0.40821073,
|
| 146 |
+
), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?,
|
| 147 |
+
image_std=(0.26862954, 0.26130258, 0.27577711),
|
| 148 |
+
tile_size=tile_size,
|
| 149 |
+
possible_resolutions=None,
|
| 150 |
+
max_num_tiles=max_num_tiles,
|
| 151 |
+
resample="bilinear",
|
| 152 |
+
resize_to_max_canvas=False,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# variables for checkpointing
|
| 156 |
+
self._sample_idx = 0
|
| 157 |
+
|
| 158 |
+
def __iter__(self):
|
| 159 |
+
|
| 160 |
+
while True:
|
| 161 |
+
for sample in self._get_data_iter():
|
| 162 |
+
try:
|
| 163 |
+
sample = self._sample_processor(
|
| 164 |
+
sample, image_token=self.image_token
|
| 165 |
+
)
|
| 166 |
+
except Exception:
|
| 167 |
+
continue
|
| 168 |
+
self._sample_idx += 1
|
| 169 |
+
|
| 170 |
+
# CLIP Transform
|
| 171 |
+
encoder_input = {"images": [], "aspect_ratio": []}
|
| 172 |
+
for image in sample["images"]:
|
| 173 |
+
out = self.transform_image(image)
|
| 174 |
+
encoder_input["images"].append(out["image"])
|
| 175 |
+
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
|
| 176 |
+
sample["encoder_input"] = encoder_input
|
| 177 |
+
|
| 178 |
+
# Tokenize
|
| 179 |
+
tokens = self._tokenizer.encode(
|
| 180 |
+
sample["text"],
|
| 181 |
+
bos=True,
|
| 182 |
+
eos=True,
|
| 183 |
+
allowed_special=set(["<|image|>"]),
|
| 184 |
+
)
|
| 185 |
+
sample["input_ids"] = torch.LongTensor(tokens[:-1])
|
| 186 |
+
sample["labels"] = torch.LongTensor(tokens[1:])
|
| 187 |
+
# Mask BOS, EOS & image tokens from the loss
|
| 188 |
+
sample["labels"] = torch.where(
|
| 189 |
+
torch.isin(
|
| 190 |
+
sample["labels"],
|
| 191 |
+
torch.LongTensor(
|
| 192 |
+
[
|
| 193 |
+
self._tokenizer.bos_id,
|
| 194 |
+
self._tokenizer.eos_id,
|
| 195 |
+
self._tokenizer.image_id,
|
| 196 |
+
]
|
| 197 |
+
),
|
| 198 |
+
),
|
| 199 |
+
IGNORE_INDEX,
|
| 200 |
+
sample["labels"],
|
| 201 |
+
)
|
| 202 |
+
# Truncate
|
| 203 |
+
sample["input_ids"], sample["labels"] = (
|
| 204 |
+
sample["input_ids"][: self.seq_len],
|
| 205 |
+
sample["labels"][: self.seq_len],
|
| 206 |
+
)
|
| 207 |
+
yield sample
|
| 208 |
+
|
| 209 |
+
if not self.infinite:
|
| 210 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 211 |
+
break
|
| 212 |
+
else:
|
| 213 |
+
# Reset offset for the next iteration
|
| 214 |
+
self._sample_idx = 0
|
| 215 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 216 |
+
|
| 217 |
+
def _get_data_iter(self):
|
| 218 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 219 |
+
return iter([])
|
| 220 |
+
|
| 221 |
+
it = iter(self._data)
|
| 222 |
+
for _ in range(self._sample_idx):
|
| 223 |
+
next(it)
|
| 224 |
+
return it
|
| 225 |
+
|
| 226 |
+
def load_state_dict(self, state_dict):
|
| 227 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 228 |
+
|
| 229 |
+
def state_dict(self):
|
| 230 |
+
return {"sample_idx": self._sample_idx}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_mm_dataloader(
|
| 234 |
+
dp_world_size: int,
|
| 235 |
+
dp_rank: int,
|
| 236 |
+
tokenizer: Tokenizer,
|
| 237 |
+
job_config: JobConfig,
|
| 238 |
+
infinite: bool = True,
|
| 239 |
+
) -> ParallelAwareDataloader:
|
| 240 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 241 |
+
dataset_name = job_config.training.dataset
|
| 242 |
+
dataset_path = job_config.training.dataset_path
|
| 243 |
+
batch_size = job_config.training.batch_size
|
| 244 |
+
seq_len = job_config.training.seq_len
|
| 245 |
+
pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig
|
| 246 |
+
padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig
|
| 247 |
+
|
| 248 |
+
hf_ds = MultiModalDataset(
|
| 249 |
+
dataset_name=dataset_name,
|
| 250 |
+
dataset_path=dataset_path,
|
| 251 |
+
tokenizer=tokenizer,
|
| 252 |
+
seq_len=seq_len,
|
| 253 |
+
dp_rank=dp_rank,
|
| 254 |
+
dp_world_size=dp_world_size,
|
| 255 |
+
infinite=infinite,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
collate_fn = MultiModalCollator(
|
| 259 |
+
padding_idx=padding_idx, pad_max_tiles=pad_max_tiles
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return ParallelAwareDataloader(
|
| 263 |
+
dataset=hf_ds,
|
| 264 |
+
dp_rank=dp_rank,
|
| 265 |
+
dp_world_size=dp_world_size,
|
| 266 |
+
batch_size=batch_size,
|
| 267 |
+
collate_fn=collate_fn,
|
| 268 |
+
)
|
torchtitan/experiments/multimodal/tests/test_multimodal_model.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from torchtitan.experiments.llama_multimodal import (
|
| 11 |
+
ModelArgs,
|
| 12 |
+
MultimodalDecoder,
|
| 13 |
+
VisionEncoder,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .test_utils import fixed_init_model, fixed_init_tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def encoder_config():
|
| 21 |
+
return ModelArgs(
|
| 22 |
+
encoder_embed_dim=32,
|
| 23 |
+
encoder_num_layers=2,
|
| 24 |
+
encoder_num_heads=4,
|
| 25 |
+
tile_size=49,
|
| 26 |
+
patch_size=9,
|
| 27 |
+
max_num_tiles=4,
|
| 28 |
+
in_channels=3,
|
| 29 |
+
return_intermediates=[0, 1],
|
| 30 |
+
num_layers_projection=2,
|
| 31 |
+
decoder_embed_dim=128,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture
|
| 36 |
+
def decoder_config():
|
| 37 |
+
return ModelArgs(
|
| 38 |
+
decoder_embed_dim=512,
|
| 39 |
+
vocab_size=10000,
|
| 40 |
+
fusion_interval=2,
|
| 41 |
+
num_special_tokens=3,
|
| 42 |
+
decoder_num_layers=6,
|
| 43 |
+
decoder_num_heads=8,
|
| 44 |
+
decoder_num_kv_heads=4,
|
| 45 |
+
max_seq_len=512,
|
| 46 |
+
rope_theta=50000.0,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestMultimodalModelVisionEncoder:
|
| 51 |
+
@pytest.fixture(autouse=True)
|
| 52 |
+
def setup_class(self, encoder_config):
|
| 53 |
+
self.model_args = encoder_config
|
| 54 |
+
self.batch_size = 1
|
| 55 |
+
self.num_imgs = 2
|
| 56 |
+
self.num_tiles = 4
|
| 57 |
+
self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape(
|
| 58 |
+
self.batch_size, self.num_imgs, 2
|
| 59 |
+
)
|
| 60 |
+
image = torch.rand(
|
| 61 |
+
(
|
| 62 |
+
self.batch_size,
|
| 63 |
+
self.num_imgs,
|
| 64 |
+
self.num_tiles,
|
| 65 |
+
self.model_args.in_channels,
|
| 66 |
+
self.model_args.tile_size,
|
| 67 |
+
self.model_args.tile_size,
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1)
|
| 71 |
+
|
| 72 |
+
def test_llama_mm_vision_encoder(self):
|
| 73 |
+
model = VisionEncoder(self.model_args)
|
| 74 |
+
fixed_init_model(model, min_val=-1, max_val=1)
|
| 75 |
+
output = model(self.image, self.aspect_ratio)
|
| 76 |
+
expected_shape = (
|
| 77 |
+
self.batch_size,
|
| 78 |
+
self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1),
|
| 79 |
+
self.model_args.decoder_embed_dim,
|
| 80 |
+
)
|
| 81 |
+
assert (
|
| 82 |
+
output.shape == expected_shape
|
| 83 |
+
), f"Expected shape {expected_shape}, but got {output.shape}"
|
| 84 |
+
|
| 85 |
+
# TODO: Need to ensure numerical stability before doing convergence test.
|
| 86 |
+
# output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is
|
| 87 |
+
# the test value from the original torch tune test
|
| 88 |
+
# assert torch.allclose(
|
| 89 |
+
# output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3
|
| 90 |
+
# )
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestMultimodalModelDecoder:
|
| 94 |
+
@pytest.fixture(autouse=True)
|
| 95 |
+
def setup_class(self, decoder_config):
|
| 96 |
+
self.model_args = decoder_config
|
| 97 |
+
self.batch_size = 1
|
| 98 |
+
self.decoder_embed_dim = self.model_args.decoder_embed_dim
|
| 99 |
+
self.vocab_size = self.model_args.vocab_size
|
| 100 |
+
self.seq_len = 128
|
| 101 |
+
self.input = {
|
| 102 |
+
"tokens": torch.arange(self.batch_size * self.seq_len).reshape(
|
| 103 |
+
self.batch_size, self.seq_len
|
| 104 |
+
),
|
| 105 |
+
"encoder_input": fixed_init_tensor(
|
| 106 |
+
(self.batch_size, self.seq_len, self.decoder_embed_dim),
|
| 107 |
+
min_val=-1,
|
| 108 |
+
max_val=1,
|
| 109 |
+
),
|
| 110 |
+
"encoder_mask": None,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def test_llama_mm_decoder(self):
|
| 115 |
+
model = MultimodalDecoder(self.model_args)
|
| 116 |
+
fixed_init_model(model, min_val=-1, max_val=1)
|
| 117 |
+
output = model(**self.input)
|
| 118 |
+
expected_shape = (self.batch_size, self.seq_len, self.vocab_size)
|
| 119 |
+
assert (
|
| 120 |
+
output.shape == expected_shape
|
| 121 |
+
), f"Expected shape {expected_shape}, but got {output.shape}"
|
| 122 |
+
|
| 123 |
+
# TODO: Need to ensure numerical stability before doing convergence test.
|
| 124 |
+
# output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is
|
| 125 |
+
# the test value from the original torch tune test
|
| 126 |
+
# assert torch.allclose(
|
| 127 |
+
# output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3
|
| 128 |
+
# )
|
torchtitan/experiments/multimodal/utils.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List, Optional, Set, Tuple, Union
|
| 13 |
+
from urllib import request
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision
|
| 17 |
+
from torchvision.transforms.v2 import functional as F
|
| 18 |
+
|
| 19 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
|
| 20 |
+
def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
|
| 21 |
+
"""
|
| 22 |
+
Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image (torch.Tensor): Input image to crop into tiles.
|
| 26 |
+
tile_size (int): Size of each tile.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
|
| 30 |
+
|
| 31 |
+
Examples:
|
| 32 |
+
>>> image = torch.rand(3, 200, 300)
|
| 33 |
+
>>> tiles = tile_crop(image, tile_size=50)
|
| 34 |
+
>>> tiles.shape # 4x6 = 24 tiles
|
| 35 |
+
torch.Size([24, 3, 50, 50])
|
| 36 |
+
|
| 37 |
+
>>> image = torch.rand(3, 400, 600)
|
| 38 |
+
>>> tiles = tile_crop(image, tile_size=200)
|
| 39 |
+
>>> tiles.shape # 2x3 = 6 tiles
|
| 40 |
+
torch.Size([6, 3, 200, 200])
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
channel_size, height, width = image.shape
|
| 44 |
+
|
| 45 |
+
# assert sizes are divisible
|
| 46 |
+
assert (
|
| 47 |
+
height % tile_size == 0 and width % tile_size == 0
|
| 48 |
+
), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
|
| 49 |
+
|
| 50 |
+
# Reshape to split height and width into tile_size blocks
|
| 51 |
+
tiles_height = height // tile_size
|
| 52 |
+
tiles_width = width // tile_size
|
| 53 |
+
|
| 54 |
+
reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
|
| 55 |
+
|
| 56 |
+
# Transpose to bring tiles together
|
| 57 |
+
# We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
|
| 58 |
+
transposed = reshaped.permute(1, 3, 0, 2, 4)
|
| 59 |
+
|
| 60 |
+
# Flatten the tiles
|
| 61 |
+
tiles = transposed.contiguous().view(
|
| 62 |
+
tiles_height * tiles_width, channel_size, tile_size, tile_size
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return tiles
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
| 69 |
+
def resize_with_pad(
|
| 70 |
+
image: torch.Tensor,
|
| 71 |
+
target_size: Tuple[int, int],
|
| 72 |
+
resample: torchvision.transforms.InterpolationMode,
|
| 73 |
+
max_size: Optional[int] = None,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Resizes and pads an image to target_size without causing distortion.
|
| 77 |
+
The user can set max_size to limit upscaling when target_size exceeds image_size.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
| 81 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
| 82 |
+
resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
|
| 83 |
+
Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
|
| 84 |
+
InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
|
| 85 |
+
max_size (Optional[int]): The maximum size to upscale the image to.
|
| 86 |
+
If None, will upscale up to target_size.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
torch.Tensor: The resized and padded image tensor in the format [..., H, W].
|
| 90 |
+
|
| 91 |
+
Examples:
|
| 92 |
+
|
| 93 |
+
Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
|
| 94 |
+
and then padded from (448, 1194) to (448, 1344).
|
| 95 |
+
|
| 96 |
+
>>> max_size = None
|
| 97 |
+
>>> image = torch.rand([3, 300, 800])
|
| 98 |
+
>>> target_size = (448, 1344)
|
| 99 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
| 100 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
| 101 |
+
|
| 102 |
+
Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
|
| 103 |
+
|
| 104 |
+
>>> max_size = 600
|
| 105 |
+
>>> image = torch.rand([3, 300, 800])
|
| 106 |
+
>>> target_size = (448, 1344)
|
| 107 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
| 108 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
| 109 |
+
|
| 110 |
+
Example 3: The image will be downscaled from (500, 1000) to (224, 448),
|
| 111 |
+
and padded from (224, 448) to (448, 448).
|
| 112 |
+
|
| 113 |
+
>>> max_size = 600
|
| 114 |
+
>>> image = torch.rand([3, 500, 1000])
|
| 115 |
+
>>> target_size = (448, 488)
|
| 116 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
| 117 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
image_height, image_width = image.shape[-2:]
|
| 122 |
+
image_size = (image_height, image_width)
|
| 123 |
+
|
| 124 |
+
# If target_size requires upscaling, we might want to limit the upscaling to max_size
|
| 125 |
+
if max_size is not None:
|
| 126 |
+
new_target_height = min(max(image_height, max_size), target_size[0])
|
| 127 |
+
new_target_width = min(max(image_width, max_size), target_size[1])
|
| 128 |
+
target_size_resize = (new_target_height, new_target_width)
|
| 129 |
+
else:
|
| 130 |
+
target_size_resize = target_size
|
| 131 |
+
|
| 132 |
+
# resize to target_size while preserving aspect ratio
|
| 133 |
+
new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
|
| 134 |
+
image_size=image_size,
|
| 135 |
+
target_size=target_size_resize,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
image = F.resize(
|
| 139 |
+
inpt=image,
|
| 140 |
+
size=list(new_size_preserving_aspect_ratio),
|
| 141 |
+
interpolation=resample,
|
| 142 |
+
antialias=True,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
image = _pad_image_top_left(image=image, target_size=target_size)
|
| 146 |
+
|
| 147 |
+
return image
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
| 151 |
+
def _pad_image_top_left(
|
| 152 |
+
image: torch.Tensor,
|
| 153 |
+
target_size: Tuple[int, int],
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Places the image at the top left of the canvas and pads with 0 the right and bottom
|
| 157 |
+
to fit to the target resolution. If target_size < image_size, it will crop the image.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
| 161 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
torch.Tensor: The padded image tensor in the format [..., H, W].
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
image_size = image.shape[-2:]
|
| 168 |
+
|
| 169 |
+
height, width = image_size
|
| 170 |
+
target_height, target_width = target_size
|
| 171 |
+
|
| 172 |
+
pad_x = target_width - width
|
| 173 |
+
pad_y = target_height - height
|
| 174 |
+
|
| 175 |
+
padding = [0, 0, pad_x, pad_y]
|
| 176 |
+
return F.pad(inpt=image, padding=padding)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
| 180 |
+
def _get_max_res_without_distortion(
|
| 181 |
+
image_size: Tuple[int, int],
|
| 182 |
+
target_size: Tuple[int, int],
|
| 183 |
+
) -> Tuple[int, int]:
|
| 184 |
+
"""
|
| 185 |
+
Determines the maximum resolution to which an image can be resized to without distorting its
|
| 186 |
+
aspect ratio, based on the target resolution.
|
| 187 |
+
|
| 188 |
+
For example, if image_size = (200,400) and target_size = (600,800),
|
| 189 |
+
scale_h = 600/200 = 3
|
| 190 |
+
scale_w = 800/400 = 2
|
| 191 |
+
So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
|
| 192 |
+
|
| 193 |
+
Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
image_size (Tuple[int, int]): The original resolution of the image.
|
| 197 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into.
|
| 198 |
+
Returns:
|
| 199 |
+
Tuple[int, int]: The optimal dimensions to which the image should be resized.
|
| 200 |
+
Examples:
|
| 201 |
+
>>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
|
| 202 |
+
(133, 200)
|
| 203 |
+
>>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
|
| 204 |
+
(450, 337)
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
original_height, original_width = image_size
|
| 208 |
+
target_height, target_width = target_size
|
| 209 |
+
|
| 210 |
+
scale_w = target_width / original_width
|
| 211 |
+
scale_h = target_height / original_height
|
| 212 |
+
|
| 213 |
+
if scale_w < scale_h:
|
| 214 |
+
new_width = target_width
|
| 215 |
+
new_height = min(math.floor(original_height * scale_w), target_height)
|
| 216 |
+
else:
|
| 217 |
+
new_height = target_height
|
| 218 |
+
new_width = min(math.floor(original_width * scale_h), target_width)
|
| 219 |
+
|
| 220 |
+
return new_height, new_width
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
| 224 |
+
def _get_factors(n: int) -> Set[int]:
|
| 225 |
+
"""
|
| 226 |
+
Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
n (int): The number to find factors for.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
set: A set containing all factors of the number.
|
| 233 |
+
|
| 234 |
+
Examples:
|
| 235 |
+
>>> _get_factors(n=12)
|
| 236 |
+
{1, 2, 3, 4, 6, 12}
|
| 237 |
+
"""
|
| 238 |
+
factors_set = set()
|
| 239 |
+
|
| 240 |
+
for i in range(1, int(n**0.5) + 1):
|
| 241 |
+
if n % i == 0:
|
| 242 |
+
factors_set.add(i)
|
| 243 |
+
factors_set.add(n // i)
|
| 244 |
+
return factors_set
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
| 248 |
+
def get_canvas_best_fit(
|
| 249 |
+
image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
|
| 250 |
+
) -> Tuple[int, int]:
|
| 251 |
+
"""
|
| 252 |
+
Determines the best canvas possible from a list of possible resolutions to
|
| 253 |
+
resize an image to, without distortion.
|
| 254 |
+
|
| 255 |
+
For each possible resolution, calculates the scaling factors for
|
| 256 |
+
width and height, and selects the smallest one, which is the limiting side.
|
| 257 |
+
E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
|
| 258 |
+
then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
|
| 259 |
+
|
| 260 |
+
If there are multiple canvases that satisfy the conditions,
|
| 261 |
+
we pick the one with the lowest area to minimize padding.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
image (torch.Tensor): The image we want to fit into a canvas.
|
| 265 |
+
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
| 266 |
+
row represents a possible canvas.
|
| 267 |
+
resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
|
| 268 |
+
If False, pick the canvas that minimizes downscaling, including no downscaling at all.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Tuple[int, int]: The best resolution to fit the image into.
|
| 272 |
+
|
| 273 |
+
Examples:
|
| 274 |
+
>>> image = torch.rand(3, 200, 300)
|
| 275 |
+
>>> possible_resolutions = torch.tensor([
|
| 276 |
+
... [224, 672],
|
| 277 |
+
... [672, 224],
|
| 278 |
+
... [224, 448],
|
| 279 |
+
... [448, 224],
|
| 280 |
+
... [224, 224]
|
| 281 |
+
... ])
|
| 282 |
+
>>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
|
| 283 |
+
(224, 448)
|
| 284 |
+
|
| 285 |
+
In the example above, we calculate the scaling factors for each possible resolution
|
| 286 |
+
|
| 287 |
+
>>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
| 288 |
+
>>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
| 289 |
+
>>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
| 290 |
+
|
| 291 |
+
Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
|
| 292 |
+
|
| 293 |
+
>>> upscaling_options = torch.tensor([1.1200, 1.1200])
|
| 294 |
+
>>> selected_scale = torch.tensor(1.1200)
|
| 295 |
+
|
| 296 |
+
There are two possible options, so we pick the one with the smallest area
|
| 297 |
+
|
| 298 |
+
>>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
|
| 299 |
+
>>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
original_height, original_width = image.shape[-2:]
|
| 303 |
+
|
| 304 |
+
# possible resolutions heights/widths
|
| 305 |
+
target_heights, target_widths = (
|
| 306 |
+
possible_resolutions[:, 0],
|
| 307 |
+
possible_resolutions[:, 1],
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# scaling factors to resize the image without distortion
|
| 311 |
+
scale_w = target_widths / original_width
|
| 312 |
+
scale_h = target_heights / original_height
|
| 313 |
+
|
| 314 |
+
# get limiting side scaling -> no distortion
|
| 315 |
+
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
| 316 |
+
|
| 317 |
+
# filter only scales that allow upscaling
|
| 318 |
+
upscaling_options = scales[scales >= 1]
|
| 319 |
+
if len(upscaling_options) > 0:
|
| 320 |
+
if resize_to_max_canvas:
|
| 321 |
+
selected_scale = torch.max(upscaling_options)
|
| 322 |
+
else:
|
| 323 |
+
selected_scale = torch.min(upscaling_options)
|
| 324 |
+
else:
|
| 325 |
+
# no upscaling possible,
|
| 326 |
+
# get the minimum downscaling (max scale for scales<1)
|
| 327 |
+
downscaling_options = scales[scales < 1]
|
| 328 |
+
selected_scale = torch.max(downscaling_options)
|
| 329 |
+
|
| 330 |
+
# get all resolutions that support this scaling factor,
|
| 331 |
+
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
| 332 |
+
chosen_canvas = possible_resolutions[scales == selected_scale]
|
| 333 |
+
|
| 334 |
+
# if there are multiple resolutions,
|
| 335 |
+
# get the one with minimum area to reduce padding
|
| 336 |
+
if len(chosen_canvas) > 1:
|
| 337 |
+
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
| 338 |
+
optimal_idx = torch.argmin(areas)
|
| 339 |
+
optimal_canvas = chosen_canvas[optimal_idx]
|
| 340 |
+
else:
|
| 341 |
+
optimal_canvas = chosen_canvas[0]
|
| 342 |
+
|
| 343 |
+
return tuple(optimal_canvas.tolist())
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
| 347 |
+
def find_supported_resolutions(
|
| 348 |
+
max_num_tiles: int, tile_size: int
|
| 349 |
+
) -> List[Tuple[int, int]]:
|
| 350 |
+
"""
|
| 351 |
+
Computes all combinations of resolutions, multiple of tile_size,
|
| 352 |
+
that contain up to max_num_tiles. Useful for when dividing an image into tiles.
|
| 353 |
+
|
| 354 |
+
For example, if we want at most 2 tiles per image, then we can support the
|
| 355 |
+
following resolutions: (1x1, 1x2, 2x1) * tile_size
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
max_num_tiles (int): Maximum number of tiles.
|
| 359 |
+
tile_size (int): Size of the side of the tile.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
|
| 363 |
+
|
| 364 |
+
Examples:
|
| 365 |
+
|
| 366 |
+
>>> max_num_tiles = 4
|
| 367 |
+
>>> tile_size = 224
|
| 368 |
+
>>> find_supported_resolutions(max_num_tiles, tile_size)
|
| 369 |
+
[(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
# create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
|
| 373 |
+
# example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
|
| 374 |
+
asp_dict = defaultdict(list)
|
| 375 |
+
for _tile_size in range(max_num_tiles, 0, -1):
|
| 376 |
+
factors = sorted(_get_factors(_tile_size))
|
| 377 |
+
asp_ratios = [(factor, _tile_size // factor) for factor in factors]
|
| 378 |
+
for height, width in asp_ratios:
|
| 379 |
+
ratio_float = height / width
|
| 380 |
+
asp_dict[ratio_float].append((height, width))
|
| 381 |
+
|
| 382 |
+
# get the resolutions multiplied by the tile_size
|
| 383 |
+
possible_resolutions = []
|
| 384 |
+
for ar, resolution in asp_dict.items():
|
| 385 |
+
for height, width in resolution:
|
| 386 |
+
possible_resolutions.append((height * tile_size, width * tile_size))
|
| 387 |
+
|
| 388 |
+
return possible_resolutions
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# NOTE Copied from torchtune.data._utils.py
|
| 392 |
+
def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
|
| 393 |
+
"""
|
| 394 |
+
Convenience method to load an image in torch.Tensor format from a local file path or remote source.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
image_loc (Union[Path, str]): Local file path or remote source pointing to the image
|
| 398 |
+
which will be loaded in PIL format.
|
| 399 |
+
|
| 400 |
+
Note:
|
| 401 |
+
If loading an image from a remote source, the function expects the URL provided in ``image_loc``
|
| 402 |
+
to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
|
| 403 |
+
|
| 404 |
+
Raises:
|
| 405 |
+
ValueError: If the image cannot be loaded from remote source, **or**
|
| 406 |
+
if the image cannot be opened as a :class:`~torch.Tensor`.
|
| 407 |
+
|
| 408 |
+
Examples:
|
| 409 |
+
>>> # Load from remote source
|
| 410 |
+
>>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
|
| 411 |
+
|
| 412 |
+
>>> # Load from local file path
|
| 413 |
+
>>> image = load_image(Path("/home/user/bird.jpg"))
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
torch.Tensor: The loaded image.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
# If pointing to remote source, try to load to local
|
| 420 |
+
if isinstance(image_loc, str) and image_loc.startswith("http"):
|
| 421 |
+
try:
|
| 422 |
+
image_loc = request.urlopen(image_loc).read()
|
| 423 |
+
image = torchvision.io.decode_image(
|
| 424 |
+
torch.frombuffer(image_loc, dtype=torch.uint8),
|
| 425 |
+
mode="RGB",
|
| 426 |
+
)
|
| 427 |
+
except Exception as e:
|
| 428 |
+
raise ValueError("Failed to load remote image as torch.Tensor") from e
|
| 429 |
+
|
| 430 |
+
# Open the local image as a Tensor image
|
| 431 |
+
else:
|
| 432 |
+
try:
|
| 433 |
+
image = torchvision.io.decode_image(image_loc, mode="RGB")
|
| 434 |
+
except Exception as e:
|
| 435 |
+
raise ValueError("Failed to load local image as torch.Tensor") from e
|
| 436 |
+
|
| 437 |
+
return image
|
torchtitan/experiments/simple_fsdp/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
| 10 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 11 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 12 |
+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
|
| 13 |
+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
|
| 14 |
+
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
|
| 15 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 16 |
+
|
| 17 |
+
from .model import SimpleFSDPTransformer
|
| 18 |
+
from .parallelize_llama import parallelize_llama
|
| 19 |
+
|
| 20 |
+
register_train_spec(
|
| 21 |
+
TrainSpec(
|
| 22 |
+
name="llama3_simple_fsdp",
|
| 23 |
+
cls=SimpleFSDPTransformer,
|
| 24 |
+
config=llama3_configs,
|
| 25 |
+
parallelize_fn=parallelize_llama,
|
| 26 |
+
pipelining_fn=pipeline_llama,
|
| 27 |
+
build_optimizers_fn=build_optimizers,
|
| 28 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 29 |
+
build_dataloader_fn=build_hf_dataloader,
|
| 30 |
+
build_tokenizer_fn=build_tiktoken_tokenizer,
|
| 31 |
+
build_loss_fn=build_cross_entropy_loss,
|
| 32 |
+
)
|
| 33 |
+
)
|
torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-311.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-311.pyc
ADDED
|
Binary file (7.22 kB). View file
|
|
|
torchtitan/experiments/simple_fsdp/model.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torchtitan.models.llama3 import Transformer, TransformerModelArgs
|
| 8 |
+
from .simple_fsdp import disable_data_parallel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SimpleFSDPTransformer(Transformer):
|
| 12 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 13 |
+
super().__init__(model_args)
|
| 14 |
+
self.init_weights()
|
| 15 |
+
|
| 16 |
+
def init_weights(self, *args, **kwargs):
|
| 17 |
+
with disable_data_parallel():
|
| 18 |
+
super().init_weights(*args, **kwargs)
|
torchtitan/experiments/simple_fsdp/parallelize_llama.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from torch.distributed import DeviceMesh
|
| 11 |
+
|
| 12 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
| 13 |
+
from torchtitan.distributed import ParallelDims
|
| 14 |
+
from torchtitan.models.llama3.parallelize_llama import apply_ac
|
| 15 |
+
from torchtitan.tools.logging import logger
|
| 16 |
+
|
| 17 |
+
from .simple_fsdp import data_parallel, MixedPrecisionPolicy
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parallelize_llama(
|
| 21 |
+
model: nn.Module,
|
| 22 |
+
world_mesh: DeviceMesh,
|
| 23 |
+
parallel_dims: ParallelDims,
|
| 24 |
+
job_config: JobConfig,
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 28 |
+
parallelism to the model.
|
| 29 |
+
|
| 30 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 31 |
+
the model must fit on GPU or CPU memory.
|
| 32 |
+
"""
|
| 33 |
+
# TODO(ruisizhang123): Add support for TP (on-going)
|
| 34 |
+
# if parallel_dims.tp_enabled:
|
| 35 |
+
# if (
|
| 36 |
+
# job_config.parallelism.enable_async_tensor_parallel
|
| 37 |
+
# and not job_config.training.compile
|
| 38 |
+
# ):
|
| 39 |
+
# raise RuntimeError("Async TP requires --training.compile")
|
| 40 |
+
|
| 41 |
+
# enable_float8_linear = "float8" in job_config.model.converters
|
| 42 |
+
# float8_is_rowwise = job_config.float8.recipe_name in (
|
| 43 |
+
# "rowwise",
|
| 44 |
+
# "rowwise_with_gw_hp",
|
| 45 |
+
# )
|
| 46 |
+
|
| 47 |
+
# # For now, float8 all-gather with TP is only supported for tensorwise
|
| 48 |
+
# # float8 scaling recipes. For rowwise recipes, we use regular TP and
|
| 49 |
+
# # all-gather happens in high precision.
|
| 50 |
+
# enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
| 51 |
+
|
| 52 |
+
# apply_tp(
|
| 53 |
+
# model,
|
| 54 |
+
# world_mesh["tp"],
|
| 55 |
+
# loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 56 |
+
# enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
| 57 |
+
# enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 61 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 62 |
+
|
| 63 |
+
# apply data parallel
|
| 64 |
+
if (
|
| 65 |
+
parallel_dims.dp_replicate_enabled
|
| 66 |
+
or parallel_dims.dp_shard_enabled
|
| 67 |
+
or parallel_dims.cp_enabled
|
| 68 |
+
):
|
| 69 |
+
if parallel_dims.dp_replicate_enabled:
|
| 70 |
+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
|
| 71 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 72 |
+
dp_mode = "hybrid_shard"
|
| 73 |
+
else:
|
| 74 |
+
dp_mesh_dim_names = ("dp_replicate",)
|
| 75 |
+
dp_mode = "replicate"
|
| 76 |
+
else:
|
| 77 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 78 |
+
dp_mode = "fully_shard"
|
| 79 |
+
|
| 80 |
+
mp_policy = MixedPrecisionPolicy(
|
| 81 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 82 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
model = data_parallel(
|
| 86 |
+
model,
|
| 87 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 88 |
+
mode=dp_mode,
|
| 89 |
+
ac_mode=job_config.activation_checkpoint.mode,
|
| 90 |
+
mp_policy=mp_policy,
|
| 91 |
+
)
|
| 92 |
+
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
|
| 93 |
+
|
| 94 |
+
if job_config.training.compile:
|
| 95 |
+
torch._inductor.config.reorder_for_peak_memory = False
|
| 96 |
+
model = torch.compile(model, fullgraph=True)
|
| 97 |
+
|
| 98 |
+
return model
|
torchtitan/models/llama3/train_configs/llama3_8b.toml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# torchtitan Config.toml
|
| 2 |
+
# NOTE: this toml config is a preset for 64 A100 GPUs.
|
| 3 |
+
|
| 4 |
+
[job]
|
| 5 |
+
dump_folder = "./outputs"
|
| 6 |
+
description = "Llama 3 8B training"
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = true
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 100
|
| 12 |
+
|
| 13 |
+
[metrics]
|
| 14 |
+
log_freq = 10
|
| 15 |
+
enable_tensorboard = true
|
| 16 |
+
save_tb_folder = "tb"
|
| 17 |
+
|
| 18 |
+
[model]
|
| 19 |
+
name = "llama3"
|
| 20 |
+
flavor = "8B"
|
| 21 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 22 |
+
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
|
| 23 |
+
# converters = "float8"
|
| 24 |
+
|
| 25 |
+
[optimizer]
|
| 26 |
+
name = "AdamW"
|
| 27 |
+
lr = 3e-4
|
| 28 |
+
eps = 1e-8
|
| 29 |
+
|
| 30 |
+
[lr_scheduler]
|
| 31 |
+
warmup_steps = 200 # lr scheduler warm up
|
| 32 |
+
|
| 33 |
+
[training]
|
| 34 |
+
batch_size = 1
|
| 35 |
+
seq_len = 8192
|
| 36 |
+
max_norm = 1.0 # grad norm clipping
|
| 37 |
+
steps = 1000
|
| 38 |
+
compile = false
|
| 39 |
+
dataset = "c4"
|
| 40 |
+
|
| 41 |
+
[parallelism]
|
| 42 |
+
data_parallel_replicate_degree = 1
|
| 43 |
+
data_parallel_shard_degree = -1
|
| 44 |
+
tensor_parallel_degree = 1
|
| 45 |
+
pipeline_parallel_degree = 1
|
| 46 |
+
context_parallel_degree = 1
|
| 47 |
+
|
| 48 |
+
[checkpoint]
|
| 49 |
+
enable_checkpoint = false
|
| 50 |
+
folder = "checkpoint"
|
| 51 |
+
interval = 500
|
| 52 |
+
model_weights_only = false
|
| 53 |
+
export_dtype = "float32"
|
| 54 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 55 |
+
|
| 56 |
+
[activation_checkpoint]
|
| 57 |
+
mode = 'selective' # ['none', 'selective', 'full']
|
| 58 |
+
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
|
| 59 |
+
|
| 60 |
+
[float8]
|
| 61 |
+
enable_fsdp_float8_all_gather = false
|
| 62 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
| 63 |
+
filter_fqns = "output"
|
torchtitan/models/norms.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
|
| 11 |
+
"""
|
| 12 |
+
Builds the specified normalization layer based on the norm_type.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
norm_type (str): The type of normalization layer to build.
|
| 16 |
+
Supported types: layernorm, np_layernorm, rmsnorm
|
| 17 |
+
dim (int): The dimension of the normalization layer.
|
| 18 |
+
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The built normalization layer.
|
| 22 |
+
|
| 23 |
+
Raises:
|
| 24 |
+
NotImplementedError: If an unknown norm_type is provided.
|
| 25 |
+
"""
|
| 26 |
+
norm_type = norm_type.lower() # Normalize to lowercase
|
| 27 |
+
|
| 28 |
+
if norm_type == "layernorm":
|
| 29 |
+
return nn.LayerNorm(dim, eps=eps, bias=False)
|
| 30 |
+
elif norm_type == "np_layernorm":
|
| 31 |
+
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
| 32 |
+
elif norm_type == "rmsnorm":
|
| 33 |
+
return nn.RMSNorm(dim, eps=eps)
|
| 34 |
+
else:
|
| 35 |
+
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
|
torchtitan/protocols/__pycache__/model_converter.cpython-311.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
torchtitan/tools/utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import gc
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch._utils import _get_available_device_type, _get_device_module
|
| 15 |
+
|
| 16 |
+
from torchtitan.tools.logging import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_device_info():
|
| 20 |
+
device_type = _get_available_device_type()
|
| 21 |
+
if device_type is None:
|
| 22 |
+
device_type = "cuda" # default device_type: cuda
|
| 23 |
+
device_module = _get_device_module(device_type) # default device_module:torch.cuda
|
| 24 |
+
return device_type, device_module
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
device_type, device_module = get_device_info()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# used to avoid stragglers in garbage collection
|
| 31 |
+
class GarbageCollection:
|
| 32 |
+
def __init__(self, gc_freq=1000):
|
| 33 |
+
assert gc_freq > 0, "gc_freq must be a positive integer"
|
| 34 |
+
self.gc_freq = gc_freq
|
| 35 |
+
gc.disable()
|
| 36 |
+
self.collect("Initial GC collection.")
|
| 37 |
+
|
| 38 |
+
def run(self, step_count):
|
| 39 |
+
if step_count > 1 and step_count % self.gc_freq == 0:
|
| 40 |
+
self.collect("Peforming periodical GC collection.")
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def collect(reason: str):
|
| 44 |
+
begin = time.monotonic()
|
| 45 |
+
gc.collect(1)
|
| 46 |
+
logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
|
| 50 |
+
def get_peak_flops(device_name: str) -> int:
|
| 51 |
+
try:
|
| 52 |
+
# Run the lspci command and capture the output
|
| 53 |
+
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
|
| 54 |
+
# Filter the output for lines containing both "NVIDIA" and "H100"
|
| 55 |
+
filtered_lines = [
|
| 56 |
+
line
|
| 57 |
+
for line in result.stdout.splitlines()
|
| 58 |
+
if "NVIDIA" in line and "H100" in line
|
| 59 |
+
]
|
| 60 |
+
# Join all filtered lines into a single string
|
| 61 |
+
device_name = " ".join(filtered_lines) or device_name
|
| 62 |
+
except FileNotFoundError as e:
|
| 63 |
+
logger.warning(f"Error running lspci: {e}, fallback to use device_name")
|
| 64 |
+
if "A100" in device_name:
|
| 65 |
+
# data from https://www.nvidia.com/en-us/data-center/a100/
|
| 66 |
+
return 312e12
|
| 67 |
+
elif "H100" in device_name:
|
| 68 |
+
# data from https://www.nvidia.com/en-us/data-center/h100/
|
| 69 |
+
# NOTE: Specifications are one-half lower without sparsity.
|
| 70 |
+
if "NVL" in device_name:
|
| 71 |
+
return 835e12
|
| 72 |
+
elif "PCIe" in device_name:
|
| 73 |
+
return 756e12
|
| 74 |
+
else: # for H100 SXM and other variants
|
| 75 |
+
return 989e12
|
| 76 |
+
elif "H200" in device_name:
|
| 77 |
+
# data from https://www.nvidia.com/en-us/data-center/h200/
|
| 78 |
+
return 989e12
|
| 79 |
+
elif "MI300X" in device_name or "MI325X" in device_name:
|
| 80 |
+
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
|
| 81 |
+
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
|
| 82 |
+
return 1300e12
|
| 83 |
+
elif "MI250X" in device_name:
|
| 84 |
+
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
|
| 85 |
+
return 191.5e12
|
| 86 |
+
elif "Data Center GPU Max 1550" in device_name:
|
| 87 |
+
# Also known as Ponte Vecchio (PVC).
|
| 88 |
+
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
| 89 |
+
# Dot Product Accumulate Systolic (DPAS):
|
| 90 |
+
# - Freq: 1300MHz
|
| 91 |
+
# - #ops: 512
|
| 92 |
+
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
|
| 93 |
+
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
|
| 94 |
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
| 95 |
+
return 512 * max_comp_units * 1300 * 10**6
|
| 96 |
+
else: # for other GPU types, assume A100
|
| 97 |
+
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
|
| 98 |
+
return 312e12
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass(frozen=True)
|
| 102 |
+
class Color:
|
| 103 |
+
black = "\033[30m"
|
| 104 |
+
red = "\033[31m"
|
| 105 |
+
green = "\033[32m"
|
| 106 |
+
yellow = "\033[33m"
|
| 107 |
+
blue = "\033[34m"
|
| 108 |
+
magenta = "\033[35m"
|
| 109 |
+
cyan = "\033[36m"
|
| 110 |
+
white = "\033[37m"
|
| 111 |
+
reset = "\033[39m"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass(frozen=True)
|
| 115 |
+
class NoColor:
|
| 116 |
+
black = ""
|
| 117 |
+
red = ""
|
| 118 |
+
green = ""
|
| 119 |
+
yellow = ""
|
| 120 |
+
blue = ""
|
| 121 |
+
magenta = ""
|
| 122 |
+
cyan = ""
|
| 123 |
+
white = ""
|
| 124 |
+
reset = ""
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def check_if_feature_in_pytorch(
|
| 128 |
+
feature_name: str,
|
| 129 |
+
pull_request: str,
|
| 130 |
+
min_nightly_version: Optional[str] = None,
|
| 131 |
+
) -> None:
|
| 132 |
+
if "git" in torch.__version__: # pytorch is built from source
|
| 133 |
+
# notify users to check if the pull request is included in their pytorch
|
| 134 |
+
logger.warning(
|
| 135 |
+
"detected that the pytorch is built from source. Please make sure the PR "
|
| 136 |
+
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
|
| 137 |
+
)
|
| 138 |
+
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
|
| 139 |
+
logger.warning(
|
| 140 |
+
f"detected that the pytorch version {torch.__version__} is older than "
|
| 141 |
+
f"{min_nightly_version}. Please upgrade a newer version to include the "
|
| 142 |
+
f"change in ({pull_request_link}) for correct {feature_name}."
|
| 143 |
+
)
|