Erland commited on
Commit
7e92010
·
verified ·
1 Parent(s): 372fade

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +512 -0
  2. torchtitan/components/__pycache__/ft.cpython-311.pyc +0 -0
  3. torchtitan/components/optimizer.py +303 -0
  4. torchtitan/distributed/__pycache__/pipeline.cpython-311.pyc +0 -0
  5. torchtitan/distributed/pipeline.py +201 -0
  6. torchtitan/experiments/deepseek_v3/checkpoint.py +154 -0
  7. torchtitan/experiments/deepseek_v3/download.py +70 -0
  8. torchtitan/experiments/deepseek_v3/generate.py +308 -0
  9. torchtitan/experiments/deepseek_v3/inference.sh +15 -0
  10. torchtitan/experiments/deepseek_v3/model.py +1325 -0
  11. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  12. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  13. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
  14. torchtitan/experiments/deepseek_v3/train.py +142 -0
  15. torchtitan/experiments/flux/README.md +23 -0
  16. torchtitan/experiments/flux/dataset/flux_dataset.py +267 -0
  17. torchtitan/experiments/flux/model/autoencoder.py +388 -0
  18. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  19. torchtitan/experiments/flux/model/layers.py +286 -0
  20. torchtitan/experiments/flux/model/model.py +177 -0
  21. torchtitan/experiments/flux/parallelize_flux.py +26 -0
  22. torchtitan/experiments/flux/requirements.txt +2 -0
  23. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  24. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  25. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  26. torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
  27. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
  28. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py +240 -0
  29. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py +82 -0
  30. torchtitan/experiments/llama4/__pycache__/__init__.cpython-311.pyc +0 -0
  31. torchtitan/experiments/llama4/infra/expert_parallel.py +145 -0
  32. torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
  33. torchtitan/experiments/llama4/model/__pycache__/model.cpython-311.pyc +0 -0
  34. torchtitan/experiments/llama4/model/model.py +466 -0
  35. torchtitan/experiments/llama4/model/moe.py +228 -0
  36. torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +25 -0
  37. torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
  38. torchtitan/experiments/multimodal/mm_collator.py +227 -0
  39. torchtitan/experiments/multimodal/mm_dataset.py +268 -0
  40. torchtitan/experiments/multimodal/tests/test_multimodal_model.py +128 -0
  41. torchtitan/experiments/multimodal/utils.py +437 -0
  42. torchtitan/experiments/simple_fsdp/__init__.py +33 -0
  43. torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-311.pyc +0 -0
  44. torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-311.pyc +0 -0
  45. torchtitan/experiments/simple_fsdp/model.py +18 -0
  46. torchtitan/experiments/simple_fsdp/parallelize_llama.py +98 -0
  47. torchtitan/models/llama3/train_configs/llama3_8b.toml +63 -0
  48. torchtitan/models/norms.py +35 -0
  49. torchtitan/protocols/__pycache__/model_converter.cpython-311.pyc +0 -0
  50. torchtitan/tools/utils.py +143 -0
README.md ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+ # This is a fork for the paper:
5
+ # Softpick: No Attention Sink, No Massive Activations with Rectified Softmax
6
+
7
+ </div>
8
+
9
+ ## Instructions for Softpick Attention
10
+
11
+ This fork can only work on an older commit of torchtitan and flame, so the setup looks like this:
12
+
13
+ ```bash
14
+ git clone https://github.com/zaydzuhri/flame.git
15
+ cd flame
16
+ git checkout softpick-attention
17
+ git submodule update --init --recursive --remote
18
+ cd 3rdparty/torchtitan
19
+ git checkout 4f532e0
20
+ cd ../../
21
+
22
+ pip install .
23
+ pip install flash-attn --no-build-isolation
24
+ ```
25
+ The flash-linear-attention submodule has been changed to link to our fork: https://github.com/zaydzuhri/flash-linear-attention/tree/softpick-attention
26
+ So no need to manually clone it.
27
+
28
+ Then prepare the fineweb-edu 100B sample the same way as described in the flame repo guide below.
29
+
30
+ These are the training commands used in the paper:
31
+ ```bash
32
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/vanilla.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/vanilla_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/vanilla-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
33
+
34
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
35
+ ```
36
+
37
+ And the same for the extra experiments in the appendix:
38
+ ```bash
39
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/rectified.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/rectified_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/rectified-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
40
+
41
+ NGPU=8 bash train.sh --job.config_file flame/models/fla.toml --job.dump_folder exp/softpick.scaled.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine --model.config configs/softpick_scaled_transformer_340M.json --model.tokenizer_path fla-hub/transformer-1.3B-100B --optimizer.name AdamW --optimizer.eps 1e-15 --optimizer.lr 3e-4 --lr_scheduler.warmup_steps 1000 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine --training.batch_size 16 --training.seq_len 4096 --training.context_len 4096 --training.gradient_accumulation_steps 1 --training.steps 100000 --training.max_norm 1.0 --training.skip_nan_inf --training.dataset ~/.cache/HuggingFaceFW___fineweb-edu/sample-100BT --training.dataset_split train --training.num_workers 32 --training.prefetch_factor 2 --training.seed 79 --training.compile --checkpoint.interval 10000 --checkpoint.load_step -1 --metrics.log_freq 5 --checkpoint.hf_upload_enabled --checkpoint.hf_repo_base_name "zaydzuhri/softpick-scaled-340M-4096-batch16-steps100000" --comm.init_timeout_seconds 600 --comm.train_timeout_seconds 300
42
+ ```
43
+
44
+ Feel free to DM @zmkzmkz on X for any questions regarding the paper or this code!
45
+
46
+ ## Flame
47
+
48
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
49
+
50
+ **Feature Highlights:**
51
+
52
+ - 🚀 Minimal, easy-to-use, extensible training framework
53
+ - 🤗 Seamless integration with `fla` and `transformers`
54
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
55
+ - 🔮 4D parallelism (coming soon)
56
+
57
+ ## Setup
58
+
59
+ To get started, clone the `flame` repository and install the required dependencies:
60
+
61
+ ```bash
62
+ git clone https://github.com/fla-org/flame.git
63
+ cd flame
64
+ pip install .
65
+ ```
66
+
67
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
68
+ After installation, initialize and update the submodules:
69
+ ```sh
70
+ git submodule update --init --recursive
71
+ ```
72
+
73
+ ## Dataset Preparation
74
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
75
+
76
+ ```py
77
+ from datasets import load_dataset
78
+
79
+ # load fineweb-edu with parallel processing
80
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
81
+
82
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
83
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
84
+ ```
85
+
86
+ ## Training Recipes
87
+
88
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
89
+
90
+ > [!WARNING]
91
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
92
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
93
+
94
+ ```sh
95
+ bash train.sh \
96
+ --job.config_file flame/models/fla.toml \
97
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
98
+ --model.config configs/transformer_340M.json \
99
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
100
+ --optimizer.name AdamW \
101
+ --optimizer.eps 1e-15 \
102
+ --optimizer.lr 3e-4 \
103
+ --lr_scheduler.warmup_steps 1024 \
104
+ --lr_scheduler.lr_min 0.1 \
105
+ --lr_scheduler.decay_type cosine \
106
+ --training.batch_size 1 \
107
+ --training.seq_len 65536 \
108
+ --training.context_len 4096 \
109
+ --training.varlen \
110
+ --training.gradient_accumulation_steps 1 \
111
+ --training.steps 20480 \
112
+ --training.max_norm 1.0 \
113
+ --training.skip_nan_inf \
114
+ --training.dataset HuggingFaceFW/fineweb-edu \
115
+ --training.dataset_name sample-100BT \
116
+ --training.dataset_split train \
117
+ --training.streaming \
118
+ --training.num_workers 32 \
119
+ --training.prefetch_factor 2 \
120
+ --training.seed 42 \
121
+ --training.compile \
122
+ --checkpoint.interval 2048 \
123
+ --checkpoint.load_step -1 \
124
+ --checkpoint.keep_latest_k 2 \
125
+ --metrics.log_freq 1
126
+ ```
127
+
128
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
129
+ **For single-GPU debugging, set `NGPU=1`.**
130
+
131
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
132
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
133
+
134
+ **Key parameters:**
135
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
136
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
137
+ - `--training.steps`: Total number of training steps.
138
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
139
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
140
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
141
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
142
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
143
+
144
+ > [!WARNING]
145
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
146
+ > Each step processes `global_batch_size * seq_len` tokens.
147
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
148
+
149
+ For a detailed explanation of all parameters, run:
150
+
151
+ ```sh
152
+ bash train.sh -h
153
+ ```
154
+
155
+ <details>
156
+ <summary>Usage</summary>
157
+
158
+ ```py
159
+ options:
160
+ -h, --help show this help message and exit
161
+ --job.config_file JOB.CONFIG_FILE
162
+ Job config file
163
+ --job.dump_folder JOB.DUMP_FOLDER
164
+ Folder to dump job outputs
165
+ --job.description JOB.DESCRIPTION
166
+ Description of the job
167
+ --job.use_for_integration_test
168
+ Add this config to the integration test suite
169
+ --job.print_args Print the args to terminal
170
+ --model.config MODEL.CONFIG
171
+ Path to the model config
172
+ --model.norm_type MODEL.NORM_TYPE
173
+ Type of layer normalization to use [layernorm,
174
+ np_layernorm, rmsnorm, fused_rmsnorm]
175
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
176
+ Tokenizer path
177
+ --profiling.enable_profiling
178
+ Whether to enable pytorch profiler
179
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
180
+ Trace files location
181
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
182
+ How often to collect profiler traces, in iterations
183
+ --profiling.enable_memory_snapshot
184
+ Whether to dump memory snapshot
185
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
186
+ Memeory snapshot files location
187
+ --optimizer.name OPTIMIZER.NAME
188
+ Optimizer to use
189
+ --optimizer.eps OPTIMIZER.EPS
190
+ Epsilon value for the optimizer.
191
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
192
+ --optimizer.scheduler {wsd,cosine,linear}
193
+ Scheduler to use. Currently supported: wsd, cosine,
194
+ and linear.
195
+ --optimizer.lr OPTIMIZER.LR
196
+ Learning rate to use
197
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
198
+ Min lr ratio for lr scheduler
199
+ --optimizer.early_step_in_backward
200
+ Whether to apply optimizer in the backward. Caution,
201
+ optimizer_in_backward is not compatible with gradients
202
+ clipping, users should not call
203
+ register_post_accumulate_grad_hook after the optimizer
204
+ is built.
205
+ --training.batch_size TRAINING.BATCH_SIZE
206
+ Batch size
207
+ --training.seq_len TRAINING.SEQ_LEN
208
+ Sequence length
209
+ --training.context_len TRAINING.CONTEXT_LEN
210
+ Max length allowed for each sequence
211
+ --training.varlen Whether to take sequences of variable length as input
212
+ --training.warmup_steps TRAINING.WARMUP_STEPS
213
+ Steps for lr scheduler warmup, normally 1/5 of
214
+ --training.steps
215
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
216
+ Number of steps to accumulate gradients before
217
+ updating parameters
218
+ --training.steps TRAINING.STEPS
219
+ How many train steps to run
220
+ --training.max_norm TRAINING.MAX_NORM
221
+ Max norm for gradient clipping
222
+ --training.skip_nan_inf
223
+ Skip batch updates when NaN or INF gradients are
224
+ encountered during training
225
+ --training.dataset TRAINING.DATASET
226
+ Dataset to use, with comma separated values
227
+ --training.dataset_name TRAINING.DATASET_NAME
228
+ The name of the dataset config, with comma separated
229
+ values if provided
230
+ --training.dataset_split TRAINING.DATASET_SPLIT
231
+ Dataset split to use, with comma separated values if
232
+ provided
233
+ --training.data_dir TRAINING.DATA_DIR
234
+ Data dirs to use, with comma separated values if
235
+ provided
236
+ --training.data_files TRAINING.DATA_FILES
237
+ Data files to use, with comma separated values if
238
+ provided
239
+ --training.data_probs TRAINING.DATA_PROBS
240
+ Data sampling probabilities, with comma separated
241
+ values if provided
242
+ --training.streaming Whether to load dataset in streaming mode, used for
243
+ huge dataset
244
+ --training.num_workers TRAINING.NUM_WORKERS
245
+ Number of subprocesses to use for data loading. 0
246
+ means that the data will be loaded in the main
247
+ process.
248
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
249
+ Number of batches loaded in advance by each worker.2
250
+ means there will be a total of 2 * num_workers batches
251
+ prefetched across all workers.
252
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
253
+ The `data_parallel_replicate_degree` argument
254
+ specifies the degree of data parallelism for weight
255
+ replication. When this value is greater than 1,
256
+ weights will be replicated across
257
+ `data_parallel_replicate_degree` ranks. If
258
+ `data_parallel_shard_degree` is also greater than 1,
259
+ the parallelism method used is HSDP (Hybrid Sharded
260
+ Data Parallelism). Otherwise, the parallelism method
261
+ used is DDP (Distributed Data Parallelism). 1 means
262
+ disabled.
263
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
264
+ The `data_parallel_shard_degree` argument specifies
265
+ the degree of data parallelism for weight sharding.
266
+ When this value is greater than 1, weights will be
267
+ sharded across `data_parallel_shard_degree` ranks. If
268
+ `data_parallel_replicate_degree` is also greater than
269
+ 1, the parallelism method used is HSDP (Hybrid Sharded
270
+ Data Parallelism). Otherwise, the parallelism method
271
+ used is FSDP (Fully Sharded Data Parallelism). -1
272
+ means leftover ranks will be used (After
273
+ DP_REPLICATE/SP/PP). Note that only
274
+ `data_parallel_shard_degree` can be negative. 1 means
275
+ disabled.
276
+ --training.enable_cpu_offload
277
+ Whether to apply CPU offloading of parameters,
278
+ gradients, and optimizer states in FSDP
279
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
280
+ Tensor Parallelism degree. 1 means disabled.
281
+ --training.disable_loss_parallel
282
+ Whether to apply loss parallel when sequence parallel
283
+ is enabled
284
+ --training.mixed_precision_param {bfloat16,float32}
285
+ torch dtype to use for parameters when applying mixed
286
+ precision via FSDP. This feature only takes effect
287
+ when data_parallel_shard_degree > 1
288
+ --training.mixed_precision_reduce {float32}
289
+ torch dtype to use for reductions when applying mixed
290
+ precision via FSDP. This feature only takes effect
291
+ when data_parallel_shard_degree > 1
292
+ --training.compile Whether to compile the model
293
+ --training.gc_freq TRAINING.GC_FREQ
294
+ Python garbage control scheduling interval, in steps
295
+ --training.seed TRAINING.SEED
296
+ Choose the base RNG seed used for training
297
+ --training.deterministic
298
+ Use deterministic algorithms wherever possible, may be
299
+ slower
300
+ --metrics.log_freq METRICS.LOG_FREQ
301
+ How often to log metrics to TensorBoard, in iterations
302
+ --metrics.enable_tensorboard
303
+ Whether to log metrics to TensorBoard
304
+ --metrics.disable_color_printing
305
+ Whether to disable color printing in logs
306
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
307
+ Folder to dump TensorBoard states
308
+ --metrics.rank_0_only
309
+ Whether to save TensorBoard metrics only for rank 0 or
310
+ for all ranks. When pipeline_parallel_degree is > 1,
311
+ this option uses the 0th rank of the last stage
312
+ pipeline group, which is the only stage that computes
313
+ loss metrics.
314
+ --metrics.enable_wandb
315
+ Whether to log metrics to Weights & Biases
316
+ --experimental.enable_async_tensor_parallel
317
+ Whether to apply async tensor parallel (currently only
318
+ effective when compile is enabled)
319
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
320
+ Pipeline Parallelism degree, or number of ranks. 1
321
+ means disabled. If using looped schedules, this still
322
+ specifies the number of physical ranks, not the number
323
+ of stages. Stages per rank are inferred from split
324
+ points degree, and schedule.
325
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
326
+ Specify comma-separated names of modules to use as the
327
+ beginning of a split point. e.g. "layers.0,layers.2"
328
+ will cause the model to be split into 3 stages, the
329
+ first containing all the layers up to layers.0, the
330
+ second containing layers.0 and up to layers.2, the
331
+ third containing layers.2 and all the remaining
332
+ layers. Note: fully-automated splitting may be enabled
333
+ in the future, but currently the split points must be
334
+ specified manually.
335
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
336
+ Specify the Pipeline Parallel schedule to use. The
337
+ supported schedules are: https://github.com/pytorch/py
338
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
339
+ rch/distributed/pipelining/schedules.py#L2161. The
340
+ schedule must be compatible with the split points and
341
+ stages_per_rank. Looped schedules (e.g.
342
+ Interleaved1F1B) require specifying
343
+ pipeline_parallel_degree = number of ranks, and
344
+ split_points = number of stages - 1
345
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
346
+ Specify the path to the pipeline parallel schedule csv
347
+ file to use. The pipeline_parallel_schedule argument
348
+ must be either PipelineScheduleSingle,
349
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
350
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
351
+ How many microbatches to split the global training
352
+ batch into when using pipeline parallelism. The global
353
+ training batch size must be evenly divisible by the
354
+ number of microbatches. The default value will be the
355
+ number of pipeline stages, if unspecified.
356
+ --experimental.enable_compiled_autograd
357
+ Enable CompiledAutograd to compile the backward.
358
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
359
+ Context parallelism degree. 1 means disabled.
360
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
361
+ The collective to use in context parallel SDPA for kv
362
+ shards exchange. 'allgather' means to all-gather all
363
+ kv shards on ranks after the first sub-SDPA
364
+ computation, 'alltoall' means to all-to-all shuffle
365
+ the kv shards. The default value is 'allgather'.
366
+ --checkpoint.enable_checkpoint
367
+ Whether to enable checkpoint
368
+ --checkpoint.folder CHECKPOINT.FOLDER
369
+ The folder to store the checkpoints. When
370
+ enable_checkpoint is set to true, checkpoints will be
371
+ in {--job.dump_folder}/{--checkpoint.folder}.
372
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
373
+ Checkpointing interval unit of measurement ['step',
374
+ 'seconds']
375
+ --checkpoint.interval CHECKPOINT.INTERVAL
376
+ Checkpointing interval, in steps or seconds depending
377
+ on --checkpoint.interval_type
378
+ --checkpoint.model_weights_only
379
+ When model_weights_only=True, only model weights will
380
+ be saved at the end of training. With this,
381
+ checkpoints can be loaded using `torch.load(...,
382
+ weights_only=True)` after conversion. When
383
+ model_weights_only=False, the full checkpoint will be
384
+ saved. A full checkpoint includes model, optimizer and
385
+ train_state, which can be used to resume training. The
386
+ default value is false.
387
+ --checkpoint.export_dtype {float16,bfloat16,float32}
388
+ Converts to the specified precision when training
389
+ completes and model_weights_only=true. Currently
390
+ supports float32, float16, and bfloat16. The default
391
+ value is float32.
392
+ --checkpoint.create_seed_checkpoint
393
+ Initializes the full model without applying
394
+ parallelisms, and then saves it as a seed checkpoint.
395
+ Note: requires user to call train.py without
396
+ specifying any parallelisms, e.g. NGPU=1. Could be
397
+ implemented as a separate script, but this way shares
398
+ more code.
399
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
400
+ Which async checkpoint mode to use. Currently there
401
+ are 3 different modes. 1. "disabled": synchronized
402
+ checkpointing will be used. 2. "async":
403
+ torch.distributed.checkpoint.async_save will be used.
404
+ 1. "async_with_pinned_mem": this option utilizes a
405
+ dedicated pinned memory space and creates a separate
406
+ process for faster GPU->CPU transfer performance and
407
+ eliminating GIL contention. The cost is increased CPU
408
+ memory usage. If insufficient CPU memory is available,
409
+ performance may degrade due to memory paging. For most
410
+ users, "async" should suffice as the performance
411
+ overhead is typically small (on the order of tens of
412
+ seconds) compared to checkpointing frequency. This
413
+ mode can be employed to pursue near-zero checkpointing
414
+ times (e.g., < 1 second) given appropriate hardware
415
+ support such as ample CPU memory and fast PCIe.
416
+ "disabled" is the default mode.
417
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
418
+ Keeps only the latest k checkpoints, and purging older
419
+ ones. If 0, keep all checkpoints. 0 is the default
420
+ value.
421
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
422
+ Load the checkpoint at the specified step. If -1, load
423
+ the latest checkpoint.
424
+ --float8.enable_float8_linear
425
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
426
+ This feature requires you to install 'torchao' which
427
+ can be found here: https://github.com/pytorch/ao
428
+ --float8.enable_fsdp_float8_all_gather
429
+ Whether enable float8 all-gather in FSDP
430
+ --float8.precompute_float8_dynamic_scale_for_fsdp
431
+ Whether precompute float8 scales dynamically for FSDP
432
+ --float8.scaling_type_input {dynamic,delayed}
433
+ float8 scaling for input, dynamic (default) or delayed
434
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
435
+ float8 scaling for input, dynamic (default) or delayed
436
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
437
+ float8 scaling for input, dynamic (default) or delayed
438
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
439
+ Timeout for communication operations, during
440
+ initialization and first train step.
441
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
442
+ Timeout for communication operations after the first
443
+ train step -- usually a tighter bound than during
444
+ initialization.
445
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
446
+ Flight recorder ring buffer size, >0 means recording
447
+ by default, 0 means disabled
448
+ --memory_estimation.enabled
449
+ Whether to estimate memory usage for FSDP
450
+ --memory_estimation.disable_fake_mode
451
+ Whether to estimate memory under FakeTensorMode
452
+ ```
453
+ </details>
454
+
455
+ ### Training with `torch.compile`
456
+
457
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
458
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
459
+
460
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
461
+ We are actively working on resolving these issues to make compilation transparent to users.
462
+ In the meantime, please ensure you are using the latest dependencies.
463
+
464
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
465
+
466
+ ### Training with multiple datasets
467
+
468
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
469
+ `flame` allows training with multiple datasets easily.
470
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
471
+
472
+ ```sh
473
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
474
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
475
+ ```
476
+
477
+ ### ~Finalizing training~
478
+
479
+ > [!NOTE]
480
+ > We have done this conversion automatically in the training script since our latest updates.
481
+
482
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
483
+ To facilitate this, we provide a straightforward conversion script:
484
+
485
+ ```sh
486
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
487
+ ```
488
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
489
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
490
+
491
+ ### Continual training
492
+
493
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
494
+ This allows you to seamlessly resume training with `flame`.
495
+ ```sh
496
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
497
+ ```
498
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
499
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
500
+
501
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
502
+
503
+ ## Multi-node training
504
+
505
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
506
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
507
+
508
+ To set up multi-node training:
509
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
510
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
511
+
512
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
torchtitan/components/__pycache__/ft.cpython-311.pyc ADDED
Binary file (7.05 kB). View file
 
torchtitan/components/optimizer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ from typing import Any, Generic, Iterator, TypeVar
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.distributed.checkpoint.state_dict import (
13
+ get_optimizer_state_dict,
14
+ set_optimizer_state_dict,
15
+ StateDictOptions,
16
+ )
17
+ from torch.distributed.checkpoint.stateful import Stateful
18
+ from torch.optim import Optimizer
19
+
20
+ from torchtitan.components.ft import FTManager, has_torchft
21
+ from torchtitan.config_manager import JobConfig
22
+
23
+ __all__ = [
24
+ "OptimizersContainer",
25
+ "build_optimizers",
26
+ ]
27
+
28
+
29
+ if has_torchft:
30
+ import torchft as ft
31
+
32
+
33
+ T = TypeVar("T", bound=Optimizer)
34
+
35
+
36
+ class OptimizersContainer(Optimizer, Stateful, Generic[T]):
37
+ """A container for multiple optimizers.
38
+
39
+ This class is used to wrap multiple optimizers into a single object that can be
40
+ used to reduce the complexity of the training loop. This mimics the behavior of
41
+ ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
42
+
43
+ **Note**
44
+ Users who want to customize the optimizer behavior can inherit from this class and
45
+ extend the functionality as needed. The following methods must follow the same signature
46
+ as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
47
+ ``load_state_dict()``.
48
+
49
+ **Limitations**
50
+ This class assumes that all the optimizers are the same type and have the same
51
+ configurations. With this assumption, TorchTitan can support lr scheduler resharding
52
+ (e.g., loading a checkpoint with a different number of GPUs and/or different
53
+ parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
54
+ resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
55
+
56
+ Args:
57
+ model_parts (List[nn.Module]): List of model parts to be optimized.
58
+ optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
59
+ name (str): Name of the optimizers.
60
+ """
61
+
62
+ optimizers: list[T]
63
+ model_parts: list[nn.Module]
64
+
65
+ def __init__(
66
+ self,
67
+ model_parts: list[nn.Module],
68
+ optimizer_cls: type[T],
69
+ optimizer_kwargs: dict[str, Any],
70
+ ) -> None:
71
+ all_params = []
72
+ self.optimizers = []
73
+ self.model_parts = model_parts
74
+ for model in self.model_parts:
75
+ params = [p for p in model.parameters() if p.requires_grad]
76
+ self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
77
+ all_params.extend(params)
78
+ self._validate_length(len(self.model_parts))
79
+ self._post_init(all_params, optimizer_kwargs)
80
+
81
+ def __iter__(self) -> Iterator[T]:
82
+ return iter(self.optimizers)
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.optimizers)
86
+
87
+ def step(self, *args, **kwargs) -> None:
88
+ for optimizer in self.optimizers:
89
+ optimizer.step(*args, **kwargs)
90
+
91
+ def zero_grad(self, *args, **kwargs) -> None:
92
+ for optimizer in self.optimizers:
93
+ optimizer.zero_grad(*args, **kwargs)
94
+
95
+ def state_dict(self) -> dict[str, Any]:
96
+ func = functools.partial(
97
+ get_optimizer_state_dict,
98
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
99
+ )
100
+ return {
101
+ k: v
102
+ for sd in map(func, self.model_parts, self.optimizers)
103
+ for k, v in sd.items()
104
+ }
105
+
106
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
107
+ func = functools.partial(
108
+ set_optimizer_state_dict,
109
+ optim_state_dict=state_dict,
110
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
111
+ )
112
+ list(map(func, self.model_parts, self.optimizers))
113
+
114
+ def _validate_length(self, expected_length: int) -> None:
115
+ assert expected_length == len(self.optimizers), (
116
+ "Must pass one optimizer per model part or per param if "
117
+ "using OptimizersInBackwardContainer."
118
+ )
119
+
120
+ def _post_init(
121
+ self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
122
+ ) -> None:
123
+ # We need to call Optimizer.__init__() to initialize some necessary optimizer
124
+ # functionality such as hooks.
125
+ Optimizer.__init__(self, all_params, optimizer_kwargs)
126
+
127
+
128
+ class OptimizersInBackwardContainer(OptimizersContainer):
129
+ """OptimizersContainer for executing ``optim.step()`` in backward pass.
130
+
131
+ This class extend ``OptimizersContainer`` to support optimizer step in
132
+ backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
133
+ Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
134
+ execute these methods when the gradient is accumulated.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ model_parts: list[nn.Module],
140
+ optimizer_cls: type[T],
141
+ optimizer_kwargs: dict[str, Any],
142
+ ) -> None:
143
+ all_params = []
144
+ self.model_parts = model_parts
145
+
146
+ optim_dict = {}
147
+ for model in self.model_parts:
148
+ for p in model.parameters():
149
+ if p.requires_grad:
150
+ optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
151
+ all_params.append(p)
152
+
153
+ def optim_hook(param) -> None:
154
+ optim_dict[param].step()
155
+ optim_dict[param].zero_grad()
156
+
157
+ for model in self.model_parts:
158
+ for param in model.parameters():
159
+ if param.requires_grad:
160
+ param.register_post_accumulate_grad_hook(optim_hook)
161
+
162
+ self.optimizers = list(optim_dict.values())
163
+
164
+ self._validate_length(
165
+ sum(len(list(model.parameters())) for model in self.model_parts)
166
+ )
167
+ self._post_init(all_params, optimizer_kwargs)
168
+
169
+ def step(self) -> None:
170
+ pass
171
+
172
+ def zero_grad(self) -> None:
173
+ pass
174
+
175
+
176
+ class FTOptimizersContainer(OptimizersContainer):
177
+ def __init__(
178
+ self,
179
+ model_parts: list[nn.Module],
180
+ optimizer_cls: type[T],
181
+ optimizer_kwargs: dict[str, Any],
182
+ ft_manager: "ft.Manager",
183
+ ) -> None:
184
+ super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
185
+
186
+ # Force to initialize the optimizer state so that `optim.step()`
187
+ # won't be called by state_dict() and load_state_dict().
188
+ _ = {
189
+ k: v
190
+ for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
191
+ for k, v in sd.items()
192
+ }
193
+ self.cache_state_dict: dict[str, Any] = {}
194
+ self._ft_optimizer = ft.Optimizer(ft_manager, self)
195
+ self._call_from_ft: bool = False
196
+
197
+ def init_cache_state_dict(self) -> None:
198
+ self.cache_state_dict = super().state_dict()
199
+
200
+ def state_dict(self) -> dict[str, Any]:
201
+ return self.cache_state_dict
202
+
203
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
204
+ # We have to invalidate the `cache_state_dict` because optimizer uses
205
+ # assign instead of copy when doing `load_state_dict()`. Without
206
+ # invalidating the `cache_state_dict`, there will be memory leakage.
207
+ self.cache_state_dict = {}
208
+ super().load_state_dict(state_dict)
209
+ self.init_cache_state_dict()
210
+
211
+ def step(self, *args, **kwargs) -> None:
212
+ """Calling the correct step() depending on the caller.
213
+
214
+ TorchFT's OptimizerWrapper.step() is designed to be callled only once
215
+ per train step per ft.Manager regardless how many optimizers are used.
216
+ Hence we will need to appropriately dispatch the call.
217
+ """
218
+ if self._call_from_ft:
219
+ super().step(*args, **kwargs)
220
+ else:
221
+ self._call_from_ft = True
222
+ self._ft_optimizer.step(*args, **kwargs)
223
+ self._call_from_ft = False
224
+
225
+ def zero_grad(self, *args, **kwargs) -> None:
226
+ """Calling the correct zero_grad() depending on the caller.
227
+
228
+ Check the comment in ``step()``.
229
+ """
230
+ if self._call_from_ft:
231
+ super().zero_grad(*args, **kwargs)
232
+ else:
233
+ self._call_from_ft = True
234
+ self._ft_optimizer.zero_grad(*args, **kwargs)
235
+ self._call_from_ft = False
236
+
237
+
238
+ def build_optimizers(
239
+ model_parts: list[nn.Module],
240
+ job_config: JobConfig,
241
+ ft_manager: FTManager,
242
+ ) -> OptimizersContainer:
243
+ """Create a OptimizersContainer for the given model parts and job config.
244
+
245
+ This function creates a ``OptimizersContainer`` for the given model parts.
246
+ ``job_config`` should define the correct optimizer name and parameters.
247
+ This function currently supports creating ``OptimizersContainer`` and
248
+ ``OptimizersInBackwardContainer``.
249
+
250
+ **Note**
251
+ Users who want to customize the optimizer behavior can create their own
252
+ ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
253
+ customized ``build_optimizers`` to ``TrainSpec`` will create the customized
254
+ ``OptimizersContainer``.
255
+
256
+ Args:
257
+ model_parts (List[nn.Module]): List of model parts to be optimized.
258
+ job_config (JobConfig): Job config containing the optimizer name and parameters.
259
+ """
260
+ optim_in_bwd = job_config.optimizer.early_step_in_backward
261
+ if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
262
+ raise NotImplementedError(
263
+ "Optimizers in backward is not supported with pipeline parallelism."
264
+ )
265
+ name = job_config.optimizer.name
266
+ lr = job_config.optimizer.lr
267
+ eps = job_config.optimizer.eps
268
+
269
+ optim_implementation = job_config.optimizer.implementation
270
+ assert optim_implementation in ["fused", "foreach", "for-loop"]
271
+
272
+ fused = optim_implementation == "fused"
273
+ foreach = optim_implementation == "foreach"
274
+
275
+ optimizer_kwargs = {
276
+ "lr": lr,
277
+ "eps": eps,
278
+ "betas": (0.9, 0.95),
279
+ "weight_decay": 0.1,
280
+ "fused": fused,
281
+ "foreach": foreach,
282
+ }
283
+
284
+ optimizer_classes = {
285
+ "Adam": torch.optim.Adam,
286
+ "AdamW": torch.optim.AdamW,
287
+ }
288
+ if name not in optimizer_classes:
289
+ raise NotImplementedError(f"Optimizer {name} not added.")
290
+ optimizer_cls = optimizer_classes[name]
291
+
292
+ if optim_in_bwd and ft_manager.enabled:
293
+ raise ValueError("TorchFT is not supported with optimizers in backward.")
294
+ elif optim_in_bwd:
295
+ return OptimizersInBackwardContainer(
296
+ model_parts, optimizer_cls, optimizer_kwargs
297
+ )
298
+ elif ft_manager.enabled:
299
+ return FTOptimizersContainer(
300
+ model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
301
+ )
302
+ else:
303
+ return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
torchtitan/distributed/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (8.48 kB). View file
 
torchtitan/distributed/pipeline.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import math
7
+ import os
8
+ from typing import Callable, Optional
9
+
10
+ from torch.distributed.pipelining.schedules import (
11
+ _PipelineSchedule,
12
+ _PipelineScheduleRuntime,
13
+ get_schedule_class,
14
+ PipelineScheduleMulti,
15
+ PipelineScheduleSingle,
16
+ )
17
+ from torch.distributed.pipelining.stage import PipelineStage
18
+
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ __all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
24
+
25
+
26
+ # TODO: It's unclear if this API is general enough to be used by other models.
27
+ # If not, we should move it to a Transformer-specific directory.
28
+ def generate_split_points(
29
+ schedule_str: str,
30
+ layers_per_stage: Optional[int],
31
+ pp_dim: int,
32
+ num_layers: int,
33
+ input_weight: int = 1,
34
+ output_weight: int = 1,
35
+ ) -> list[str]:
36
+ """
37
+ Generate a list of split points based on the number of layers and
38
+ pipeline parallel dimension, ensuring the first and last stages have the least layers.
39
+
40
+ Args:
41
+ schedule_str (str): The string of the schedule name.
42
+ layers_per_stage (int): The number of layers per stage.
43
+ pp_dim (int): The pipeline parallel dimension.
44
+ num_layers (int): The number of layers in the model.
45
+ input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation.
46
+
47
+ Returns:
48
+ list[str]: A list of split point FQNs.
49
+ """
50
+
51
+ schedule_class = get_schedule_class(schedule_str)
52
+ is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
53
+ num_stages_per_rank = 1 if is_single_stage_schedule else 2
54
+
55
+ if layers_per_stage is not None:
56
+ total_stages = math.ceil(num_layers / layers_per_stage)
57
+ if total_stages % pp_dim != 0:
58
+ raise ValueError(
59
+ f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})."
60
+ f"Each rank should have the same number of stages. "
61
+ )
62
+ num_stages_per_rank = total_stages // pp_dim
63
+
64
+ if is_single_stage_schedule and num_stages_per_rank != 1:
65
+ raise ValueError(
66
+ f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules."
67
+ )
68
+ elif not is_single_stage_schedule and num_stages_per_rank < 2:
69
+ raise ValueError(
70
+ f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules."
71
+ )
72
+ else:
73
+ total_stages = pp_dim * num_stages_per_rank
74
+ if total_stages > num_layers:
75
+ raise ValueError("Total stages cannot be greater than the number of layers")
76
+
77
+ # Calculate effective number of layers including input and output weights
78
+ effective_num_layers = num_layers + input_weight + output_weight
79
+ base_layers_per_stage = effective_num_layers // total_stages
80
+
81
+ splits = [""] * (total_stages - 1)
82
+ current_layer_index = 0
83
+
84
+ # First stage
85
+ layers_on_first_stage = max(0, base_layers_per_stage - input_weight)
86
+ current_layer_index += layers_on_first_stage
87
+ splits[0] = "layers." + str(current_layer_index)
88
+
89
+ # Last stage
90
+ layers_on_last_stage = max(0, base_layers_per_stage - output_weight)
91
+ splits[-1] = "layers." + str(num_layers - layers_on_last_stage)
92
+
93
+ # Middle stages
94
+ remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1
95
+ middle_stages = len(splits) - 2
96
+ layers_per_middle_stage = remaining_layers // middle_stages
97
+ # split remainder evenly across middle stages
98
+ remainder = remaining_layers % middle_stages
99
+
100
+ for i in range(1, middle_stages + 1):
101
+ current_layer_index += layers_per_middle_stage
102
+ if remainder > 0:
103
+ current_layer_index += 1
104
+ remainder -= 1
105
+ splits[i] = "layers." + str(current_layer_index)
106
+
107
+ logger.info(
108
+ f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} "
109
+ "This may be sub-optimal as the number of layers per stage may be unbalanced."
110
+ )
111
+ return splits
112
+
113
+
114
+ def build_pipeline_schedule(
115
+ job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
116
+ ) -> _PipelineSchedule:
117
+ """Builds a pipeline schedule for the given job configuration and stages.
118
+
119
+ Args:
120
+ job_config (JobConfig): The job configuration.
121
+ stages (list[PipelineStage]): The stages to be scheduled.
122
+ loss_fn (Callable): The loss function.
123
+
124
+ Returns:
125
+ _PipelineSchedule: The pipeline schedule for the given stages.
126
+ """
127
+ pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv
128
+
129
+ # Validate that pp_schedule_csv is a valid path
130
+ if pp_schedule_csv:
131
+ if not os.path.isfile(pp_schedule_csv):
132
+ raise FileNotFoundError(
133
+ f"The specified path {pp_schedule_csv} does not exist or is not a file."
134
+ )
135
+ schedule_class = _PipelineScheduleRuntime
136
+ else:
137
+ schedule_class = get_schedule_class(
138
+ job_config.parallelism.pipeline_parallel_schedule
139
+ )
140
+
141
+ looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
142
+ microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
143
+ batch_size = job_config.training.batch_size
144
+ # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
145
+ if batch_size % microbatch_size != 0:
146
+ raise ValueError(
147
+ f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
148
+ "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
149
+ )
150
+ n_microbatches = batch_size // microbatch_size
151
+ # We expect that the number of local stages (`len(stages)`) is the same across all ranks
152
+ num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
153
+ if n_microbatches < num_total_stages:
154
+ logger.warning(
155
+ f"Number of microbatches ({n_microbatches}) is less than the total number "
156
+ f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
157
+ )
158
+
159
+ schedule = schedule_class(
160
+ stages if looped_schedule else stages[0],
161
+ n_microbatches=n_microbatches,
162
+ loss_fn=loss_fn,
163
+ )
164
+ logger.info(
165
+ f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
166
+ f"with {n_microbatches} microbatches and {num_total_stages} stages."
167
+ )
168
+
169
+ if pp_schedule_csv:
170
+ assert schedule_class in [
171
+ PipelineScheduleSingle,
172
+ PipelineScheduleMulti,
173
+ _PipelineScheduleRuntime,
174
+ ], (
175
+ "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), "
176
+ "and _PipelineScheduleRuntime support csv schedules"
177
+ )
178
+ schedule._load_csv(pp_schedule_csv)
179
+
180
+ return schedule
181
+
182
+
183
+ # TODO(whc) should this be a utility inside torch.pipelining?
184
+ def stage_ids_this_rank(
185
+ pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
186
+ ) -> tuple[int]:
187
+ """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
188
+ assert (
189
+ num_stages % pp_size == 0
190
+ ), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
191
+ stages_per_rank = num_stages // pp_size
192
+ if style == "loop":
193
+ return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
194
+ elif style == "v":
195
+ assert (
196
+ stages_per_rank == 2
197
+ ), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
198
+ stage_v_pairs = list(
199
+ zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
200
+ )
201
+ return stage_v_pairs[pp_rank]
torchtitan/experiments/deepseek_v3/checkpoint.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from typing import Dict, Optional, Set, Tuple
11
+
12
+ import torch
13
+ from safetensors import safe_open
14
+
15
+ from transformers.utils import cached_file
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
21
+
22
+
23
+ def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
24
+ try:
25
+ with open(file_path, "r") as file:
26
+ data = json.load(file)
27
+
28
+ if "weight_map" in data and isinstance(data["weight_map"], dict):
29
+ return data["weight_map"]
30
+ else:
31
+ logger.info("No 'weight_map' dictionary found in the JSON file.")
32
+ return None
33
+ except (json.JSONDecodeError, Exception) as e:
34
+ logger.info(f"An error occurred while reading the JSON file: {str(e)}")
35
+ return None
36
+
37
+
38
+ def get_hf_weight_map_and_path(
39
+ model_id: str,
40
+ ) -> Tuple[Dict[str, str], str]:
41
+ """Get the weight map for a given HF model id and also the cache path for loading the weights"""
42
+ try:
43
+ index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
44
+ except Exception as e:
45
+ logger.error(
46
+ f"Model `{model_id}` not found in HF cache. "
47
+ f"You can download the model using `python download.py {model_id}"
48
+ )
49
+ raise e
50
+
51
+ weight_map = read_weights_from_json(index_file)
52
+ weight_path = os.path.dirname(index_file)
53
+ logger.info(f"Loading weights from: {weight_path}")
54
+ return weight_map, weight_path
55
+
56
+
57
+ def get_needed_files(
58
+ state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
59
+ ) -> Set[str]:
60
+ needed_files = set()
61
+ for param in state_dict.keys():
62
+ file = weight_map.get(param)
63
+ if file:
64
+ needed_files.add(file)
65
+ elif param.endswith("weight"):
66
+ raise ValueError(
67
+ f"Parameter {param} not found in weight map, please check..."
68
+ )
69
+ logger.info(f"Needed files: {needed_files}")
70
+ return needed_files
71
+
72
+
73
+ def load_safetensor_file(
74
+ full_path: str, device: torch.device
75
+ ) -> Dict[str, torch.Tensor]:
76
+ tensors = {}
77
+ with safe_open(full_path, framework="pt", device=device) as f:
78
+ for k in f.keys():
79
+ tensors[k] = f.get_tensor(k)
80
+ logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
81
+ return tensors
82
+
83
+
84
+ def load_safetensor_weights(
85
+ model: torch.nn.Module,
86
+ weight_map: Dict[str, str],
87
+ file_location: str,
88
+ device: torch.device,
89
+ ):
90
+ """
91
+ Load safetensor weights into a `nn.Module`.
92
+
93
+ Args:
94
+ model (Module): The PyTorch module to load weights into. It may be a
95
+ model chunk or a full model.
96
+ weight_map (Dict[str, str]): Mapping of model parameters to file names.
97
+ file_location (str): Directory containing the weight files.
98
+ device (torch.device): The device to load tensors onto.
99
+ """
100
+ model_state_dict = model.state_dict()
101
+ needed_files = get_needed_files(model_state_dict, weight_map)
102
+ updated_states: Set[str] = set()
103
+
104
+ for file in needed_files:
105
+ full_path = os.path.join(file_location, file)
106
+ try:
107
+ checkpoint = load_safetensor_file(full_path, "cpu")
108
+ except FileNotFoundError:
109
+ logger.error(f"File not found: {full_path}")
110
+ except Exception as e:
111
+ logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
112
+
113
+ matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
114
+ for key in matched_keys:
115
+ # Check shape
116
+ if model_state_dict[key].shape != checkpoint[key].shape:
117
+ raise ValueError(
118
+ f"Shape mismatch for {key}: "
119
+ f"model needs {model_state_dict[key].shape}, but "
120
+ f"checkpoint has {checkpoint[key].shape}"
121
+ )
122
+ model_state_dict[key] = checkpoint[key].to(device)
123
+
124
+ updated_states.update(matched_keys)
125
+
126
+ missing_keys = set(model_state_dict.keys()) - updated_states
127
+ if missing_keys:
128
+ raise RuntimeError(
129
+ f"Partially updated state dict. Missing parameters: {missing_keys}"
130
+ )
131
+
132
+ model.load_state_dict(model_state_dict, strict=False, assign=True)
133
+ logger.info(f"Successfully loaded {len(updated_states)} weights into model")
134
+
135
+
136
+ def load_weights_from_hf(
137
+ model: torch.nn.Module,
138
+ distribution: str,
139
+ device: torch.device,
140
+ ):
141
+ """
142
+ Load the weights from Hugging Face format (index file + multiple safetensor
143
+ files), and fill into `model`. Model config is needed b/c we permute
144
+ wq and wk weights based on attn heads.
145
+ """
146
+
147
+ weight_map, weight_path = get_hf_weight_map_and_path(distribution)
148
+
149
+ load_safetensor_weights(
150
+ model,
151
+ weight_map,
152
+ weight_path,
153
+ device,
154
+ )
torchtitan/experiments/deepseek_v3/download.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Usage:
8
+ # Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
9
+ # python download.py {model_id} [custom_model_path]
10
+ # Examples:
11
+ # python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
12
+ # python download.py custom "deepseek-ai/new-model" # Download a custom model path
13
+
14
+ # Available models:
15
+ # "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
16
+ # "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
17
+ # "v2": "deepseek-ai/DeepSeek-V2",
18
+ # "v3": "deepseek-ai/deepseek-v3",
19
+ # "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
20
+ # "custom": None, # Placeholder for custom models
21
+
22
+
23
+ import sys
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
27
+
28
+ MODELS = {
29
+ "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
30
+ "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
31
+ "v2": "deepseek-ai/DeepSeek-V2",
32
+ "v3": "deepseek-ai/deepseek-v3",
33
+ "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
34
+ "custom": None, # For custom (any) models
35
+ }
36
+
37
+
38
+ def print_usage():
39
+ print("Usage:")
40
+ print(" python download.py [model_version]")
41
+ print(" python download.py custom [custom_model_path]")
42
+ print("\nAvailable predefined models:")
43
+ for key, model in MODELS.items():
44
+ if key != "custom": # Skip the custom placeholder
45
+ print(f" {key}: {model}")
46
+ print("\nFor custom models:")
47
+ print(" custom: Specify your own model path")
48
+ print(' Example: python download.py custom "organization/model-name"')
49
+ sys.exit(1)
50
+
51
+
52
+ # Process command line arguments
53
+ if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
54
+ print_usage()
55
+
56
+ if sys.argv[1] == "custom":
57
+ if len(sys.argv) != 3:
58
+ print("Error: Custom model requires a model path")
59
+ print_usage()
60
+ model_id = sys.argv[2]
61
+ print(f"Using custom model: {model_id}")
62
+ else:
63
+ model_id = MODELS[sys.argv[1]]
64
+ print(f"Downloading model: {model_id}")
65
+
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
torchtitan/experiments/deepseek_v3/generate.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 4 generate.py
8
+
9
+ # use inference.sh "Your Question Here?" to run inference with a single prompt.
10
+
11
+ import sys
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ from checkpoint import load_weights_from_hf
18
+ from model import DeepseekForCausalLM
19
+ from model_config import deepseek_config_registry
20
+ from torch.distributed.device_mesh import DeviceMesh
21
+ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
22
+ from torchtitan.tools.utils import Color
23
+ from transformers import AutoTokenizer
24
+
25
+ # Uncomment the model you want to run.
26
+ model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
27
+ # model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
28
+
29
+
30
+ def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
31
+ """Parse and colorize chat output with optional colors for each role."""
32
+ lines = text.split("\n")
33
+ result = []
34
+
35
+ current_role = None
36
+ current_content = []
37
+
38
+ def _process_current_content():
39
+ if not current_role or not current_content:
40
+ return None
41
+
42
+ content = "\n".join(current_content)
43
+ if current_role == "output":
44
+ return (
45
+ f"Output: {output_color}{content}{color.reset}"
46
+ if output_color
47
+ else f"Output: {content}"
48
+ )
49
+ else:
50
+ try:
51
+ prefix, rest = current_content[0].split(":", 1)
52
+ role_color = user_color if current_role == "user" else assistant_color
53
+ if role_color:
54
+ formatted = f"{prefix}:{role_color}{rest}{color.reset}"
55
+ if len(current_content) > 1:
56
+ formatted += (
57
+ f"{role_color}\n"
58
+ + "\n".join(current_content[1:])
59
+ + f"{color.reset}"
60
+ )
61
+ return formatted
62
+ except ValueError:
63
+ pass
64
+ return content
65
+
66
+ for line in lines:
67
+ if line.startswith("Output:"):
68
+ if processed := _process_current_content():
69
+ result.append(processed)
70
+ current_role = "output"
71
+ content = line[len("Output:") :].strip()
72
+ if output_color:
73
+ content = f"Output: {output_color}{content}{color.reset}"
74
+ else:
75
+ content = f"Output: {content}"
76
+ result.append(content)
77
+ current_content = []
78
+
79
+ elif line.startswith("User:"):
80
+ if processed := _process_current_content():
81
+ result.append(processed)
82
+ current_role = "user"
83
+ current_content = [line]
84
+
85
+ elif line.startswith("Assistant:"):
86
+ if processed := _process_current_content():
87
+ result.append(processed)
88
+ current_role = "assistant"
89
+ current_content = [line]
90
+
91
+ else:
92
+ if current_content:
93
+ current_content.append(line)
94
+ elif line.strip() and current_role is None:
95
+ # Handle system message at the beginning
96
+ current_role = "output"
97
+ if output_color:
98
+ result.append(f"Output: {output_color}{line.strip()}{color.reset}")
99
+ else:
100
+ result.append(f"Output: {line.strip()}")
101
+
102
+ # Process the last segment
103
+ if processed := _process_current_content():
104
+ result.append(processed)
105
+
106
+ return "\n".join(result)
107
+
108
+
109
+ color = Color()
110
+
111
+
112
+ @dataclass
113
+ class DistConfig:
114
+ mesh: DeviceMesh
115
+ pp_mesh: DeviceMesh
116
+ ep_mesh: DeviceMesh
117
+ pp_size: int
118
+ ep_size: int
119
+ ep_rank: int
120
+ pp_rank: int
121
+ device: torch.device
122
+
123
+
124
+ def create_model(dist_config: DistConfig):
125
+ model_args = deepseek_config_registry[model_id]
126
+ model_args.ep_size = dist_config.ep_size
127
+ model_args.num_stages = dist_config.pp_size
128
+ model_args.stage_idx = dist_config.pp_rank
129
+ model_args.max_seq_len = 16384
130
+
131
+ with dist_config.device, dist_config.mesh:
132
+ model = DeepseekForCausalLM(model_args)
133
+ load_weights_from_hf(model, model_id, dist_config.device)
134
+ model.eval()
135
+ model.setup_symm_mem(torch.bfloat16, dist_config.device)
136
+
137
+ stage = PipelineStage(
138
+ model,
139
+ dist_config.pp_rank,
140
+ dist_config.pp_size,
141
+ dist_config.device,
142
+ group=dist_config.pp_mesh.get_group(),
143
+ )
144
+ pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
145
+ return model, pp_schedule
146
+
147
+
148
+ def create_dist_config(mesh: DeviceMesh):
149
+ rank = dist.get_rank()
150
+ device_count = torch.cuda.device_count()
151
+ device = torch.device("cuda", rank % device_count)
152
+
153
+ dist_config = DistConfig(
154
+ mesh=mesh,
155
+ pp_mesh=mesh["pp"],
156
+ ep_mesh=mesh["ep"],
157
+ pp_rank=mesh["pp"].get_local_rank(),
158
+ pp_size=mesh["pp"].size(),
159
+ ep_size=mesh["ep"].size(),
160
+ ep_rank=mesh["ep"].get_local_rank(),
161
+ device=device,
162
+ )
163
+ return dist_config
164
+
165
+
166
+ def decode(tokenizer, x):
167
+ output = tokenizer.decode(x[0])
168
+ # Clean up the output by removing special tokens
169
+ bos = tokenizer.bos_token
170
+ output = output.replace(bos, "")
171
+ # Truncate at end of sentence token
172
+ eos_token = tokenizer.eos_token
173
+ if eos_token and eos_token in output:
174
+ output = output.split(eos_token)[0]
175
+ colored_output = colorize_chat(
176
+ output,
177
+ user_color=color.green,
178
+ assistant_color=color.cyan,
179
+ output_color=color.blue,
180
+ )
181
+ return colored_output
182
+
183
+
184
+ @torch.inference_mode()
185
+ def generate(
186
+ model,
187
+ pp_schedule,
188
+ tokenizer,
189
+ dist_config,
190
+ messages: list[dict],
191
+ n_tokens: int = 50,
192
+ ):
193
+ rank = dist.get_rank()
194
+ device = dist_config.device
195
+ x = tokenizer.apply_chat_template(
196
+ [messages] * dist_config.pp_size,
197
+ add_generation_prompt=True,
198
+ return_tensors="pt",
199
+ )
200
+ next_idx = x.shape[-1]
201
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
202
+ x = x.to(device)
203
+
204
+ for _ in range(n_tokens):
205
+ if dist_config.pp_size > 1:
206
+ if dist_config.pp_rank == 0:
207
+ pp_schedule.step(x)
208
+ torch.distributed.broadcast(
209
+ x,
210
+ group=dist_config.pp_mesh.get_group(),
211
+ group_src=dist_config.pp_size - 1,
212
+ )
213
+ elif dist_config.pp_rank == dist_config.pp_size - 1:
214
+ preds = pp_schedule.step()
215
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
216
+ x[:, next_idx] = next_token
217
+ torch.distributed.broadcast(
218
+ x,
219
+ group=dist_config.pp_mesh.get_group(),
220
+ group_src=dist_config.pp_size - 1,
221
+ )
222
+ else:
223
+ pp_schedule.step()
224
+ torch.distributed.broadcast(
225
+ x,
226
+ group=dist_config.pp_mesh.get_group(),
227
+ group_src=dist_config.pp_size - 1,
228
+ )
229
+
230
+ next_idx += 1
231
+ else:
232
+ preds = model(x)
233
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
234
+ x[:, next_idx] = next_token
235
+ next_idx += 1
236
+
237
+ if rank == 0:
238
+ colored_output = decode(tokenizer, x)
239
+ print(f"Without CUDA Graph:\n{colored_output}")
240
+
241
+
242
+ @torch.inference_mode()
243
+ def generate_with_cuda_graph(
244
+ model,
245
+ tokenizer,
246
+ dist_config,
247
+ messages: list[dict],
248
+ n_tokens: int = 10,
249
+ ):
250
+ rank = dist.get_rank()
251
+ device = dist_config.device
252
+ x = tokenizer.apply_chat_template(
253
+ [messages] * dist_config.pp_size,
254
+ add_generation_prompt=True,
255
+ return_tensors="pt",
256
+ )
257
+ next_idx = x.shape[-1]
258
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
259
+ x = x.to(device)
260
+
261
+ torch.cuda.synchronize()
262
+
263
+ # Create CUDA graph
264
+ g = torch.cuda.CUDAGraph()
265
+ with torch.cuda.graph(g):
266
+ preds = model(x)
267
+
268
+ # Run CUDA graph
269
+ for _ in range(n_tokens):
270
+ g.replay()
271
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
272
+ x[:, next_idx] = next_token
273
+ next_idx += 1
274
+
275
+ if rank == 0:
276
+ colored_output = decode(tokenizer, x)
277
+ print(f"With CUDA Graph:\n{colored_output}")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # Get user prompt from command line arguments
282
+ user_prompt = "What is 2+2?" # Default prompt
283
+ if len(sys.argv) > 1:
284
+ user_prompt = sys.argv[1]
285
+
286
+ mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
287
+ rank = dist.get_rank()
288
+ if rank == 0:
289
+ print(
290
+ f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
291
+ )
292
+
293
+ dist_config = create_dist_config(mesh)
294
+ model, pp_schedule = create_model(dist_config)
295
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
296
+
297
+ messages = [
298
+ {"role": "system", "content": "You are a helpful assistant."},
299
+ {"role": "user", "content": user_prompt},
300
+ ]
301
+
302
+ generate(model, pp_schedule, tokenizer, dist_config, messages)
303
+ generate_with_cuda_graph(model, tokenizer, dist_config, messages)
304
+
305
+ if rank == 0:
306
+ print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
307
+
308
+ dist.destroy_process_group()
torchtitan/experiments/deepseek_v3/inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/bash
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ NGPU=${NGPU:-"4"}
10
+
11
+ # Get the prompt from command line argument or use a default
12
+ prompt="${1:-What is 2+2?}"
13
+
14
+ # Run the model with the prompt
15
+ torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
torchtitan/experiments/deepseek_v3/model.py ADDED
@@ -0,0 +1,1325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
8
+ # Hugging Face Model Hub. Url:
9
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
10
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
11
+ #
12
+ # It has been modified from its original forms to accommodate naming convention
13
+ # and usage patterns of the TorchTitan project.
14
+
15
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """ PyTorch DeepSeek model."""
29
+ import math
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+
35
+ import torch.distributed._symmetric_memory as symm_mem
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint
38
+
39
+ from attn_mask_utils import _prepare_4d_causal_attention_mask
40
+ from indices import generate_permute_indices
41
+ from model_config import ModelArgs
42
+ from symm_mem_recipes import OnDeviceAllToAllV
43
+ from torch import nn
44
+ from torch.distributed._functional_collectives import all_to_all_single_autograd
45
+
46
+ from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
47
+ ALIGN_SIZE_M,
48
+ grouped_gemm_forward,
49
+ )
50
+
51
+ # Get model parallel subgroup by name:
52
+ # e.g. "pp", "ep", None
53
+ def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
54
+ glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
55
+ return glob.get_group(dim_name)
56
+
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-6):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ input_dtype = hidden_states.dtype
66
+ hidden_states = hidden_states.to(torch.float32)
67
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+ return self.weight * hidden_states.to(input_dtype)
70
+
71
+
72
+ class RotaryEmbedding(nn.Module):
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
74
+ super().__init__()
75
+
76
+ self.dim = dim
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.base = base
79
+ inv_freq = 1.0 / (
80
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
81
+ )
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+
84
+ # Build here to make `torch.jit.trace` work.
85
+ self._set_cos_sin_cache(
86
+ seq_len=max_position_embeddings,
87
+ device=self.inv_freq.device,
88
+ dtype=torch.get_default_dtype(),
89
+ )
90
+
91
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
92
+ self.max_seq_len_cached = seq_len
93
+ t = torch.arange(
94
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
95
+ )
96
+
97
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
98
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
99
+ emb = torch.cat((freqs, freqs), dim=-1)
100
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
101
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
102
+
103
+ def forward(self, x, seq_len=None):
104
+ # x: [bs, num_attention_heads, seq_len, head_size]
105
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
106
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
107
+
108
+ return (
109
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
110
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
111
+ )
112
+
113
+
114
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
115
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
116
+
117
+ def __init__(
118
+ self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None,
123
+ scaling_factor=1.0,
124
+ ):
125
+ self.scaling_factor = scaling_factor
126
+ super().__init__(dim, max_position_embeddings, base, device)
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ self.max_seq_len_cached = seq_len
130
+ t = torch.arange(
131
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
132
+ )
133
+ t = t / self.scaling_factor
134
+
135
+ freqs = torch.outer(t, self.inv_freq)
136
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
137
+ emb = torch.cat((freqs, freqs), dim=-1)
138
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
+
141
+
142
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
143
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
144
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
+
146
+ def __init__(
147
+ self,
148
+ dim,
149
+ max_position_embeddings=2048,
150
+ base=10000,
151
+ device=None,
152
+ scaling_factor=1.0,
153
+ ):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+
160
+ if seq_len > self.max_position_embeddings:
161
+ base = self.base * (
162
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
163
+ - (self.scaling_factor - 1)
164
+ ) ** (self.dim / (self.dim - 2))
165
+ inv_freq = 1.0 / (
166
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
167
+ )
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+
170
+ t = torch.arange(
171
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
172
+ )
173
+
174
+ freqs = torch.outer(t, self.inv_freq)
175
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
+
180
+
181
+ # Inverse dim formula to find dim based on number of rotations
182
+ def yarn_find_correction_dim(
183
+ num_rotations, dim, base=10000, max_position_embeddings=2048
184
+ ):
185
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
186
+ 2 * math.log(base)
187
+ )
188
+
189
+
190
+ # Find dim range bounds based on rotations
191
+ def yarn_find_correction_range(
192
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
193
+ ):
194
+ low = math.floor(
195
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
196
+ )
197
+ high = math.ceil(
198
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
199
+ )
200
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
201
+
202
+
203
+ def yarn_get_mscale(scale=1, mscale=1):
204
+ if scale <= 1:
205
+ return 1.0
206
+ return 0.1 * mscale * math.log(scale) + 1.0
207
+
208
+
209
+ def yarn_linear_ramp_mask(min, max, dim):
210
+ if min == max:
211
+ max += 0.001 # Prevent singularity
212
+
213
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
214
+ ramp_func = torch.clamp(linear_func, 0, 1)
215
+ return ramp_func
216
+
217
+
218
+ class YarnRotaryEmbedding(RotaryEmbedding):
219
+ def __init__(
220
+ self,
221
+ dim,
222
+ max_position_embeddings=2048,
223
+ base=10000,
224
+ device=None,
225
+ scaling_factor=1.0,
226
+ original_max_position_embeddings=4096,
227
+ beta_fast=32,
228
+ beta_slow=1,
229
+ mscale=1,
230
+ mscale_all_dim=0,
231
+ ):
232
+ self.scaling_factor = scaling_factor
233
+ self.original_max_position_embeddings = original_max_position_embeddings
234
+ self.beta_fast = beta_fast
235
+ self.beta_slow = beta_slow
236
+ self.mscale = mscale
237
+ self.mscale_all_dim = mscale_all_dim
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+ dim = self.dim
243
+
244
+ freq_extra = 1.0 / (
245
+ self.base
246
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
247
+ )
248
+ freq_inter = 1.0 / (
249
+ self.scaling_factor
250
+ * self.base
251
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
252
+ )
253
+
254
+ low, high = yarn_find_correction_range(
255
+ self.beta_fast,
256
+ self.beta_slow,
257
+ dim,
258
+ self.base,
259
+ self.original_max_position_embeddings,
260
+ )
261
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
262
+ device=device, dtype=torch.float32
263
+ )
264
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
265
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
266
+
267
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
268
+
269
+ freqs = torch.outer(t, inv_freq)
270
+
271
+ _mscale = float(
272
+ yarn_get_mscale(self.scaling_factor, self.mscale)
273
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
274
+ )
275
+
276
+ emb = torch.cat((freqs, freqs), dim=-1)
277
+ self.register_buffer(
278
+ "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
279
+ )
280
+ self.register_buffer(
281
+ "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
282
+ )
283
+
284
+
285
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
286
+ def rotate_half(x):
287
+ """Rotates half the hidden dims of the input."""
288
+ x1 = x[..., : x.shape[-1] // 2]
289
+ x2 = x[..., x.shape[-1] // 2 :]
290
+ return torch.cat((-x2, x1), dim=-1)
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
294
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
295
+ """Applies Rotary Position Embedding to the query and key tensors.
296
+
297
+ Args:
298
+ q (`torch.Tensor`): The query tensor.
299
+ k (`torch.Tensor`): The key tensor.
300
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
301
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
302
+ position_ids (`torch.Tensor`):
303
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
304
+ used to pass offsetted position ids when working with a KV-cache.
305
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
306
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
307
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
308
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
309
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
310
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
311
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
312
+ Returns:
313
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
314
+ """
315
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
316
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
317
+
318
+ b, h, s, d = q.shape
319
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
320
+
321
+ b, h, s, d = k.shape
322
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
323
+
324
+ q_embed = (q * cos) + (rotate_half(q) * sin)
325
+ k_embed = (k * cos) + (rotate_half(k) * sin)
326
+ return q_embed, k_embed
327
+
328
+
329
+ class MLP(nn.Module):
330
+ act_fn = nn.SiLU()
331
+
332
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
333
+ super().__init__()
334
+ self.config = config
335
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
336
+ self.intermediate_size = (
337
+ config.intermediate_size if intermediate_size is None else intermediate_size
338
+ )
339
+
340
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
341
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
342
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
343
+
344
+ def forward(self, x):
345
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
346
+ return down_proj
347
+
348
+
349
+ class MoEGate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.top_k = config.num_experts_per_tok
354
+ self.n_routed_experts = config.n_routed_experts
355
+ self.routed_scaling_factor = config.routed_scaling_factor
356
+ self.scoring_func = config.scoring_func
357
+ self.seq_aux = config.seq_aux
358
+ self.topk_method = config.topk_method
359
+ self.n_group = config.n_group
360
+ self.topk_group = config.topk_group
361
+
362
+ # topk selection algorithm
363
+ self.norm_topk_prob = config.norm_topk_prob
364
+ self.gating_dim = config.hidden_size
365
+ self.weight = nn.Parameter(
366
+ torch.empty((self.n_routed_experts, self.gating_dim))
367
+ )
368
+ if self.topk_method == "noaux_tc":
369
+ self.e_score_correction_bias = nn.Parameter(
370
+ # Changed from torch.empty to torch.rand to avoid non-even
371
+ # distribution for runs without actual weigths
372
+ torch.rand((self.n_routed_experts))
373
+ )
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self) -> None:
377
+ import torch.nn.init as init
378
+
379
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
380
+
381
+ def forward(self, hidden_states):
382
+ bsz, seq_len, h = hidden_states.shape
383
+ # compute gating score
384
+ hidden_states = hidden_states.view(-1, h)
385
+ logits = F.linear(
386
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
387
+ )
388
+ if self.scoring_func == "sigmoid":
389
+ scores = logits.sigmoid()
390
+ elif self.scoring_func == "softmax":
391
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
392
+ else:
393
+ raise NotImplementedError(
394
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
395
+ )
396
+
397
+ # select top-k experts
398
+ if self.topk_method == "noaux_tc":
399
+ scores_for_choice = scores.view(
400
+ bsz * seq_len, -1
401
+ ) + self.e_score_correction_bias.unsqueeze(0)
402
+ group_scores = (
403
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1)
404
+ .topk(2, dim=-1)[0]
405
+ .sum(dim=-1)
406
+ ) # [n, n_group]
407
+ group_idx = torch.topk(
408
+ group_scores, k=self.topk_group, dim=-1, sorted=False
409
+ )[
410
+ 1
411
+ ] # [n, top_k_group]
412
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
413
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
414
+ score_mask = (
415
+ group_mask.unsqueeze(-1)
416
+ .expand(
417
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
418
+ )
419
+ .reshape(bsz * seq_len, -1)
420
+ ) # [n, e]
421
+ tmp_scores = scores_for_choice.masked_fill(
422
+ ~score_mask.bool(), 0.0
423
+ ) # [n, e]
424
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
425
+ topk_weight = scores.gather(1, topk_idx)
426
+ elif self.topk_method == "greedy":
427
+ topk_weight, topk_idx = torch.topk(
428
+ scores, k=self.top_k, dim=-1, sorted=False
429
+ )
430
+ else:
431
+ raise NotImplementedError(
432
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
433
+ )
434
+
435
+ # norm gate to sum 1
436
+ if self.top_k > 1 and self.norm_topk_prob:
437
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
438
+ topk_weight = topk_weight / denominator
439
+ topk_weight = (
440
+ topk_weight * self.routed_scaling_factor
441
+ ) # must multiply the scaling factor
442
+
443
+ return topk_idx, topk_weight
444
+
445
+
446
+ class MoE(nn.Module):
447
+ """
448
+ A mixed expert module containing shared experts.
449
+ """
450
+
451
+ # Class attributes:
452
+ # Two shuffle method supported:
453
+ # 1. "torch_all_to_all"
454
+ # 2. "symm_mem" (see `setup_symm_mem` below)
455
+ shuffle_method = "torch_all_to_all"
456
+
457
+ # Symmetric memory buffers shared by all MoE instances across layers
458
+ token_send_buf: Optional[torch.Tensor] = None
459
+ token_gather_buf: Optional[torch.Tensor] = None
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.num_experts_per_tok = config.num_experts_per_tok
465
+
466
+ # ep_size is the number of ranks in expert dimension
467
+ if config.ep_size <= 1:
468
+ raise ValueError(
469
+ "For code simplicity, this model only supports distributed experts, "
470
+ "thus EP size must be > 1, please modify your model config"
471
+ )
472
+ self.ep_group = get_group("ep")
473
+ assert config.ep_size == self.ep_group.size()
474
+ self.ep_size = config.ep_size
475
+ self.ep_rank = self.ep_group.rank()
476
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
477
+ # Use ModuleDict instead of ModuleList to preserve absoulte expert
478
+ # IDs while avoiding `None` experts. The absolute expert IDs match
479
+ # with checkpoint FQNs.
480
+ self.experts = nn.ModuleDict()
481
+ for i in range(self.experts_per_rank):
482
+ abs_expert_id = self.ep_rank * self.experts_per_rank + i
483
+ self.experts[str(abs_expert_id)] = MLP(
484
+ config, intermediate_size=config.moe_intermediate_size
485
+ )
486
+ self.gate = MoEGate(config)
487
+ if config.n_shared_experts is not None:
488
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
489
+ self.shared_experts = MLP(
490
+ config=config, intermediate_size=intermediate_size
491
+ )
492
+
493
+ def combine_experts(self, submod_name):
494
+ all_weights = []
495
+ for expert in self.experts.values():
496
+ lin = expert.get_submodule(submod_name)
497
+ all_weights.append(lin.weight)
498
+ lin.weight = None
499
+
500
+ concat_weight = torch.cat(all_weights)
501
+ self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
502
+
503
+ # This function is used to create a symm mem buffer for MoE's. It is for
504
+ # shuffling tokens fully "on-device", as compared to traditional torch
505
+ # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
506
+ # calls this function, the `shuffle_method` would switch from
507
+ # `torch_all_to_all` to `symm_mem`.
508
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
509
+ # Switch shuffle method
510
+ self.shuffle_method = "symm_mem"
511
+
512
+ # Combine expert weights
513
+ print("Combining expert weights for Group GEMM")
514
+ self.combine_experts("gate_proj")
515
+ self.combine_experts("up_proj")
516
+ self.combine_experts("down_proj")
517
+
518
+ # Assuming worst case, 2x tokens are routed to one EP rank
519
+ overflow = 2
520
+ OnDeviceAllToAllV.max_output_len = (
521
+ self.config.max_seq_len * self.num_experts_per_tok * overflow
522
+ )
523
+
524
+ # Symmetric memory buffers are shared by all MoE instances across
525
+ # layers, we only need to initialize them once
526
+ if MoE.token_send_buf is not None:
527
+ return
528
+
529
+ # Input buffer for DP-to-EP shuffle
530
+ MoE.token_send_buf = symm_mem.empty(
531
+ self.config.max_seq_len
532
+ * self.num_experts_per_tok, # seq len * top k (flattened)
533
+ self.config.hidden_size, # hidden dim
534
+ dtype=dtype,
535
+ device=device,
536
+ )
537
+ # Input buffer for EP-to-DP shuffle
538
+ MoE.token_gather_buf = symm_mem.empty(
539
+ self.config.max_seq_len
540
+ * self.num_experts_per_tok # seq len * top k (flattened)
541
+ * overflow,
542
+ self.config.hidden_size, # hidden dim
543
+ dtype=dtype,
544
+ device=device,
545
+ )
546
+ print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
547
+
548
+ def get_send_buf(self):
549
+ # [Why detach?] During a first forward-backward step, the buffer would
550
+ # be included in a computational graph. In a second step, autograd will
551
+ # return an error saying "Trying to backward through the graph a second
552
+ # time (or directly access saved tensors more than once)". This is
553
+ # because the buffer is still in the graph, and autograd is trying to
554
+ # backward through the graph a second time. To avoid this, we detach the
555
+ # buffer from the graph. `detach()` returns a new tensor, which shares
556
+ # the same storage with the original one.
557
+ self.token_send_buf.grad = None
558
+ return self.token_send_buf.detach()
559
+
560
+ def get_gather_buf(self):
561
+ # See [Why detach?] in `get_send_buf`
562
+ self.token_gather_buf.grad = None
563
+ return self.token_gather_buf.detach()
564
+
565
+ def forward(self, hidden_states):
566
+ identity = hidden_states
567
+ orig_shape = hidden_states.shape
568
+ # for each token, select top-k experts, and compute the weight for each expert
569
+ topk_idx, topk_weight = self.gate(hidden_states)
570
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
571
+ if self.shuffle_method == "symm_mem":
572
+ y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
573
+ else: # "torch_all_to_all"
574
+ y = self.moe_forward(hidden_states, topk_idx, topk_weight)
575
+
576
+ y = y.view(*orig_shape)
577
+ if self.config.n_shared_experts is not None:
578
+ y = y + self.shared_experts(identity)
579
+ return y
580
+
581
+ def moe_forward(self, x, topk_ids, topk_weight):
582
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
583
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
584
+ # Since this is an "aritificial" index creation (final outcome being
585
+ # `idxs`), we don't need gradients here.
586
+ with torch.no_grad():
587
+ # [seq_len, n_routed_experts]
588
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
589
+ # Fill 1 to the selected experts
590
+ cnts.scatter_(1, topk_ids, 1)
591
+ tokens_per_expert = cnts.sum(dim=0)
592
+ # Token indices for each expert
593
+ idxs = topk_ids.view(-1).argsort()
594
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
595
+
596
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
597
+ assert sorted_tokens.shape == sorted_tokens_shape
598
+
599
+ # This part exchange the information about the number of tokens send and
600
+ # received by each expert. We can understand this information as "side
601
+ # band", which is not part of the actual data. Thus no gradient is
602
+ # needed.
603
+ with torch.no_grad():
604
+ # Sum the tokens over local experts, then we get tokens per EP rank,
605
+ # which is the input splits
606
+ tokens_per_expert_group = tokens_per_expert.new_empty(
607
+ tokens_per_expert.shape[0]
608
+ )
609
+ dist.all_to_all_single(
610
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
611
+ )
612
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
613
+
614
+ # DP to EP token shuffle. This part needs gradient.
615
+ if self.shuffle_method == "symm_mem":
616
+ # Move input to the `token_send_buf` symm mem
617
+ token_send_buf = self.get_send_buf()
618
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
619
+ # Note: `out=` avoids copy, but it is not differentiable
620
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
621
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
622
+ token_send_buf,
623
+ input_splits,
624
+ self.ep_group,
625
+ )
626
+ with torch.no_grad():
627
+ # Received tokens from all other ranks. TODO: use mask instead
628
+ received = output_splits.sum()
629
+ # TODO: don't use `received`
630
+ gathered_tokens = token_gather_buf[:received]
631
+ else: # "torch_all_to_all"
632
+ # Prepare input ans output splits
633
+ with torch.no_grad():
634
+ output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
635
+ dim=1
636
+ )
637
+ gathered_tokens = all_to_all_single_autograd(
638
+ sorted_tokens,
639
+ output_splits.tolist(),
640
+ input_splits.tolist(),
641
+ self.ep_group,
642
+ )
643
+
644
+ # This part prepares a 1D tensor with the same length as
645
+ # `gathered_tokens`. The 1D tensor is filled with local expert IDs which
646
+ # the tokens in `gathered_tokens` are headed for. This part doesn't need
647
+ # gradient.
648
+ with torch.no_grad():
649
+ gatherd_idxs = (
650
+ torch.arange(
651
+ tokens_per_expert_group.numel(),
652
+ device=tokens_per_expert_group.device,
653
+ )
654
+ % self.experts_per_rank
655
+ )
656
+ gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
657
+
658
+ # Prepare buffer for tokens processed by experts
659
+ if self.shuffle_method == "symm_mem":
660
+ # Take necessary space from `token_gather_buf` symm mem because we are
661
+ # going to send them out after expert processing
662
+ processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
663
+ else: # "torch_all_to_all"
664
+ processed_tokens = torch.empty_like(gathered_tokens)
665
+
666
+ # This part processes the tokens routed to the local experts.
667
+ # TODO: can we use group GEMM here?
668
+ for i, expert in enumerate(self.experts.values()):
669
+ processed_tokens[gatherd_idxs == i] = expert(
670
+ gathered_tokens[gatherd_idxs == i]
671
+ )
672
+
673
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
674
+ # The input/output splits are just a reverse of the previous shuffle.
675
+ if self.shuffle_method == "symm_mem":
676
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
677
+ processed_tokens,
678
+ output_splits,
679
+ self.ep_group,
680
+ )
681
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
682
+ else: # "torch_all_to_all"
683
+ returned_tokens = all_to_all_single_autograd(
684
+ processed_tokens,
685
+ input_splits.tolist(),
686
+ output_splits.tolist(),
687
+ self.ep_group,
688
+ )
689
+
690
+ output_tokens = torch.empty_like(returned_tokens)
691
+ output_tokens[idxs] = returned_tokens
692
+ final_out = (
693
+ output_tokens.view(*topk_ids.shape, -1)
694
+ .type(topk_weight.dtype)
695
+ .mul_(topk_weight.unsqueeze(dim=-1))
696
+ .sum(dim=1)
697
+ .type(returned_tokens.dtype)
698
+ )
699
+ return final_out
700
+
701
+ def moe_on_device(self, x, topk_ids, topk_weight):
702
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
703
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
704
+ # Since this is an "aritificial" index creation (final outcome being
705
+ # `idxs`), we don't need gradients here.
706
+ with torch.no_grad():
707
+ # [seq_len, n_routed_experts]
708
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
709
+ # Fill 1 to the selected experts
710
+ cnts.scatter_(1, topk_ids, 1)
711
+ tokens_per_expert = cnts.sum(dim=0)
712
+ # Token indices for each expert
713
+ idxs = topk_ids.view(-1).argsort()
714
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
715
+
716
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
717
+ assert sorted_tokens.shape == sorted_tokens_shape
718
+
719
+ # This part exchange the information about the number of tokens send and
720
+ # received by each expert. We can understand this information as "side
721
+ # band", which is not part of the actual data. Thus no gradient is
722
+ # needed.
723
+ with torch.no_grad():
724
+ # Sum the tokens over local experts, then we get tokens per EP rank,
725
+ # which is the input splits
726
+ tokens_per_expert_group = tokens_per_expert.new_empty(
727
+ tokens_per_expert.shape[0]
728
+ )
729
+ dist.all_to_all_single(
730
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
731
+ )
732
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
733
+
734
+ # Move input to the `token_send_buf` symm mem
735
+ token_send_buf = self.get_send_buf()
736
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
737
+ # Note: `out=` avoids copy, but it is not differentiable
738
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
739
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
740
+ token_send_buf,
741
+ input_splits,
742
+ self.ep_group,
743
+ )
744
+
745
+ # We need to permute the received tokens so that tokens for the same expert are contiguous.
746
+ # This part prepares a 1D tensor `permuted_indices` for such permutation.
747
+ # This part doesn't need gradient.
748
+ with torch.no_grad():
749
+ permuted_indices, m_sizes = generate_permute_indices(
750
+ tokens_per_expert_group,
751
+ self.experts_per_rank,
752
+ self.ep_size,
753
+ token_gather_buf.shape[0],
754
+ ALIGN_SIZE_M,
755
+ )
756
+
757
+ # Permute the received tokens so that tokens for the same expert are contiguous.
758
+ contig_tokens = token_gather_buf[permuted_indices]
759
+
760
+ # Run the first grouped GEMM
761
+ w1 = self.get_parameter("gate_proj_weight")
762
+ gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
763
+
764
+ # Run the second grouped GEMM
765
+ w3 = self.get_parameter("up_proj_weight")
766
+ up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
767
+
768
+ # Apply activation
769
+ hidden_outputs = MLP.act_fn(gate_proj) * up_proj
770
+
771
+ # Run the third grouped GEMM
772
+ w2 = self.get_parameter("down_proj_weight")
773
+ hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
774
+
775
+ # Prepare buffer for tokens processed by experts
776
+ # Take necessary space from `token_gather_buf` symm mem because we are
777
+ # going to send them out after expert processing
778
+ processed_tokens = self.get_gather_buf()
779
+
780
+ # Move into Symmetric Memory for the return shuffle
781
+ processed_tokens[permuted_indices] = hidden_outputs
782
+
783
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
784
+ # The input/output splits are just a reverse of the previous shuffle.
785
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
786
+ processed_tokens,
787
+ output_splits,
788
+ self.ep_group,
789
+ )
790
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
791
+
792
+ output_tokens = torch.empty_like(returned_tokens)
793
+ output_tokens[idxs] = returned_tokens
794
+ final_out = (
795
+ output_tokens.view(*topk_ids.shape, -1)
796
+ .type(topk_weight.dtype)
797
+ .mul_(topk_weight.unsqueeze(dim=-1))
798
+ .sum(dim=1)
799
+ .type(returned_tokens.dtype)
800
+ )
801
+ return final_out
802
+
803
+
804
+ class Attention(nn.Module):
805
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
806
+
807
+ def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.layer_idx = layer_idx
811
+ self.attention_dropout = config.attention_dropout
812
+ self.hidden_size = config.hidden_size
813
+ self.num_heads = config.num_attention_heads
814
+
815
+ self.max_position_embeddings = config.max_position_embeddings
816
+ self.rope_theta = config.rope_theta
817
+ self.q_lora_rank = config.q_lora_rank
818
+ self.qk_rope_head_dim = config.qk_rope_head_dim
819
+ self.kv_lora_rank = config.kv_lora_rank
820
+ self.v_head_dim = config.v_head_dim
821
+ self.qk_nope_head_dim = config.qk_nope_head_dim
822
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
823
+
824
+ self.is_causal = True
825
+
826
+ if self.q_lora_rank is None:
827
+ self.q_proj = nn.Linear(
828
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
829
+ )
830
+ else:
831
+ self.q_a_proj = nn.Linear(
832
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
833
+ )
834
+ self.q_a_layernorm = RMSNorm(config.q_lora_rank)
835
+ self.q_b_proj = nn.Linear(
836
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
837
+ )
838
+
839
+ self.kv_a_proj_with_mqa = nn.Linear(
840
+ self.hidden_size,
841
+ config.kv_lora_rank + config.qk_rope_head_dim,
842
+ bias=config.attention_bias,
843
+ )
844
+ self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
845
+ self.kv_b_proj = nn.Linear(
846
+ config.kv_lora_rank,
847
+ self.num_heads
848
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
849
+ bias=False,
850
+ )
851
+
852
+ self.o_proj = nn.Linear(
853
+ self.num_heads * self.v_head_dim,
854
+ self.hidden_size,
855
+ bias=config.attention_bias,
856
+ )
857
+ self._init_rope()
858
+
859
+ self.softmax_scale = self.q_head_dim ** (-0.5)
860
+ if self.config.rope_scaling is not None:
861
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
862
+ scaling_factor = self.config.rope_scaling["factor"]
863
+ if mscale_all_dim:
864
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
865
+ self.softmax_scale = self.softmax_scale * mscale * mscale
866
+
867
+ def _init_rope(self):
868
+ if self.config.rope_scaling is None:
869
+ self.rotary_emb = RotaryEmbedding(
870
+ self.qk_rope_head_dim,
871
+ max_position_embeddings=self.max_position_embeddings,
872
+ base=self.rope_theta,
873
+ )
874
+ else:
875
+ scaling_type = self.config.rope_scaling["type"]
876
+ scaling_factor = self.config.rope_scaling["factor"]
877
+ if scaling_type == "linear":
878
+ self.rotary_emb = LinearScalingRotaryEmbedding(
879
+ self.qk_rope_head_dim,
880
+ max_position_embeddings=self.max_position_embeddings,
881
+ scaling_factor=scaling_factor,
882
+ base=self.rope_theta,
883
+ )
884
+ elif scaling_type == "dynamic":
885
+ self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
886
+ self.qk_rope_head_dim,
887
+ max_position_embeddings=self.max_position_embeddings,
888
+ scaling_factor=scaling_factor,
889
+ base=self.rope_theta,
890
+ )
891
+ elif scaling_type == "yarn":
892
+ kwargs = {
893
+ key: self.config.rope_scaling[key]
894
+ for key in [
895
+ "original_max_position_embeddings",
896
+ "beta_fast",
897
+ "beta_slow",
898
+ "mscale",
899
+ "mscale_all_dim",
900
+ ]
901
+ if key in self.config.rope_scaling
902
+ }
903
+ self.rotary_emb = YarnRotaryEmbedding(
904
+ self.qk_rope_head_dim,
905
+ max_position_embeddings=self.max_position_embeddings,
906
+ scaling_factor=scaling_factor,
907
+ base=self.rope_theta,
908
+ **kwargs,
909
+ )
910
+ else:
911
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
912
+
913
+ def forward(
914
+ self,
915
+ hidden_states: torch.Tensor,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ position_ids: Optional[torch.LongTensor] = None,
918
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
919
+ bsz, q_len, _ = hidden_states.size()
920
+
921
+ if self.q_lora_rank is None:
922
+ q = self.q_proj(hidden_states)
923
+ else:
924
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
925
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
926
+ q_nope, q_pe = torch.split(
927
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
928
+ )
929
+
930
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
931
+ compressed_kv, k_pe = torch.split(
932
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
933
+ )
934
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
935
+ kv = (
936
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
937
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
938
+ .transpose(1, 2)
939
+ )
940
+
941
+ k_nope, value_states = torch.split(
942
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
943
+ )
944
+ kv_seq_len = value_states.shape[-2]
945
+
946
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
947
+
948
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
949
+
950
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
951
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
952
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
953
+
954
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
955
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
956
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
957
+
958
+ if attention_mask is not None:
959
+ # Attention mask was made 4D because the `attn_weights` above is 4D.
960
+ # We probably can make this mask smarter if we want to pack sequences
961
+ # together, instead of using padding. This optimization can be used in
962
+ # inference. For training, if we want to pack sequences, data loader
963
+ # will pass in a mask containing such info.
964
+ attention_mask = _prepare_4d_causal_attention_mask(
965
+ attention_mask, # None, or user provided mask in 2D
966
+ (bsz, q_len),
967
+ hidden_states,
968
+ 0, # past_key_values_length, 0 when training
969
+ )
970
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
971
+ raise ValueError(
972
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
973
+ )
974
+
975
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
976
+ query=query_states,
977
+ key=key_states,
978
+ value=value_states,
979
+ attn_mask=attention_mask,
980
+ dropout_p=self.attention_dropout,
981
+ is_causal=attention_mask is None,
982
+ scale=self.softmax_scale,
983
+ )
984
+
985
+ attn_output = attn_output.transpose(1, 2).contiguous()
986
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
987
+ attn_output = self.o_proj(attn_output)
988
+
989
+ return attn_output
990
+
991
+
992
+ class DecoderLayer(nn.Module):
993
+ def __init__(self, config: ModelArgs, layer_idx: int):
994
+ super().__init__()
995
+ self.hidden_size = config.hidden_size
996
+
997
+ self.self_attn = Attention(config=config, layer_idx=layer_idx)
998
+
999
+ self.mlp = (
1000
+ MoE(config)
1001
+ if (
1002
+ config.n_routed_experts is not None
1003
+ and layer_idx >= config.first_k_dense_replace
1004
+ and layer_idx % config.moe_layer_freq == 0
1005
+ )
1006
+ else MLP(config)
1007
+ )
1008
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1009
+ self.post_attention_layernorm = RMSNorm(
1010
+ config.hidden_size, eps=config.rms_norm_eps
1011
+ )
1012
+
1013
+ def forward(
1014
+ self,
1015
+ hidden_states: torch.Tensor,
1016
+ attention_mask: Optional[torch.Tensor] = None,
1017
+ position_ids: Optional[torch.LongTensor] = None,
1018
+ ) -> torch.Tensor:
1019
+ """
1020
+ Args:
1021
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1022
+ attention_mask (`torch.FloatTensor`, *optional*):
1023
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1024
+ query_sequence_length, key_sequence_length)` if default attention is used.
1025
+ """
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states = self.input_layernorm(hidden_states)
1029
+
1030
+ # Self Attention
1031
+ hidden_states = self.self_attn(
1032
+ hidden_states=hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ )
1036
+ hidden_states = residual + hidden_states
1037
+
1038
+ # Fully Connected
1039
+ residual = hidden_states
1040
+ hidden_states = self.post_attention_layernorm(hidden_states)
1041
+ hidden_states = self.mlp(hidden_states)
1042
+ hidden_states = residual + hidden_states
1043
+
1044
+ return hidden_states
1045
+
1046
+
1047
+ Deepseek_INPUTS_DOCSTRING = r"""
1048
+ Args:
1049
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1050
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1051
+ it.
1052
+
1053
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1054
+ [`PreTrainedTokenizer.__call__`] for details.
1055
+
1056
+ [What are input IDs?](../glossary#input-ids)
1057
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+ - 1 for tokens that are **not masked**,
1061
+ - 0 for tokens that are **masked**.
1062
+
1063
+ [What are attention masks?](../glossary#attention-mask)
1064
+
1065
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1066
+ [`PreTrainedTokenizer.__call__`] for details.
1067
+
1068
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1069
+ `past_key_values`).
1070
+
1071
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1072
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1073
+ information on the default strategy.
1074
+
1075
+ - 1 indicates the head is **not masked**,
1076
+ - 0 indicates the head is **masked**.
1077
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1078
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1079
+ config.n_positions - 1]`.
1080
+
1081
+ [What are position IDs?](../glossary#position-ids)
1082
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1083
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1084
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1085
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1086
+
1087
+ Two formats are allowed:
1088
+ - a [`~cache_utils.Cache`] instance;
1089
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1090
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1091
+ cache format.
1092
+
1093
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1094
+ legacy cache format will be returned.
1095
+
1096
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1097
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1098
+ of shape `(batch_size, sequence_length)`.
1099
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1100
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1101
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1102
+ model's internal embedding lookup matrix.
1103
+ use_cache (`bool`, *optional*):
1104
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1105
+ `past_key_values`).
1106
+ output_attentions (`bool`, *optional*):
1107
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1108
+ tensors for more detail.
1109
+ output_hidden_states (`bool`, *optional*):
1110
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1111
+ more detail.
1112
+ return_dict (`bool`, *optional*):
1113
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1114
+ """
1115
+
1116
+
1117
+ class DeepseekModel(torch.nn.Module):
1118
+ """
1119
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
1120
+
1121
+ Args:
1122
+ config: ModelArgs
1123
+ """
1124
+
1125
+ def __init__(self, config: ModelArgs):
1126
+ super().__init__()
1127
+ self.config = config
1128
+ self.padding_idx = config.pad_token_id
1129
+ self.vocab_size = config.vocab_size
1130
+
1131
+ # Creating model parts related to my stage
1132
+ assert (
1133
+ config.stage_idx < config.num_stages
1134
+ ), f"Stage {config.stage_idx} is not in the model"
1135
+ print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
1136
+
1137
+ self.embed_tokens = (
1138
+ nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1139
+ if config.stage_idx == 0
1140
+ else None
1141
+ )
1142
+
1143
+ self.layers = torch.nn.ModuleDict()
1144
+ division = config.num_hidden_layers // config.num_stages
1145
+ residual = config.num_hidden_layers % config.num_stages
1146
+ # Some earlier stages may have 1 more layer than latter stages because
1147
+ # the division may have residual; this is more even than giving the
1148
+ # entire residual to the last stage.
1149
+ layers_per_stage = [
1150
+ division + 1 if stage < residual else division
1151
+ for stage in range(config.num_stages)
1152
+ ]
1153
+ assert sum(layers_per_stage) == config.num_hidden_layers
1154
+ layer_id_start = sum(layers_per_stage[: config.stage_idx])
1155
+ layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
1156
+ for layer_id in range(layer_id_start, layer_id_end):
1157
+ self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
1158
+
1159
+ self.norm = (
1160
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
+ if config.stage_idx == config.num_stages - 1
1162
+ else None
1163
+ )
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.apply(self._init_weights)
1167
+
1168
+ def _init_weights(self, module):
1169
+ std = self.config.initializer_range
1170
+ if isinstance(module, nn.Linear):
1171
+ module.weight.data.normal_(mean=0.0, std=std)
1172
+ if module.bias is not None:
1173
+ module.bias.data.zero_()
1174
+ elif isinstance(module, nn.Embedding):
1175
+ module.weight.data.normal_(mean=0.0, std=std)
1176
+ if module.padding_idx is not None:
1177
+ module.weight.data[module.padding_idx].zero_()
1178
+
1179
+ def forward(
1180
+ self,
1181
+ tokens: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ position_ids: Optional[torch.LongTensor] = None,
1184
+ ) -> torch.Tensor:
1185
+ # Embedding
1186
+ hidden_states = (
1187
+ self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
1188
+ )
1189
+
1190
+ # decoder layers
1191
+ for decoder_layer in self.layers.values():
1192
+ hidden_states = decoder_layer(
1193
+ hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ position_ids=position_ids,
1196
+ )
1197
+
1198
+ hidden_states = (
1199
+ self.norm(hidden_states) if self.norm is not None else hidden_states
1200
+ )
1201
+ return hidden_states
1202
+
1203
+
1204
+ class DeepseekForCausalLM(torch.nn.Module):
1205
+ def __init__(self, config):
1206
+ super().__init__()
1207
+ self.model = DeepseekModel(config)
1208
+ self.lm_head = (
1209
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1210
+ if config.stage_idx == config.num_stages - 1
1211
+ else None
1212
+ )
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ # self.post_init()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ tokens: torch.Tensor,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ position_ids: Optional[torch.LongTensor] = None,
1222
+ ) -> Tuple:
1223
+ r"""
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import AutoTokenizer, DeepseekForCausalLM
1228
+
1229
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1230
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1231
+
1232
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1233
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1234
+
1235
+ >>> # Generate
1236
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1237
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1238
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1239
+ ```"""
1240
+ hidden_states = self.model(
1241
+ tokens,
1242
+ attention_mask=attention_mask,
1243
+ position_ids=position_ids,
1244
+ )
1245
+
1246
+ logits = (
1247
+ self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
1248
+ )
1249
+ return logits
1250
+
1251
+ def prepare_inputs_for_generation(
1252
+ self,
1253
+ input_ids,
1254
+ past_key_values=None,
1255
+ attention_mask=None,
1256
+ **kwargs,
1257
+ ):
1258
+ if past_key_values is not None:
1259
+ # Assuming isinstance(past_key_values, Cache):
1260
+ cache_length = past_key_values.get_seq_length()
1261
+ past_length = past_key_values.seen_tokens
1262
+ max_cache_length = past_key_values.get_max_length()
1263
+
1264
+ # Keep only the unprocessed tokens:
1265
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1266
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1267
+ # input)
1268
+ if (
1269
+ attention_mask is not None
1270
+ and attention_mask.shape[1] > input_ids.shape[1]
1271
+ ):
1272
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1273
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1274
+ # input_ids based on the past_length.
1275
+ elif past_length < input_ids.shape[1]:
1276
+ input_ids = input_ids[:, past_length:]
1277
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1278
+
1279
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1280
+ if (
1281
+ max_cache_length is not None
1282
+ and attention_mask is not None
1283
+ and cache_length + input_ids.shape[1] > max_cache_length
1284
+ ):
1285
+ attention_mask = attention_mask[:, -max_cache_length:]
1286
+
1287
+ position_ids = kwargs.get("position_ids", None)
1288
+ if attention_mask is not None and position_ids is None:
1289
+ # create position_ids on the fly for batch generation
1290
+ position_ids = attention_mask.long().cumsum(-1) - 1
1291
+ position_ids.masked_fill_(attention_mask == 0, 1)
1292
+ if past_key_values:
1293
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1294
+
1295
+ model_inputs = {"input_ids": input_ids}
1296
+
1297
+ model_inputs.update(
1298
+ {
1299
+ "position_ids": position_ids,
1300
+ "past_key_values": past_key_values,
1301
+ "use_cache": kwargs.get("use_cache"),
1302
+ "attention_mask": attention_mask,
1303
+ }
1304
+ )
1305
+ return model_inputs
1306
+
1307
+ @staticmethod
1308
+ def _reorder_cache(past_key_values, beam_idx):
1309
+ reordered_past = ()
1310
+ for layer_past in past_key_values:
1311
+ reordered_past += (
1312
+ tuple(
1313
+ past_state.index_select(0, beam_idx.to(past_state.device))
1314
+ for past_state in layer_past
1315
+ ),
1316
+ )
1317
+ return reordered_past
1318
+
1319
+ # Setup Symmetric Memory for MoE token shuffle.
1320
+ # Supports inference currently.
1321
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
1322
+ for layer in self.model.layers.values():
1323
+ if not isinstance(layer.mlp, MoE):
1324
+ continue
1325
+ layer.mlp.setup_symm_mem(dtype, device)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def get_tid():
13
+ return tl.inline_asm_elementwise(
14
+ """
15
+ mov.u32 $0, %tid.x;
16
+ mov.u32 $1, %tid.y;
17
+ mov.u32 $2, %tid.z;
18
+ """,
19
+ "=r,=r,=r",
20
+ [],
21
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
22
+ is_pure=True,
23
+ pack=1,
24
+ )
25
+
26
+
27
+ @triton.jit
28
+ def get_ntid():
29
+ return tl.inline_asm_elementwise(
30
+ """
31
+ mov.u32 $0, %ntid.x;
32
+ mov.u32 $1, %ntid.y;
33
+ mov.u32 $2, %ntid.z;
34
+ """,
35
+ "=r,=r,=r",
36
+ [],
37
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
38
+ is_pure=True,
39
+ pack=1,
40
+ )
41
+
42
+
43
+ @triton.jit
44
+ def get_flat_tid():
45
+ tid_x, tid_y, tid_z = get_tid()
46
+ ntid_x, ntid_y, _ = get_ntid()
47
+ return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
48
+
49
+
50
+ @triton.jit
51
+ def get_flat_bid():
52
+ return (
53
+ tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
54
+ + tl.program_id(1) * tl.num_programs(0)
55
+ + tl.program_id(0)
56
+ )
57
+
58
+
59
+ @triton.jit
60
+ def sync_threads():
61
+ tl.inline_asm_elementwise(
62
+ "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
63
+ )
torchtitan/experiments/deepseek_v3/train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 8 run.py
8
+ import torch
9
+ import torch.distributed as dist
10
+ from checkpoint import load_weights_from_hf
11
+ from model import DeepseekForCausalLM
12
+ from model_config import deepseek_config_registry
13
+
14
+ from torch.distributed.device_mesh import DeviceMesh
15
+ from torch.distributed.fsdp import fully_shard
16
+ from torch.distributed.pipelining import PipelineStage, Schedule1F1B
17
+
18
+
19
+ # Use DeepSeek-V2-Lite as a proxy
20
+ model_id = "deepseek-ai/DeepSeek-V2-Lite"
21
+
22
+
23
+ # Run full model
24
+ def run_full_model(
25
+ mesh: DeviceMesh,
26
+ ):
27
+ rank = dist.get_rank()
28
+ device_count = torch.cuda.device_count()
29
+ device = torch.device("cuda", rank % device_count)
30
+
31
+ pp_mesh = mesh["pp"]
32
+ ep_mesh = mesh["ep"]
33
+ pp_rank = pp_mesh.get_local_rank()
34
+ ep_rank = ep_mesh.get_local_rank()
35
+ pp_size = pp_mesh.size()
36
+ ep_size = ep_mesh.size()
37
+
38
+ # Get model configs
39
+ model_args = deepseek_config_registry[model_id]
40
+ # [Note]: I am making the model smaller for testing / avoiding OOM. If you
41
+ # have sufficient GPUs for model parallelism, you can remove this line.
42
+ model_args.num_hidden_layers = 16
43
+
44
+ # Apply model parallelism
45
+ model_args.ep_size = ep_size
46
+ model_args.num_stages = pp_size
47
+ model_args.stage_idx = pp_rank
48
+ print(model_args)
49
+
50
+ # Instantiate model
51
+ with device, mesh:
52
+ model = DeepseekForCausalLM(model_args)
53
+
54
+ # Load weights
55
+ load_weights_from_hf(model, model_id, device)
56
+ model.train()
57
+
58
+ # Apply data parallelism
59
+ fsdp_mesh = mesh["fsdp"]
60
+ hsdp_mesh = mesh["ep", "fsdp"]
61
+ # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
62
+ # optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
63
+ # Reason: the MoE is "sparsely activated" compared to the dense model, thus
64
+ # it will be ineconomical re-gather the weights.
65
+ for layer in model.model.layers.values():
66
+ # Apply FSDP to experts
67
+ if hasattr(layer.mlp, "experts"):
68
+ for expert in layer.mlp.experts.values():
69
+ fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
70
+ # Apply HSDP to other parts such as attention, layernorm, because they
71
+ # are doing DDP on EP dimension
72
+ fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)
73
+
74
+ # Apply HSDP on root model (lm_head, embeddings, etc)
75
+ fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)
76
+
77
+ # Synthetic setting
78
+ microbatches = pp_size * 2
79
+
80
+ # Use Symmetric Memory for MoE token shuffle.
81
+ # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is
82
+ # currently supported for forward only. See `generate.py`.
83
+ # model.setup_symm_mem(torch.bfloat16, device)
84
+
85
+ # Example inputs
86
+ torch.manual_seed(ep_rank)
87
+ bs = 4
88
+ seqlen = 128
89
+ x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
90
+ label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)
91
+
92
+ # Create loss function
93
+ loss_fn = torch.nn.functional.cross_entropy
94
+
95
+ # Run forward and backward
96
+ steps = 2
97
+ for _ in range(steps):
98
+ if pp_size > 1:
99
+ # Create pipeline stage
100
+ stage = PipelineStage(
101
+ model,
102
+ pp_rank,
103
+ pp_size,
104
+ device,
105
+ group=pp_mesh.get_group(),
106
+ )
107
+
108
+ # Create pipeline schedule
109
+ losses = []
110
+ pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)
111
+
112
+ if pp_rank == 0:
113
+ y = pp_schedule.step(x)
114
+ elif pp_rank == pp_size - 1:
115
+ y = pp_schedule.step(target=label, losses=losses)
116
+ loss = torch.mean(torch.stack(losses))
117
+ else:
118
+ pp_schedule.step()
119
+ else:
120
+ y = model(x)
121
+ loss = loss_fn(y, label)
122
+ loss.backward()
123
+
124
+ if pp_rank == pp_size - 1:
125
+ print(f"logits: {y.shape}")
126
+ print(f"{loss=}")
127
+
128
+ if pp_rank == 0:
129
+ param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
130
+ print(f"{torch.linalg.norm(param.grad)=}")
131
+
132
+ model.zero_grad()
133
+
134
+ print("Backward done")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp"))
139
+
140
+ run_full_model(mesh)
141
+
142
+ dist.destroy_process_group()
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/dataset/flux_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from datasets import Dataset, load_dataset
17
+ from datasets.distributed import split_dataset_by_node
18
+ from PIL import Image
19
+
20
+ from torch.distributed.checkpoint.stateful import Stateful
21
+
22
+ from torch.utils.data import IterableDataset
23
+ from torchtitan.components.dataloader import ParallelAwareDataloader
24
+
25
+ from torchtitan.config_manager import JobConfig
26
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
27
+ from torchtitan.tools.logging import logger
28
+
29
+
30
+ def _process_cc12m_image(
31
+ img: Image.Image,
32
+ output_size: int = 256,
33
+ ) -> Optional[torch.Tensor]:
34
+ """Process CC12M image to the desired size."""
35
+
36
+ width, height = img.size
37
+ # Skip low resolution images
38
+ if width < output_size or height < output_size:
39
+ return None
40
+
41
+ if width >= height:
42
+ # resize height to be equal to output_size, then crop
43
+ new_width, new_height = math.ceil(output_size / height * width), output_size
44
+ img = img.resize((new_width, new_height))
45
+ left = random.randint(0, new_width - output_size)
46
+ resized_img = img.crop((left, 0, left + output_size, output_size))
47
+ else:
48
+ # resize width to be equal to output_size, the crop
49
+ new_width, new_height = (
50
+ output_size,
51
+ math.ceil(output_size / width * height),
52
+ )
53
+ img = img.resize((new_width, new_height))
54
+ lower = random.randint(0, new_width - output_size)
55
+ resized_img = img.crop((0, lower, output_size, lower + output_size))
56
+
57
+ assert resized_img.size[0] == resized_img.size[1] == output_size
58
+
59
+ # Skip grayscale images
60
+ if resized_img.mode == "L":
61
+ return None
62
+
63
+ np_img = np.array(resized_img).transpose((2, 0, 1))
64
+ tensor_img = torch.tensor(np_img).float() / 255.0
65
+
66
+ # NOTE: The following commented code is an alternative way
67
+ # img_transform = transforms.Compose(
68
+ # [
69
+ # transforms.Resize(max(output_size, output_size)),
70
+ # transforms.CenterCrop((output_size, output_size)),
71
+ # transforms.ToTensor(),
72
+ # ]
73
+ # )
74
+ # tensor_img = img_transform(img)
75
+
76
+ return tensor_img
77
+
78
+
79
+ def _flux_data_processor(
80
+ sample: dict[str, Any],
81
+ t5_tokenizer: FluxTokenizer,
82
+ clip_tokenizer: FluxTokenizer,
83
+ output_size: int = 256,
84
+ ) -> dict[str, Any]:
85
+ """
86
+ Preprocess CC12M dataset sample image and text for Flux model.
87
+
88
+ Args:
89
+ sample: A sample from dataset
90
+ t5_encoder: T5 encoder
91
+ clip_encoder: CLIP encoder
92
+ output_size: The output image size
93
+
94
+ """
95
+ img = _process_cc12m_image(sample["jpg"], output_size=output_size)
96
+ t5_tokens = t5_tokenizer.encode(sample["txt"])
97
+ clip_tokens = clip_tokenizer.encode(sample["txt"])
98
+
99
+ return {
100
+ "image": img,
101
+ "clip_tokens": clip_tokens, # type: List[int]
102
+ "t5_tokens": t5_tokens, # type: List[int]
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class TextToImageDatasetConfig:
108
+ path: str
109
+ loader: Callable
110
+ data_processor: Callable
111
+
112
+
113
+ DATASETS = {
114
+ "cc12m": TextToImageDatasetConfig(
115
+ path="pixparse/cc12m-wds",
116
+ loader=lambda path: load_dataset(path, split="train", streaming=True),
117
+ data_processor=_flux_data_processor,
118
+ ),
119
+ }
120
+
121
+
122
+ def _validate_dataset(
123
+ dataset_name: str, dataset_path: Optional[str] = None
124
+ ) -> tuple[str, Callable, Callable]:
125
+ """Validate dataset name and path."""
126
+ if dataset_name not in DATASETS:
127
+ raise ValueError(
128
+ f"Dataset {dataset_name} is not supported. "
129
+ f"Supported datasets are: {list(DATASETS.keys())}"
130
+ )
131
+
132
+ config = DATASETS[dataset_name]
133
+ path = dataset_path or config.path
134
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
135
+ return path, config.loader, config.data_processor
136
+
137
+
138
+ class FluxDataset(IterableDataset, Stateful):
139
+ """Dataset for FLUX text-to-image model.
140
+
141
+ Args:
142
+ dataset_name (str): Name of the dataset.
143
+ dataset_path (str): Path to the dataset.
144
+ model_transform (Transform): Callable that applies model-specific preprocessing to the sample.
145
+ dp_rank (int): Data parallel rank.
146
+ dp_world_size (int): Data parallel world size.
147
+ infinite (bool): Whether to loop over the dataset infinitely.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ dataset_name: str,
153
+ dataset_path: Optional[str],
154
+ t5_tokenizer: FluxTokenizer,
155
+ clip_tokenizer: FluxTokenizer,
156
+ job_config: Optional[JobConfig] = None,
157
+ dp_rank: int = 0,
158
+ dp_world_size: int = 1,
159
+ infinite: bool = False,
160
+ ) -> None:
161
+
162
+ # Force lowercase for consistent comparison
163
+ dataset_name = dataset_name.lower()
164
+
165
+ path, dataset_loader, data_processor = _validate_dataset(
166
+ dataset_name, dataset_path
167
+ )
168
+ ds = dataset_loader(path)
169
+
170
+ self.dataset_name = dataset_name
171
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172
+
173
+ self._t5_tokenizer = t5_tokenizer
174
+ self._clip_tokenizer = clip_tokenizer
175
+ self._data_processor = data_processor
176
+ self.job_config = job_config
177
+
178
+ self.infinite = infinite
179
+
180
+ # Variables for checkpointing
181
+ self._sample_idx = 0
182
+ self._all_samples: list[dict[str, Any]] = []
183
+
184
+ def _get_data_iter(self):
185
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
186
+ return iter([])
187
+
188
+ it = iter(self._data)
189
+ for _ in range(self._sample_idx):
190
+ next(it)
191
+ return it
192
+
193
+ def __iter__(self):
194
+ while True:
195
+ for sample in self._get_data_iter():
196
+ # Use the dataset-specific preprocessor
197
+ sample_dict = self._data_processor(
198
+ sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199
+ )
200
+
201
+ # skip low quality image or image with color channel = 1
202
+ if sample_dict["image"] is None:
203
+ logger.warning(
204
+ f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
205
+ )
206
+ continue
207
+
208
+ self._all_samples.extend(sample_dict)
209
+ self._sample_idx += 1
210
+
211
+ labels = sample_dict.pop("image")
212
+ yield sample_dict, labels
213
+
214
+ if not self.infinite:
215
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
216
+ break
217
+ else:
218
+ # Reset offset for the next iteration
219
+ self._sample_idx = 0
220
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
221
+
222
+ def load_state_dict(self, state_dict):
223
+ self._sample_idx = state_dict["sample_idx"]
224
+ self._all_samples = state_dict["all_samples"]
225
+
226
+ def state_dict(self):
227
+ return {
228
+ "all_samples": self._all_samples,
229
+ "sample_idx": self._sample_idx,
230
+ }
231
+
232
+
233
+ def build_flux_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ job_config: JobConfig,
237
+ # This parameter is not used, keep it for compatibility
238
+ tokenizer: FluxTokenizer | None,
239
+ infinite: bool = True,
240
+ ) -> ParallelAwareDataloader:
241
+ """Build a data loader for HuggingFace datasets."""
242
+ dataset_name = job_config.training.dataset
243
+ dataset_path = job_config.training.dataset_path
244
+ batch_size = job_config.training.batch_size
245
+
246
+ t5_encoder_name = job_config.encoder.t5_encoder
247
+ clip_encoder_name = job_config.encoder.clip_encoder
248
+ max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
249
+
250
+ ds = FluxDataset(
251
+ dataset_name=dataset_name,
252
+ dataset_path=dataset_path,
253
+ t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
254
+ clip_tokenizer=FluxTokenizer(
255
+ clip_encoder_name, max_length=77
256
+ ), # fix max_length for CLIP
257
+ dp_rank=dp_rank,
258
+ dp_world_size=dp_world_size,
259
+ infinite=infinite,
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ )
torchtitan/experiments/flux/model/autoencoder.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from safetensors.torch import load_file as load_sft
13
+ from torch import nn, Tensor
14
+
15
+
16
+ @dataclass
17
+ class AutoEncoderParams:
18
+ resolution: int = 256
19
+ in_channels: int = 3
20
+ ch: int = 128
21
+ out_ch: int = 3
22
+ ch_mult: tuple[int] = (1, 2, 4, 4)
23
+ num_res_blocks: int = 2
24
+ z_channels: int = 16
25
+ scale_factor: float = 0.3611
26
+ shift_factor: float = 0.1159
27
+
28
+
29
+ def swish(x: Tensor) -> Tensor:
30
+ return x * torch.sigmoid(x)
31
+
32
+
33
+ class AttnBlock(nn.Module):
34
+ def __init__(self, in_channels: int):
35
+ super().__init__()
36
+ self.in_channels = in_channels
37
+
38
+ self.norm = nn.GroupNorm(
39
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
40
+ )
41
+
42
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
43
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
44
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
45
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
46
+
47
+ def attention(self, h_: Tensor) -> Tensor:
48
+ h_ = self.norm(h_)
49
+ q = self.q(h_)
50
+ k = self.k(h_)
51
+ v = self.v(h_)
52
+
53
+ b, c, h, w = q.shape
54
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
55
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
56
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
57
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
58
+
59
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return x + self.proj_out(self.attention(x))
63
+
64
+
65
+ class ResnetBlock(nn.Module):
66
+ def __init__(self, in_channels: int, out_channels: int):
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+
72
+ self.norm1 = nn.GroupNorm(
73
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
74
+ )
75
+ self.conv1 = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
77
+ )
78
+ self.norm2 = nn.GroupNorm(
79
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
80
+ )
81
+ self.conv2 = nn.Conv2d(
82
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
83
+ )
84
+ if self.in_channels != self.out_channels:
85
+ self.nin_shortcut = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
87
+ )
88
+
89
+ def forward(self, x):
90
+ h = x
91
+ h = self.norm1(h)
92
+ h = swish(h)
93
+ h = self.conv1(h)
94
+
95
+ h = self.norm2(h)
96
+ h = swish(h)
97
+ h = self.conv2(h)
98
+
99
+ if self.in_channels != self.out_channels:
100
+ x = self.nin_shortcut(x)
101
+
102
+ return x + h
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, in_channels: int):
107
+ super().__init__()
108
+ # no asymmetric padding in torch conv, must do it ourselves
109
+ self.conv = nn.Conv2d(
110
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
111
+ )
112
+
113
+ def forward(self, x: Tensor):
114
+ pad = (0, 1, 0, 1)
115
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
116
+ x = self.conv(x)
117
+ return x
118
+
119
+
120
+ class Upsample(nn.Module):
121
+ def __init__(self, in_channels: int):
122
+ super().__init__()
123
+ self.conv = nn.Conv2d(
124
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
125
+ )
126
+
127
+ def forward(self, x: Tensor):
128
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
129
+ x = self.conv(x)
130
+ return x
131
+
132
+
133
+ class Encoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ resolution: int,
137
+ in_channels: int,
138
+ ch: int,
139
+ ch_mult: list[int],
140
+ num_res_blocks: int,
141
+ z_channels: int,
142
+ ):
143
+ super().__init__()
144
+ self.ch = ch
145
+ self.num_resolutions = len(ch_mult)
146
+ self.num_res_blocks = num_res_blocks
147
+ self.resolution = resolution
148
+ self.in_channels = in_channels
149
+ # downsampling
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
152
+ )
153
+
154
+ curr_res = resolution
155
+ in_ch_mult = (1,) + tuple(ch_mult)
156
+ self.in_ch_mult = in_ch_mult
157
+ self.down = nn.ModuleList()
158
+ block_in = self.ch
159
+ for i_level in range(self.num_resolutions):
160
+ block = nn.ModuleList()
161
+ attn = nn.ModuleList()
162
+ block_in = ch * in_ch_mult[i_level]
163
+ block_out = ch * ch_mult[i_level]
164
+ for _ in range(self.num_res_blocks):
165
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
166
+ block_in = block_out
167
+ down = nn.Module()
168
+ down.block = block
169
+ down.attn = attn
170
+ if i_level != self.num_resolutions - 1:
171
+ down.downsample = Downsample(block_in)
172
+ curr_res = curr_res // 2
173
+ self.down.append(down)
174
+
175
+ # middle
176
+ self.mid = nn.Module()
177
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
178
+ self.mid.attn_1 = AttnBlock(block_in)
179
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
180
+
181
+ # end
182
+ self.norm_out = nn.GroupNorm(
183
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
184
+ )
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ # downsampling
191
+ hs = [self.conv_in(x)]
192
+ for i_level in range(self.num_resolutions):
193
+ for i_block in range(self.num_res_blocks):
194
+ h = self.down[i_level].block[i_block](hs[-1])
195
+ if len(self.down[i_level].attn) > 0:
196
+ h = self.down[i_level].attn[i_block](h)
197
+ hs.append(h)
198
+ if i_level != self.num_resolutions - 1:
199
+ hs.append(self.down[i_level].downsample(hs[-1]))
200
+
201
+ # middle
202
+ h = hs[-1]
203
+ h = self.mid.block_1(h)
204
+ h = self.mid.attn_1(h)
205
+ h = self.mid.block_2(h)
206
+ # end
207
+ h = self.norm_out(h)
208
+ h = swish(h)
209
+ h = self.conv_out(h)
210
+ return h
211
+
212
+
213
+ class Decoder(nn.Module):
214
+ def __init__(
215
+ self,
216
+ ch: int,
217
+ out_ch: int,
218
+ ch_mult: list[int],
219
+ num_res_blocks: int,
220
+ in_channels: int,
221
+ resolution: int,
222
+ z_channels: int,
223
+ ):
224
+ super().__init__()
225
+ self.ch = ch
226
+ self.num_resolutions = len(ch_mult)
227
+ self.num_res_blocks = num_res_blocks
228
+ self.resolution = resolution
229
+ self.in_channels = in_channels
230
+ self.ffactor = 2 ** (self.num_resolutions - 1)
231
+
232
+ # compute in_ch_mult, block_in and curr_res at lowest res
233
+ block_in = ch * ch_mult[self.num_resolutions - 1]
234
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
235
+ self.z_shape = (1, z_channels, curr_res, curr_res)
236
+
237
+ # z to block_in
238
+ self.conv_in = nn.Conv2d(
239
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
240
+ )
241
+
242
+ # middle
243
+ self.mid = nn.Module()
244
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
247
+
248
+ # upsampling
249
+ self.up = nn.ModuleList()
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = nn.Module()
258
+ up.block = block
259
+ up.attn = attn
260
+ if i_level != 0:
261
+ up.upsample = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ # get dtype for proper tracing
273
+ upscale_dtype = next(self.up.parameters()).dtype
274
+
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # cast to proper dtype
284
+ h = h.to(upscale_dtype)
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](h)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = swish(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class DiagonalGaussian(nn.Module):
302
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
303
+ super().__init__()
304
+ self.sample = sample
305
+ self.chunk_dim = chunk_dim
306
+
307
+ def forward(self, z: Tensor) -> Tensor:
308
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
309
+ if self.sample:
310
+ std = torch.exp(0.5 * logvar)
311
+ return mean + std * torch.randn_like(mean)
312
+ else:
313
+ return mean
314
+
315
+
316
+ class AutoEncoder(nn.Module):
317
+ def __init__(self, params: AutoEncoderParams):
318
+ super().__init__()
319
+ self.params = params
320
+ self.encoder = Encoder(
321
+ resolution=params.resolution,
322
+ in_channels=params.in_channels,
323
+ ch=params.ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.decoder = Decoder(
329
+ resolution=params.resolution,
330
+ in_channels=params.in_channels,
331
+ ch=params.ch,
332
+ out_ch=params.out_ch,
333
+ ch_mult=params.ch_mult,
334
+ num_res_blocks=params.num_res_blocks,
335
+ z_channels=params.z_channels,
336
+ )
337
+ self.reg = DiagonalGaussian()
338
+
339
+ self.scale_factor = params.scale_factor
340
+ self.shift_factor = params.shift_factor
341
+
342
+ def encode(self, x: Tensor) -> Tensor:
343
+ z = self.reg(self.encoder(x))
344
+ z = self.scale_factor * (z - self.shift_factor)
345
+ return z
346
+
347
+ def decode(self, z: Tensor) -> Tensor:
348
+ z = z / self.scale_factor + self.shift_factor
349
+ return self.decoder(z)
350
+
351
+ def forward(self, x: Tensor) -> Tensor:
352
+ return self.decode(self.encode(x))
353
+
354
+
355
+ def load_ae(
356
+ ckpt_path: str,
357
+ autoencoder_params: AutoEncoderParams,
358
+ device: str | torch.device = "cuda",
359
+ dtype=torch.bfloat16,
360
+ ) -> AutoEncoder:
361
+ """
362
+ Load the autoencoder from the given model name.
363
+ Args:
364
+ name (str): The name of the autoencoder.
365
+ device (str or torch.device): The device to load the autoencoder to.
366
+ Returns:
367
+ AutoEncoder: The loaded autoencoder.
368
+ """
369
+ # Loading the autoencoder
370
+ print("Init AE")
371
+ with torch.device(device):
372
+ ae = AutoEncoder(autoencoder_params)
373
+
374
+ if not os.path.exists(ckpt_path):
375
+ raise ValueError(
376
+ f"Autoencoder path {ckpt_path} does not exist. Please download it first."
377
+ )
378
+
379
+ if ckpt_path is not None:
380
+ sd = load_sft(ckpt_path, device=str(device))
381
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
382
+ if len(missing) > 0:
383
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
384
+ if len(unexpected) > 0:
385
+ print(
386
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
387
+ )
388
+ return ae.to(dtype=dtype)
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/layers.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # imported from black-forest-labs/FLUX
8
+ import math
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from einops import rearrange
13
+ from torch import nn, Tensor
14
+
15
+ from torchtitan.experiments.flux.model.math import attention, rope
16
+
17
+
18
+ class EmbedND(nn.Module):
19
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -math.log(max_period)
48
+ * torch.arange(start=0, end=half, dtype=torch.float32)
49
+ / half
50
+ ).to(t.device)
51
+
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ if torch.is_floating_point(t):
57
+ embedding = embedding.to(t)
58
+ return embedding
59
+
60
+
61
+ class MLPEmbedder(nn.Module):
62
+ def __init__(self, in_dim: int, hidden_dim: int):
63
+ super().__init__()
64
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
65
+ self.silu = nn.SiLU()
66
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ return self.out_layer(self.silu(self.in_layer(x)))
70
+
71
+
72
+ class RMSNorm(torch.nn.Module):
73
+ def __init__(self, dim: int):
74
+ super().__init__()
75
+ self.scale = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x: Tensor):
78
+ x_dtype = x.dtype
79
+ x = x.float()
80
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
81
+ return (x * rrms).to(dtype=x_dtype) * self.scale
82
+
83
+
84
+ class QKNorm(torch.nn.Module):
85
+ def __init__(self, dim: int):
86
+ super().__init__()
87
+ self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm
88
+ self.key_norm = RMSNorm(dim)
89
+
90
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
91
+ q = self.query_norm(q)
92
+ k = self.key_norm(k)
93
+ return q.to(v), k.to(v)
94
+
95
+
96
+ class SelfAttention(nn.Module):
97
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.norm = QKNorm(head_dim)
104
+ self.proj = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
107
+ qkv = self.qkv(x)
108
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
109
+ q, k = self.norm(q, k, v)
110
+ x = attention(q, k, v, pe=pe)
111
+ x = self.proj(x)
112
+ return x
113
+
114
+
115
+ @dataclass
116
+ class ModulationOut:
117
+ shift: Tensor
118
+ scale: Tensor
119
+ gate: Tensor
120
+
121
+
122
+ class Modulation(nn.Module):
123
+ def __init__(self, dim: int, double: bool):
124
+ super().__init__()
125
+ self.is_double = double
126
+ self.multiplier = 6 if double else 3
127
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
128
+
129
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
130
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
131
+ self.multiplier, dim=-1
132
+ )
133
+
134
+ return (
135
+ ModulationOut(*out[:3]),
136
+ ModulationOut(*out[3:]) if self.is_double else None,
137
+ )
138
+
139
+
140
+ class DoubleStreamBlock(nn.Module):
141
+ def __init__(
142
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
143
+ ):
144
+ super().__init__()
145
+
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.num_heads = num_heads
148
+ self.hidden_size = hidden_size
149
+ self.img_mod = Modulation(hidden_size, double=True)
150
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
+ self.img_attn = SelfAttention(
152
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
153
+ )
154
+
155
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.img_mlp = nn.Sequential(
157
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
158
+ nn.GELU(approximate="tanh"),
159
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
160
+ )
161
+
162
+ self.txt_mod = Modulation(hidden_size, double=True)
163
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
164
+ self.txt_attn = SelfAttention(
165
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
166
+ )
167
+
168
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ self.txt_mlp = nn.Sequential(
170
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
171
+ nn.GELU(approximate="tanh"),
172
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
173
+ )
174
+
175
+ def forward(
176
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
177
+ ) -> tuple[Tensor, Tensor]:
178
+ img_mod1, img_mod2 = self.img_mod(vec)
179
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
180
+
181
+ # prepare image for attention
182
+ img_modulated = self.img_norm1(img)
183
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
184
+ img_qkv = self.img_attn.qkv(img_modulated)
185
+ img_q, img_k, img_v = rearrange(
186
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
187
+ )
188
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
189
+
190
+ # prepare txt for attention
191
+ txt_modulated = self.txt_norm1(txt)
192
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
193
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
194
+ txt_q, txt_k, txt_v = rearrange(
195
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
196
+ )
197
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
198
+
199
+ # run actual attention
200
+ q = torch.cat((txt_q, img_q), dim=2)
201
+ k = torch.cat((txt_k, img_k), dim=2)
202
+ v = torch.cat((txt_v, img_v), dim=2)
203
+
204
+ attn = attention(q, k, v, pe=pe)
205
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
206
+
207
+ # calculate the img bloks
208
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
209
+ img = img + img_mod2.gate * self.img_mlp(
210
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
211
+ )
212
+
213
+ # calculate the txt bloks
214
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
215
+ txt = txt + txt_mod2.gate * self.txt_mlp(
216
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
217
+ )
218
+ return img, txt
219
+
220
+
221
+ class SingleStreamBlock(nn.Module):
222
+ """
223
+ A DiT block with parallel linear layers as described in
224
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ hidden_size: int,
230
+ num_heads: int,
231
+ mlp_ratio: float = 4.0,
232
+ qk_scale: float | None = None,
233
+ ):
234
+ super().__init__()
235
+ self.hidden_dim = hidden_size
236
+ self.num_heads = num_heads
237
+ head_dim = hidden_size // num_heads
238
+ self.scale = qk_scale or head_dim**-0.5
239
+
240
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
241
+ # qkv and mlp_in
242
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
243
+ # proj and mlp_out
244
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
245
+
246
+ self.norm = QKNorm(head_dim)
247
+
248
+ self.hidden_size = hidden_size
249
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
250
+
251
+ self.mlp_act = nn.GELU(approximate="tanh")
252
+ self.modulation = Modulation(hidden_size, double=False)
253
+
254
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
255
+ mod, _ = self.modulation(vec)
256
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
257
+ qkv, mlp = torch.split(
258
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
259
+ )
260
+
261
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
262
+ q, k = self.norm(q, k, v)
263
+
264
+ # compute attention
265
+ attn = attention(q, k, v, pe=pe)
266
+ # compute activation in mlp stream, cat again and run second linear layer
267
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
268
+ return x + mod.gate * output
269
+
270
+
271
+ class LastLayer(nn.Module):
272
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
273
+ super().__init__()
274
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ self.linear = nn.Linear(
276
+ hidden_size, patch_size * patch_size * out_channels, bias=True
277
+ )
278
+ self.adaLN_modulation = nn.Sequential(
279
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
280
+ )
281
+
282
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
283
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
284
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
285
+ x = self.linear(x)
286
+ return x
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/parallelize_flux.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+
11
+ import torch.nn as nn
12
+
13
+ from torch.distributed.device_mesh import DeviceMesh
14
+
15
+ from torchtitan.config_manager import JobConfig
16
+ from torchtitan.distributed import ParallelDims
17
+
18
+
19
+ def parallelize_flux(
20
+ model: nn.Module,
21
+ world_mesh: DeviceMesh,
22
+ parallel_dims: ParallelDims,
23
+ job_config: JobConfig,
24
+ ):
25
+ # TODO: Add model parallel strategy here
26
+ return model
torchtitan/experiments/flux/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ einops
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
11
+
12
+ import argparse
13
+ import logging
14
+ import time
15
+
16
+ # from typing import Dict, List, Optional, Tuple
17
+
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import torch
21
+ import triton
22
+
23
+ # import triton.language as tl
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+ # Try to import the optimized implementations
31
+ try:
32
+ from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
33
+
34
+ except ImportError:
35
+ logging.error(
36
+ "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
37
+ )
38
+ raise
39
+
40
+
41
+ def compute_reference_forward(x, w, m_sizes):
42
+ """
43
+ Reference PyTorch implementation of M*G grouped GEMM forward pass.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor of shape (M, K)
47
+ w (torch.Tensor): Weight tensor of shape (N, K)
48
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
49
+
50
+ Returns:
51
+ torch.Tensor: Output tensor of shape (M, N)
52
+ """
53
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
54
+
55
+ m_start = 0
56
+ for g in range(len(m_sizes)):
57
+ m_size = m_sizes[g].item()
58
+ if m_size > 0:
59
+ m_end = m_start + m_size
60
+
61
+ # Extract group input
62
+ x_g = x[m_start:m_end]
63
+
64
+ # Compute group output
65
+ y_g = torch.matmul(x_g, w.T)
66
+
67
+ # Store result
68
+ result[m_start:m_end] = y_g
69
+
70
+ # Update start index
71
+ m_start = m_end
72
+
73
+ return result
74
+
75
+
76
+ @triton.testing.perf_report(
77
+ triton.testing.Benchmark(
78
+ x_names=["N"], # We'll vary the output dimension
79
+ x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
80
+ # x_vals=[8192, 16384],
81
+ line_arg="provider", # We'll compare different providers
82
+ line_vals=["pytorch_reference", "M*G grouped GEMM"],
83
+ line_names=["PyTorch Reference", "M*G grouped Kernel"],
84
+ styles=[("blue", "-"), ("red", "-")],
85
+ ylabel="TFLOPS", # We'll measure TFLOPS
86
+ plot_name="mg_grouped_gemm_comparison",
87
+ args={
88
+ "M": 8192, # Batch dimension, fixed for all tests
89
+ "K": 7168, # Hidden dimension, fixed for all tests
90
+ "G": 8, # Number of groups
91
+ "dtype": torch.float16,
92
+ "device": "cuda",
93
+ },
94
+ )
95
+ )
96
+ def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
97
+ """
98
+ Benchmark the forward pass of the grouped GEMM implementation.
99
+
100
+ Args:
101
+ M (int): Total batch size dimension
102
+ K (int): Hidden dimension
103
+ N (int): Output dimension
104
+ G (int): Number of groups
105
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
106
+ dtype (torch.dtype): Data type to use
107
+ device (str): Device to use
108
+
109
+ Returns:
110
+ float: Performance in TFLOPS
111
+ """
112
+ # Create group sizes for M dimension (balanced across groups)
113
+ base_size = M // G
114
+ remainder = M % G
115
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
116
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
117
+
118
+ print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
119
+
120
+ # Create input and weight tensors
121
+ x = torch.randn(M, K, dtype=dtype, device=device)
122
+ w = torch.randn(N, K, dtype=dtype, device=device)
123
+
124
+ # Pre-compute for PyTorch reference to ensure fair comparison
125
+ if provider == "pytorch_reference":
126
+ # Warmup
127
+ torch.cuda.synchronize()
128
+ compute_reference_forward(x, w, m_sizes)
129
+ torch.cuda.synchronize()
130
+
131
+ # Benchmark
132
+ start_time = time.time()
133
+ for _ in range(10): # Average over 10 runs
134
+ compute_reference_forward(x, w, m_sizes)
135
+ torch.cuda.synchronize()
136
+ end_time = time.time()
137
+ else: # Optimized kernel
138
+ # Warmup
139
+ torch.cuda.synchronize()
140
+ grouped_gemm_forward(x, w, m_sizes)
141
+ torch.cuda.synchronize()
142
+
143
+ # Benchmark
144
+ start_time = time.time()
145
+ for _ in range(10): # Average over 10 runs
146
+ grouped_gemm_forward(x, w, m_sizes)
147
+ torch.cuda.synchronize()
148
+ end_time = time.time()
149
+
150
+ # Calculate FLOPs
151
+ # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
152
+ flops = 2 * M * N * K
153
+
154
+ # Convert to TFLOPS (tera-FLOPS)
155
+ avg_time = (end_time - start_time) / 10 # Average time per run
156
+ tflops = flops / avg_time / 1e12
157
+
158
+ return tflops
159
+
160
+
161
+ @triton.testing.perf_report(
162
+ triton.testing.Benchmark(
163
+ x_names=["G"], # We'll vary the number of groups
164
+ x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
165
+ line_arg="provider", # We'll compare different providers
166
+ line_vals=["pytorch_reference", "optimized_kernel"],
167
+ line_names=["PyTorch Reference", "Optimized Kernel"],
168
+ styles=[("blue", "-"), ("red", "-")],
169
+ ylabel="TFLOPS", # We'll measure TFLOPS
170
+ plot_name="mg_grouped_gemm_group_scaling",
171
+ args={
172
+ "M": 8192, # Batch dimension, fixed for all tests
173
+ "K": 4096, # Hidden dimension, fixed for all tests
174
+ "N": 8192, # Output dimension, fixed for all tests
175
+ "dtype": torch.float16,
176
+ "device": "cuda",
177
+ },
178
+ )
179
+ )
180
+ def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
181
+ """
182
+ Benchmark how performance scales with number of groups.
183
+
184
+ Args:
185
+ M (int): Total batch size dimension
186
+ K (int): Hidden dimension
187
+ N (int): Output dimension
188
+ G (int): Number of groups
189
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
190
+ dtype (torch.dtype): Data type to use
191
+ device (str): Device to use
192
+
193
+ Returns:
194
+ float: Performance in TFLOPS
195
+ """
196
+ # Create group sizes for M dimension (balanced across groups)
197
+ base_size = M // G
198
+ remainder = M % G
199
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
200
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
201
+
202
+ # Create input and weight tensors
203
+ x = torch.randn(M, K, dtype=dtype, device=device)
204
+ w = torch.randn(N, K, dtype=dtype, device=device)
205
+
206
+ # Benchmark logic - same as previous function
207
+ if provider == "pytorch_reference":
208
+ torch.cuda.synchronize()
209
+ compute_reference_forward(x, w, m_sizes)
210
+ torch.cuda.synchronize()
211
+
212
+ start_time = time.time()
213
+ for _ in range(10):
214
+ compute_reference_forward(x, w, m_sizes)
215
+ torch.cuda.synchronize()
216
+ end_time = time.time()
217
+ else:
218
+ torch.cuda.synchronize()
219
+ grouped_gemm_forward(x, w, m_sizes)
220
+ torch.cuda.synchronize()
221
+
222
+ start_time = time.time()
223
+ for _ in range(10):
224
+ grouped_gemm_forward(x, w, m_sizes)
225
+ torch.cuda.synchronize()
226
+ end_time = time.time()
227
+
228
+ # Calculate FLOPs and TFLOPS
229
+ flops = 2 * M * N * K
230
+ avg_time = (end_time - start_time) / 10
231
+ tflops = flops / avg_time / 1e12
232
+
233
+ return tflops
234
+
235
+
236
+ @triton.testing.perf_report(
237
+ triton.testing.Benchmark(
238
+ x_names=["group_balance"], # We'll vary the group balance factor
239
+ x_vals=[
240
+ 0.0,
241
+ 0.25,
242
+ 0.5,
243
+ 0.75,
244
+ 0.9,
245
+ ], # Different imbalance factors (0 = balanced, 1 = max imbalance)
246
+ line_arg="provider", # We'll compare different providers
247
+ line_vals=["pytorch_reference", "optimized_kernel"],
248
+ line_names=["PyTorch Reference", "Optimized Kernel"],
249
+ styles=[("blue", "-"), ("red", "-")],
250
+ ylabel="TFLOPS", # We'll measure TFLOPS
251
+ plot_name="mg_grouped_gemm_imbalance",
252
+ args={
253
+ "M": 8192, # Batch dimension, fixed for all tests
254
+ "K": 4096, # Hidden dimension, fixed for all tests
255
+ "N": 8192, # Output dimension, fixed for all tests
256
+ "G": 4, # Number of groups
257
+ "dtype": torch.float16,
258
+ "device": "cuda",
259
+ },
260
+ )
261
+ )
262
+ def benchmark_imbalance(
263
+ M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
264
+ ):
265
+ """
266
+ Benchmark how performance is affected by imbalanced group sizes.
267
+
268
+ Args:
269
+ M (int): Total batch size dimension
270
+ K (int): Hidden dimension
271
+ N (int): Output dimension
272
+ G (int): Number of groups
273
+ group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
274
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
275
+ dtype (torch.dtype): Data type to use
276
+ device (str): Device to use
277
+
278
+ Returns:
279
+ float: Performance in TFLOPS
280
+ """
281
+ # Create imbalanced group sizes for M dimension
282
+ if group_balance == 0:
283
+ # Balanced case
284
+ base_size = M // G
285
+ remainder = M % G
286
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
287
+ else:
288
+ # Imbalanced case
289
+ # First group gets more elements, last group gets fewer
290
+ # The imbalance is controlled by the group_balance factor
291
+ remaining = M
292
+ M_sizes = []
293
+ for g in range(G):
294
+ # Interpolate from balanced to imbalanced based on group_balance
295
+ # For balanced (group_balance=0), each group gets M/G
296
+ # For imbalanced (group_balance=1), first group gets much more than last group
297
+ balanced_size = remaining // (G - g)
298
+
299
+ # Adjusting size based on position and imbalance factor
300
+ # First groups get more, last groups get less
301
+ if g < G // 2:
302
+ # First half of groups get more
303
+ adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
304
+ size = balanced_size + adjustment
305
+ else:
306
+ # Second half of groups get less
307
+ adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
308
+ size = balanced_size - adjustment
309
+
310
+ # Ensure we don't go below 1 or take more than remaining
311
+ size = max(1, min(size, remaining))
312
+ M_sizes.append(size)
313
+ remaining -= size
314
+
315
+ # Handle any remaining elements
316
+ if remaining > 0:
317
+ M_sizes[-1] += remaining
318
+
319
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
320
+
321
+ # Create input and weight tensors
322
+ x = torch.randn(M, K, dtype=dtype, device=device)
323
+ w = torch.randn(N, K, dtype=dtype, device=device)
324
+
325
+ # Benchmark logic
326
+ if provider == "pytorch_reference":
327
+ torch.cuda.synchronize()
328
+ compute_reference_forward(x, w, m_sizes)
329
+ torch.cuda.synchronize()
330
+
331
+ start_time = time.time()
332
+ for _ in range(10):
333
+ compute_reference_forward(x, w, m_sizes)
334
+ torch.cuda.synchronize()
335
+ end_time = time.time()
336
+ else:
337
+ torch.cuda.synchronize()
338
+ grouped_gemm_forward(x, w, m_sizes)
339
+ torch.cuda.synchronize()
340
+
341
+ start_time = time.time()
342
+ for _ in range(10):
343
+ grouped_gemm_forward(x, w, m_sizes)
344
+ torch.cuda.synchronize()
345
+ end_time = time.time()
346
+
347
+ # Calculate FLOPs and TFLOPS
348
+ flops = 2 * M * N * K
349
+ avg_time = (end_time - start_time) / 10
350
+ tflops = flops / avg_time / 1e12
351
+
352
+ return tflops
353
+
354
+
355
+ def benchmark_model_configs():
356
+ """
357
+ Benchmark common model configurations used in DeepSeek-like models.
358
+ """
359
+ # Model configurations: (M, K, N, G)
360
+ configs = [
361
+ (8192, 7168, 4096, 4), # Config 1
362
+ (8192, 2048, 7168, 4), # Config 2
363
+ (4096, 7168, 4096, 8), # Config 3
364
+ (4096, 2048, 7168, 8), # Config 4
365
+ ]
366
+
367
+ results = []
368
+
369
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
370
+ dtype = torch.float16
371
+
372
+ for config_idx, (M, K, N, G) in enumerate(configs):
373
+ logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
374
+ logging.info(f"M={M}, K={K}, N={N}, G={G}")
375
+
376
+ # Create group sizes for M dimension
377
+ base_size = M // G
378
+ remainder = M % G
379
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
380
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
381
+
382
+ # Create tensors
383
+ x = torch.randn(M, K, dtype=dtype, device=device)
384
+ w = torch.randn(N, K, dtype=dtype, device=device)
385
+
386
+ # Benchmark PyTorch reference
387
+ torch.cuda.synchronize()
388
+ compute_reference_forward(x, w, m_sizes) # Warmup
389
+ torch.cuda.synchronize()
390
+
391
+ logging.info("Benchmarking PyTorch reference...")
392
+ torch.cuda.reset_peak_memory_stats()
393
+ start_time = time.time()
394
+ for _ in range(10):
395
+ compute_reference_forward(x, w, m_sizes)
396
+ torch.cuda.synchronize()
397
+ end_time = time.time()
398
+ pt_time = (end_time - start_time) / 10
399
+ pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
400
+
401
+ # Benchmark optimized kernel
402
+ torch.cuda.synchronize()
403
+ grouped_gemm_forward(x, w, m_sizes) # Warmup
404
+ torch.cuda.synchronize()
405
+
406
+ logging.info("Benchmarking optimized kernel...")
407
+ torch.cuda.reset_peak_memory_stats()
408
+ start_time = time.time()
409
+ for _ in range(10):
410
+ grouped_gemm_forward(x, w, m_sizes)
411
+ torch.cuda.synchronize()
412
+ end_time = time.time()
413
+ opt_time = (end_time - start_time) / 10
414
+ opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
415
+
416
+ # Calculate FLOPs and speedup
417
+ flops = 2 * M * N * K
418
+ pt_tflops = flops / pt_time / 1e12
419
+ opt_tflops = flops / opt_time / 1e12
420
+ speedup = pt_time / opt_time
421
+
422
+ # Store results
423
+ results.append(
424
+ {
425
+ "config": f"Config {config_idx + 1}",
426
+ "dimensions": f"M={M}, K={K}, N={N}, G={G}",
427
+ "pt_time_ms": pt_time * 1000,
428
+ "opt_time_ms": opt_time * 1000,
429
+ "pt_tflops": pt_tflops,
430
+ "opt_tflops": opt_tflops,
431
+ "speedup": speedup,
432
+ "pt_memory_mb": pt_memory,
433
+ "opt_memory_mb": opt_memory,
434
+ "memory_savings": (
435
+ (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
436
+ ),
437
+ }
438
+ )
439
+
440
+ logging.info(
441
+ f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
442
+ )
443
+ logging.info(
444
+ f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
445
+ )
446
+ logging.info(
447
+ f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
448
+ )
449
+
450
+ # Print summary table
451
+ logging.info("\n===== Benchmark Results Summary =====")
452
+ logging.info(
453
+ f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
454
+ )
455
+ logging.info(
456
+ f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
457
+ f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
458
+ )
459
+ logging.info("-" * 100)
460
+
461
+ for result in results:
462
+ logging.info(
463
+ f"{result['config']:<10} | "
464
+ f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
465
+ f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
466
+ f"{result['speedup']:<10.2f} | "
467
+ f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
468
+ f"{result['memory_savings']:<12.2f}%"
469
+ )
470
+
471
+ return results
472
+
473
+
474
+ def plot_benchmark_results(results):
475
+ """
476
+ Plot benchmark results as bar charts.
477
+ """
478
+ # Extract data
479
+ configs = [r["config"] for r in results]
480
+ pt_tflops = [r["pt_tflops"] for r in results]
481
+ opt_tflops = [r["opt_tflops"] for r in results]
482
+ speedups = [r["speedup"] for r in results]
483
+
484
+ # Create figure with subplots
485
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
486
+
487
+ # Plot TFLOPS comparison
488
+ x = np.arange(len(configs))
489
+ width = 0.35
490
+ ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
491
+ ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
492
+ ax1.set_xlabel("Model Configuration")
493
+ ax1.set_ylabel("TFLOPS")
494
+ ax1.set_title("Performance Comparison (Higher is Better)")
495
+ ax1.set_xticks(x)
496
+ ax1.set_xticklabels(configs)
497
+ ax1.legend()
498
+ ax1.grid(axis="y", linestyle="--", alpha=0.7)
499
+
500
+ # Plot speedup
501
+ ax2.bar(x, speedups, width=0.6, color="green")
502
+ ax2.set_xlabel("Model Configuration")
503
+ ax2.set_ylabel("Speedup (x)")
504
+ ax2.set_title("Speedup Factor (Higher is Better)")
505
+ ax2.set_xticks(x)
506
+ ax2.set_xticklabels(configs)
507
+ ax2.grid(axis="y", linestyle="--", alpha=0.7)
508
+
509
+ # Add speedup values on top of bars
510
+ for i, v in enumerate(speedups):
511
+ ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
512
+
513
+ plt.tight_layout()
514
+ plt.savefig("mg_grouped_gemm_benchmark_results.png")
515
+ logging.info(
516
+ "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
517
+ )
518
+
519
+
520
+ def compare_mg_implementations():
521
+ """
522
+ Combine the M*G and N*G benchmark results for comparison.
523
+ """
524
+ # Only run this if both NG and MG benchmarks have been run
525
+ try:
526
+ import pandas as pd
527
+
528
+ # Try to load previous benchmark results
529
+ mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
530
+ ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
531
+
532
+ # Create comparison plot
533
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
534
+
535
+ # Plot speedup comparison
536
+ configs = mg_results["config"].unique()
537
+ mg_speedups = mg_results.groupby("config")["speedup"].mean()
538
+ ng_speedups = ng_results.groupby("config")["speedup"].mean()
539
+
540
+ x = np.arange(len(configs))
541
+ width = 0.35
542
+
543
+ axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
544
+ axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
545
+ axes[0].set_xlabel("Model Configuration")
546
+ axes[0].set_ylabel("Speedup (x)")
547
+ axes[0].set_title("Speedup Comparison: M*G vs N*G")
548
+ axes[0].set_xticks(x)
549
+ axes[0].set_xticklabels(configs)
550
+ axes[0].legend()
551
+ axes[0].grid(axis="y", linestyle="--", alpha=0.7)
552
+
553
+ # Plot TFLOPS comparison for optimized kernels
554
+ mg_tflops = (
555
+ mg_results[mg_results["implementation"] == "optimized"]
556
+ .groupby("config")["tflops"]
557
+ .mean()
558
+ )
559
+ ng_tflops = (
560
+ ng_results[ng_results["implementation"] == "optimized"]
561
+ .groupby("config")["tflops"]
562
+ .mean()
563
+ )
564
+
565
+ axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
566
+ axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
567
+ axes[1].set_xlabel("Model Configuration")
568
+ axes[1].set_ylabel("TFLOPS")
569
+ axes[1].set_title("Performance Comparison: M*G vs N*G")
570
+ axes[1].set_xticks(x)
571
+ axes[1].set_xticklabels(configs)
572
+ axes[1].legend()
573
+ axes[1].grid(axis="y", linestyle="--", alpha=0.7)
574
+
575
+ plt.tight_layout()
576
+ plt.savefig("mg_vs_ng_comparison.png")
577
+ logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
578
+
579
+ except Exception as e:
580
+ logging.error(f"Could not create comparison plot: {e}")
581
+ logging.info(
582
+ "Run both M*G and N*G benchmarks first to generate comparison plots"
583
+ )
584
+
585
+
586
+ if __name__ == "__main__":
587
+ parser = argparse.ArgumentParser(
588
+ description="Benchmark M*G Grouped GEMM implementations"
589
+ )
590
+ parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
591
+ parser.add_argument(
592
+ "--triton-bench", action="store_true", help="Run Triton performance reports"
593
+ )
594
+ parser.add_argument(
595
+ "--model-configs", action="store_true", help="Benchmark model configurations"
596
+ )
597
+ parser.add_argument(
598
+ "--compare-mg-ng",
599
+ action="store_true",
600
+ help="Compare M*G and N*G implementations",
601
+ )
602
+ args = parser.parse_args()
603
+
604
+ # Check if CUDA is available
605
+ if not torch.cuda.is_available():
606
+ logging.error(
607
+ "CUDA is not available. This benchmark requires a CUDA-capable GPU."
608
+ )
609
+ exit(1)
610
+
611
+ if args.run_all or args.model_configs:
612
+ # Benchmark model configurations
613
+ logging.info("Running benchmark for model configurations...")
614
+ results = benchmark_model_configs()
615
+ plot_benchmark_results(results)
616
+
617
+ if args.run_all or args.triton_bench:
618
+ # Run Triton performance reports
619
+ logging.info("Running Triton performance reports...")
620
+ benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
621
+ benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
622
+ benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
623
+ logging.info(
624
+ "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
625
+ )
626
+
627
+ if args.run_all or args.compare_mg_ng:
628
+ # Compare M*G and N*G implementations
629
+ logging.info("Comparing M*G and N*G implementations...")
630
+ compare_mg_implementations()
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mg_grouped_gemm import grouped_gemm_forward
8
+ from .tma_autotuning import ALIGN_SIZE_M
9
+
10
+ __all__ = [
11
+ "grouped_gemm_forward",
12
+ "ALIGN_SIZE_M",
13
+ ]
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - TMAHelper class, AutoTuning are derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+
13
+ import os
14
+ import sys
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ from triton import Config as TConfig
22
+
23
+ from triton.runtime import driver # @manual
24
+
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+
28
+ # ===== Supporting utils, CUDA and TMA =====
29
+
30
+
31
+ class CudaUtils:
32
+ @staticmethod
33
+ def is_cuda() -> bool:
34
+ """Check if Triton is running on CUDA backend."""
35
+ return driver.active.get_current_target().backend == "cuda"
36
+
37
+ @staticmethod
38
+ def verify_tma() -> bool:
39
+ """Check if TMA is supported on the current device."""
40
+ return (
41
+ CudaUtils.is_cuda()
42
+ and torch.cuda.is_available()
43
+ and torch.cuda.get_device_capability()[0] >= 9
44
+ )
45
+
46
+ @staticmethod
47
+ def get_num_sms() -> int:
48
+ """Get the number of streaming multiprocessors on the current device."""
49
+ if not CudaUtils.is_cuda():
50
+ raise RuntimeError("Triton is not running on CUDA backend")
51
+ if not torch.cuda.is_available():
52
+ raise RuntimeError("CUDA is not available")
53
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
54
+
55
+
56
+ class TmaDescriptorHelper:
57
+ """Helper class for managing TMA descriptors in Triton kernels."""
58
+
59
+ class KernelParamWrapper:
60
+ """Wrapper to implement the TmaDescKernelParam interface."""
61
+
62
+ def __init__(self, desc: torch.Tensor):
63
+ self.desc = desc
64
+
65
+ def tma_desc_cpu_ptr(self) -> int:
66
+ """Return the CPU pointer to the TMA descriptor."""
67
+ return self.desc.data_ptr()
68
+
69
+ def __init__(self, tma_size: int = 128):
70
+ """Initialize the TMA descriptor helper.
71
+
72
+ Args:
73
+ tma_size: Size of the TMA descriptor in bytes
74
+ """
75
+ if not CudaUtils.verify_tma():
76
+ raise RuntimeError(
77
+ "TMA not supported on this device (requires Hopper or newer)"
78
+ )
79
+ if "nv_tma_desc_type" not in dir(tl):
80
+ raise RuntimeError(
81
+ "TMA grid constant descriptors not supported in your Triton version"
82
+ )
83
+
84
+ self.tma_size = tma_size
85
+ self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
86
+ self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
87
+ self.descriptors: Dict[str, torch.Tensor] = {}
88
+
89
+ def init_tma_descriptor(self, name: str) -> None:
90
+ """Initialize a TMA descriptor with the given name.
91
+
92
+ Call this method outside of the lambda function for grid size.
93
+ """
94
+ self.descriptors[name] = torch.empty(
95
+ self.tma_size, device="cpu", dtype=torch.int8
96
+ )
97
+
98
+ def fill_1d_tma_descriptor(
99
+ self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
100
+ ) -> None:
101
+ """Fill a 1D TMA descriptor.
102
+
103
+ Call this method inside the lambda function for grid size.
104
+ """
105
+ if name not in self.descriptors:
106
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
107
+
108
+ desc_x = self.descriptors[name]
109
+ if desc_x.data_ptr() % 64 != 0:
110
+ raise ValueError("TMA descriptor must be 64-byte aligned")
111
+ self.fill_1d_tma_descriptor_inner(
112
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
113
+ )
114
+
115
+ def fill_2d_tma_descriptor(
116
+ self,
117
+ name: str,
118
+ ptr: int,
119
+ dim1: int,
120
+ dim0: int,
121
+ block_dim1: int,
122
+ block_dim0: int,
123
+ element_size: int,
124
+ ) -> None:
125
+ """Fill a 2D TMA descriptor.
126
+
127
+ Call this method inside the lambda function for grid size.
128
+ """
129
+ if name not in self.descriptors:
130
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
131
+
132
+ desc_x = self.descriptors[name]
133
+ if desc_x.data_ptr() % 64 != 0:
134
+ raise ValueError("TMA descriptor must be 64-byte aligned")
135
+ self.fill_2d_tma_descriptor_inner(
136
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
137
+ )
138
+
139
+ def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
140
+ """Get the TMA descriptor kernel parameter for the given name."""
141
+ if name not in self.descriptors or self.descriptors[name] is None:
142
+ raise ValueError(f"TMA descriptor '{name}' not initialized")
143
+ return self.KernelParamWrapper(self.descriptors[name])
144
+
145
+
146
+ # ====== Autotuning utilities ======
147
+ ALIGN_SIZE_M = 128
148
+
149
+ _NV_CONFIGS = [
150
+ triton.Config(
151
+ {
152
+ "BLOCK_SIZE_M": block_size_m,
153
+ "BLOCK_SIZE_N": block_size_n,
154
+ "BLOCK_SIZE_K": block_size_k,
155
+ },
156
+ num_stages=num_stages,
157
+ num_warps=num_warps,
158
+ num_ctas=num_ctas,
159
+ )
160
+ for block_size_m in [ALIGN_SIZE_M, ]
161
+ for block_size_n in [64, 128, 256]
162
+ for block_size_k in [64, 128, 256]
163
+ for num_stages in [3, 4]
164
+ for num_warps in [4, 8]
165
+ for num_ctas in [1]
166
+ ]
167
+
168
+
169
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
170
+ device = torch.cuda.current_device()
171
+ # Check for all possible pointer parameter names
172
+ if "grad_input_ptr" in named_args:
173
+ ptr_name = "grad_input_ptr"
174
+ elif "c_ptr" in named_args:
175
+ ptr_name = "c_ptr"
176
+ elif "grad_weight_ptr" in named_args:
177
+ ptr_name = "grad_weight_ptr"
178
+ else:
179
+ raise KeyError("No recognized pointer parameter found in kernel arguments")
180
+
181
+ if dtsize is None:
182
+ dtsize = named_args[ptr_name].element_size()
183
+ if dtype is None:
184
+ dtype = named_args[ptr_name].dtype
185
+
186
+ pruned_configs = []
187
+ for config in configs:
188
+ kw = config.kwargs
189
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
190
+ kw["BLOCK_SIZE_M"],
191
+ kw["BLOCK_SIZE_N"],
192
+ kw["BLOCK_SIZE_K"],
193
+ config.num_stages,
194
+ )
195
+ G, M, N, K = (
196
+ named_args["G"],
197
+ named_args["M_BUCKET"],
198
+ named_args["N"],
199
+ named_args["K"],
200
+ )
201
+
202
+ # 1. make sure we have enough smem
203
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
204
+ "max_shared_mem"
205
+ ]
206
+
207
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
208
+ if required_shared_memory > max_shared_memory:
209
+ continue
210
+
211
+ M_PER_GROUP = M // G
212
+ MIN_M_TILES = 64
213
+ # 2. make sure we don't load M tiles that are too big
214
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
215
+ continue
216
+ # 3. make sure we don't load N tiles that are too small
217
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
218
+ continue
219
+
220
+ num_sm = driver.active.utils.get_device_properties(device)[
221
+ "multiprocessor_count"
222
+ ]
223
+ N_TILES = N // BLOCK_N
224
+ MIN_N_TILES = 64
225
+ # 4. make sure we don't load N tiles that are too big
226
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
227
+ continue
228
+ # 5. make sure we don't load N tiles that are too small
229
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
230
+ continue
231
+ # 6. make sure K can be evenly divided
232
+ if K % BLOCK_K != 0:
233
+ continue
234
+
235
+ pruned_configs.append(config)
236
+
237
+ return pruned_configs
238
+
239
+
240
+ # ======== End Autotuning utilities ========
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import unittest
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from mg_grouped_gemm import grouped_gemm_forward
16
+
17
+
18
+ class TestMG_GroupedGEMM(unittest.TestCase):
19
+ def setUp(self) -> None:
20
+ torch.manual_seed(2020)
21
+
22
+ def _run_grouped_gemm_test(
23
+ self,
24
+ shape: Tuple[int, int, int, int],
25
+ device: torch.device,
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ atol: float = 1e-5,
28
+ rtol: float = 1.6e-2,
29
+ ) -> None:
30
+ G, M, N, K = shape
31
+ # In M*G grouping, input is [M*G, K] and weights are [N*G, K]
32
+ a = torch.randn(M * G, K, dtype=dtype, device=device)
33
+ b = torch.randn(N * G, K, dtype=dtype, device=device)
34
+
35
+ # Create equal-sized groups for simplicity
36
+ m_size = M
37
+ m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32)
38
+
39
+ result = grouped_gemm_forward(a, b, m_sizes)
40
+ self.assertTrue(result.shape == (M * G, N))
41
+
42
+ expected_result = torch.zeros(M * G, N, dtype=dtype, device=device)
43
+ m_start = 0
44
+ for g in range(G):
45
+ m_end = m_start + m_sizes[g]
46
+ b_slice = b[N * g : N * (g+1), :]
47
+ expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T
48
+ m_start = m_end
49
+
50
+ # Convert result to match input dtype if needed
51
+ result = result.to(dtype)
52
+ torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
53
+
54
+ def test_MG_grouped_gemm_bf16(self) -> None:
55
+ for G in (1, 4, 16):
56
+ for M in (128, 512, 1024):
57
+ print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
58
+ self._run_grouped_gemm_test(
59
+ (G, M, 1024, 1024),
60
+ torch.device("cuda"),
61
+ dtype=torch.bfloat16,
62
+ atol=1e-5,
63
+ rtol=1.6e-2,
64
+ )
65
+
66
+ def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
67
+ """Test with shapes from Deepseek model."""
68
+ deepseek_shapes = [
69
+ (4, 2048, 4096, 7168), # G, M, N, K
70
+ (4, 2048, 7168, 2048),
71
+ (8, 512, 4096, 7168),
72
+ (8, 512, 7168, 2048),
73
+ ]
74
+
75
+ device = torch.device("cuda")
76
+
77
+ for shape in deepseek_shapes:
78
+ G, M, N, K = shape
79
+ print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}")
80
+ self._run_grouped_gemm_test(
81
+ shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2
82
+ )
torchtitan/experiments/llama4/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
torchtitan/experiments/llama4/infra/expert_parallel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from functools import partial
9
+ from typing import Optional, Tuple
10
+
11
+ import torch.nn as nn
12
+ from torch.distributed.tensor import (
13
+ DeviceMesh,
14
+ distribute_module,
15
+ distribute_tensor,
16
+ DTensor,
17
+ Partial,
18
+ Replicate,
19
+ Shard,
20
+ )
21
+ from torch.distributed.tensor.parallel import ParallelStyle
22
+ from torch.distributed.tensor.placement_types import Placement
23
+
24
+
25
+ # implementation of Tensor Parallel on the non-shared experts in MoE
26
+ class TensorParallel(ParallelStyle):
27
+ def __init__(
28
+ self,
29
+ *,
30
+ input_layouts: Optional[Tuple[Optional[Placement]]] = None,
31
+ output_layout: Optional[Placement] = None,
32
+ use_local_output: bool = True,
33
+ ):
34
+ super().__init__()
35
+ self.input_layouts = input_layouts or (Replicate(), None)
36
+ self.output_layout = output_layout or Partial()
37
+ self.desired_input_layouts = (Replicate(), None)
38
+ self.use_local_output = use_local_output
39
+
40
+ @staticmethod
41
+ def _prepare_input_fn(
42
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
43
+ ):
44
+ # TODO: figure out dynamo support for instance method and switch this to instance method
45
+
46
+ # annotate module input placements/sharding with input_layouts
47
+ input_tensor, input_layout, desired_input_layout = (
48
+ inputs[0],
49
+ input_layouts[0],
50
+ desired_input_layouts[0],
51
+ )
52
+ if not isinstance(input_tensor, DTensor):
53
+ input_tensor = DTensor.from_local(
54
+ input_tensor, device_mesh, (input_layout,), run_check=False
55
+ )
56
+
57
+ if input_layouts != desired_input_layouts:
58
+ input_tensor = input_tensor.redistribute(
59
+ placements=(desired_input_layout,), async_op=True
60
+ )
61
+ return (input_tensor, *inputs[1:])
62
+
63
+ def _partition_fn(self, name, module, device_mesh):
64
+ module.register_parameter(
65
+ "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
66
+ ) # Column-wise sharding
67
+ module.register_parameter(
68
+ "w2",
69
+ nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
70
+ ) # Row-wise sharding
71
+ module.register_parameter(
72
+ "w3",
73
+ nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
74
+ ) # Column-wise sharding
75
+
76
+ @staticmethod
77
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
78
+ if outputs.placements != (output_layout,):
79
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
80
+ # back to local tensor
81
+ return outputs.to_local() if use_local_output else outputs
82
+
83
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
84
+ return distribute_module(
85
+ module,
86
+ device_mesh,
87
+ self._partition_fn,
88
+ partial(
89
+ self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
90
+ ),
91
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
92
+ )
93
+
94
+
95
+ # NOTE: This is to achieve replicate computation on the gate module in the MoE router.
96
+ # It does nothing other than (1) setting the module parameters as DTensors on the given mesh
97
+ # and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
98
+ # TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
99
+ # which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
100
+ class NoParallel(ParallelStyle):
101
+ def __init__(
102
+ self,
103
+ *,
104
+ input_layout: Optional[Placement] = None,
105
+ output_layout: Optional[Placement] = None,
106
+ use_local_output: bool = True,
107
+ ):
108
+ super().__init__()
109
+ self.input_layout = input_layout or Replicate()
110
+ self.output_layout = output_layout or Replicate()
111
+ self.desired_input_layout = Replicate()
112
+ self.use_local_output = use_local_output
113
+
114
+ @staticmethod
115
+ def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
116
+ # annotate module input placements/sharding with input_layouts
117
+ input_tensor = inputs[0]
118
+ if not isinstance(input_tensor, DTensor):
119
+ input_tensor = DTensor.from_local(
120
+ input_tensor, device_mesh, (input_layout,), run_check=False
121
+ )
122
+
123
+ if input_layout != desired_input_layout:
124
+ input_tensor = input_tensor.redistribute(
125
+ placements=(desired_input_layout,), async_op=True
126
+ )
127
+ return (input_tensor, *inputs[1:])
128
+
129
+ @staticmethod
130
+ def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
131
+ if outputs.placements != (output_layout,):
132
+ outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
133
+ # back to local tensor
134
+ return outputs.to_local() if use_local_output else outputs
135
+
136
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
137
+ return distribute_module(
138
+ module,
139
+ device_mesh,
140
+ None,
141
+ partial(
142
+ self._prepare_input_fn, self.input_layout, self.desired_input_layout
143
+ ),
144
+ partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
145
+ )
torchtitan/experiments/llama4/infra/parallelize_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+
15
+ from torchtitan.models.llama3.parallelize_llama import (
16
+ apply_ac,
17
+ apply_compile,
18
+ apply_ddp,
19
+ apply_fsdp,
20
+ apply_tp,
21
+ )
22
+ from torchtitan.tools.logging import logger
23
+
24
+
25
+ def parallelize_llama(
26
+ model: nn.Module,
27
+ world_mesh: DeviceMesh,
28
+ parallel_dims: ParallelDims,
29
+ job_config: JobConfig,
30
+ ):
31
+ """
32
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
33
+ parallelism to the model.
34
+
35
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
36
+ the model must fit on GPU or CPU memory.
37
+ """
38
+
39
+ if parallel_dims.tp_enabled:
40
+ if (
41
+ job_config.parallelism.enable_async_tensor_parallel
42
+ and not job_config.training.compile
43
+ ):
44
+ raise RuntimeError("Async TP requires --training.compile")
45
+
46
+ enable_float8_linear = "float8" in job_config.model.converters
47
+ float8_is_rowwise = job_config.float8.recipe_name in (
48
+ "rowwise",
49
+ "rowwise_with_gw_hp",
50
+ )
51
+
52
+ # For now, float8 all-gather with TP is only supported for tensorwise
53
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
54
+ # all-gather happens in high precision.
55
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
56
+
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
62
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
63
+ )
64
+
65
+ apply_moe_tp(model, world_mesh["tp"])
66
+
67
+ if job_config.activation_checkpoint.mode != "none":
68
+ if (
69
+ job_config.activation_checkpoint.mode == "selective"
70
+ and job_config.model.use_flex_attn
71
+ ):
72
+ raise ValueError(
73
+ "FlexAttention is not compatible with selective AC yet. "
74
+ "See https://github.com/pytorch/pytorch/issues/147879"
75
+ )
76
+ apply_ac(model, job_config.activation_checkpoint)
77
+
78
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
79
+ if job_config.training.compile:
80
+ apply_compile(model)
81
+
82
+ # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
+
85
+ if (
86
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
87
+ ): # apply FSDP or HSDP, potentially with Context Parallel
88
+ if parallel_dims.dp_replicate_enabled:
89
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
90
+ else:
91
+ dp_mesh_dim_names = ("dp_shard_cp",)
92
+
93
+ apply_fsdp(
94
+ model,
95
+ world_mesh[tuple(dp_mesh_dim_names)],
96
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
97
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
98
+ pp_enabled=parallel_dims.pp_enabled,
99
+ cpu_offload=job_config.training.enable_cpu_offload,
100
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
101
+ )
102
+
103
+ if parallel_dims.dp_replicate_enabled:
104
+ logger.info("Applied HSDP to the model")
105
+ else:
106
+ logger.info("Applied FSDP to the model")
107
+
108
+ if parallel_dims.cp_enabled:
109
+ logger.info("Applied Context Parallel to the model")
110
+
111
+ if job_config.training.enable_cpu_offload:
112
+ logger.info("Applied CPU Offloading to the model")
113
+ elif parallel_dims.dp_replicate_enabled:
114
+ if world_mesh.ndim > 1:
115
+ raise RuntimeError("DDP has not supported > 1D parallelism")
116
+ apply_ddp(
117
+ model,
118
+ world_mesh,
119
+ enable_compile=job_config.training.compile,
120
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121
+ )
122
+
123
+ return model
124
+
125
+
126
+ def apply_moe_tp(
127
+ model: nn.Module,
128
+ tp_mesh: DeviceMesh,
129
+ ):
130
+ from torch.distributed.tensor import Partial, Replicate, Shard
131
+ from torch.distributed.tensor.parallel import (
132
+ parallelize_module,
133
+ PrepareModuleInputOutput,
134
+ )
135
+
136
+ from .expert_parallel import NoParallel, TensorParallel
137
+
138
+ for _, transformer_block in model.layers.items():
139
+ moe_layer_plan = {
140
+ # input / output sharding on the seqlen dim
141
+ # all-gather for input, reduce-scatter for output
142
+ "moe": PrepareModuleInputOutput(
143
+ input_layouts=(Shard(1),),
144
+ desired_input_layouts=(Replicate(),),
145
+ use_local_input=True,
146
+ output_layouts=(Partial(),),
147
+ desired_output_layouts=(Shard(1),),
148
+ ),
149
+ # replicate computation for the router
150
+ "moe.router.gate": NoParallel(),
151
+ # input Replicate, output Partial
152
+ "moe.experts": TensorParallel(),
153
+ "moe.shared_expert": TensorParallel(),
154
+ }
155
+ parallelize_module(
156
+ module=transformer_block,
157
+ device_mesh=tp_mesh,
158
+ parallelize_plan=moe_layer_plan,
159
+ )
torchtitan/experiments/llama4/model/__pycache__/model.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
torchtitan/experiments/llama4/model/model.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from torchtitan.models.attention import build_attention, init_attention_mask
13
+ from torchtitan.models.norms import build_norm
14
+ from torchtitan.protocols.train_spec import ModelProtocol
15
+
16
+ from .args import TransformerModelArgs
17
+ from .moe import MoE
18
+
19
+
20
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
21
+ """
22
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
23
+
24
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
25
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
26
+ The returned tensor contains complex values in complex64 data type.
27
+
28
+ Args:
29
+ dim (int): Dimension of the frequency tensor.
30
+ end (int): End index for precomputing frequencies.
31
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
32
+
33
+ Returns:
34
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
35
+ """
36
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
37
+ t = torch.arange(end, device=freqs.device)
38
+ freqs = torch.outer(t, freqs).float()
39
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
40
+ return freqs_cis
41
+
42
+
43
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Reshape frequency tensor for broadcasting it with another tensor.
46
+
47
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
48
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
49
+
50
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
51
+ and the first seqlen elements will be sliced, but dim must match x.
52
+
53
+ Args:
54
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
55
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
56
+
57
+ Returns:
58
+ torch.Tensor: Reshaped frequency tensor.
59
+ """
60
+ ndim = x.ndim
61
+ assert ndim > 1
62
+ seqlen = x.shape[1]
63
+ freqs_cis = freqs_cis[0:seqlen]
64
+ assert freqs_cis.shape == (seqlen, x.shape[-1])
65
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
66
+ return freqs_cis.view(*shape)
67
+
68
+
69
+ def apply_rotary_emb(
70
+ xq: torch.Tensor,
71
+ xk: torch.Tensor,
72
+ freqs_cis: torch.Tensor,
73
+ ) -> tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Apply rotary embeddings to input tensors using the given frequency tensor.
76
+
77
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
78
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
79
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
80
+ returned as real tensors.
81
+
82
+ Args:
83
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
84
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
85
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
86
+
87
+ Returns:
88
+ tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
89
+ """
90
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
91
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
92
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
93
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
94
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
95
+ return xq_out.type_as(xq), xk_out.type_as(xk)
96
+
97
+
98
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
99
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
100
+ bs, slen, n_kv_heads, head_dim = x.shape
101
+ if n_rep == 1:
102
+ return x
103
+ return (
104
+ torch.unsqueeze(x, dim=3)
105
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
106
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
107
+ )
108
+
109
+
110
+ class Attention(nn.Module):
111
+ """
112
+ Multi-head attention module.
113
+
114
+ Args:
115
+ model_args (TransformerModelArgs): Model configuration arguments.
116
+
117
+ Attributes:
118
+ n_kv_heads (int): Number of key and value heads.
119
+ n_heads (int): Number of query heads.
120
+ n_rep (int): Number of repetitions for local heads.
121
+ head_dim (int): Dimension size of each attention head.
122
+ wq (Linear): Linear transformation for queries.
123
+ wk (Linear): Linear transformation for keys.
124
+ wv (Linear): Linear transformation for values.
125
+ wo (Linear): Linear transformation for output.
126
+
127
+ """
128
+
129
+ def __init__(self, model_args: TransformerModelArgs):
130
+ super().__init__()
131
+ self.n_heads = model_args.n_heads
132
+ self.n_kv_heads = (
133
+ model_args.n_heads
134
+ if model_args.n_kv_heads is None
135
+ else model_args.n_kv_heads
136
+ )
137
+ self.n_rep = self.n_heads // self.n_kv_heads
138
+ self.head_dim = model_args.dim // model_args.n_heads
139
+
140
+ self.wq = nn.Linear(
141
+ model_args.dim, model_args.n_heads * self.head_dim, bias=False
142
+ )
143
+ self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
144
+ self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
145
+ self.wo = nn.Linear(
146
+ model_args.n_heads * self.head_dim, model_args.dim, bias=False
147
+ )
148
+ self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
149
+
150
+ def init_weights(self, init_std: float):
151
+ for linear in (self.wq, self.wk, self.wv):
152
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
153
+ nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ freqs_cis: torch.Tensor,
159
+ ):
160
+ """
161
+ Forward pass of the attention module.
162
+
163
+ Args:
164
+ x (torch.Tensor): Input tensor.
165
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
166
+
167
+ Returns:
168
+ torch.Tensor: Output tensor after attention.
169
+
170
+ """
171
+
172
+ bs, seqlen, _ = x.shape
173
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
174
+
175
+ # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
176
+ # local heads from sizes of xq, xk, and xv as TP may have sharded them
177
+ # after the above linear ops.
178
+ xq = xq.view(bs, seqlen, -1, self.head_dim)
179
+ xk = xk.view(bs, seqlen, -1, self.head_dim)
180
+ xv = xv.view(bs, seqlen, -1, self.head_dim)
181
+
182
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
183
+
184
+ # repeat k/v heads if n_kv_heads < n_heads
185
+ keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
186
+ values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
187
+
188
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
+ xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
190
+ xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
191
+
192
+ output = self.sdpa(xq, xk, xv)
193
+
194
+ output = output.transpose(
195
+ 1, 2
196
+ ).contiguous() # (bs, seqlen, n_local_heads, head_dim)
197
+ output = output.view(bs, seqlen, -1)
198
+ return self.wo(output)
199
+
200
+
201
+ class FeedForward(nn.Module):
202
+ """
203
+ FeedForward module
204
+
205
+ Args:
206
+ dim (int): Input dimension.
207
+ hidden_dim (int): Hidden dimension of the feedforward layer.
208
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
209
+ ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
210
+
211
+ Attributes:
212
+ w1 (Linear): Linear transformation for the first layer.
213
+ w2 (Linear): Linear transformation for the second layer.
214
+ w3 (Linear): Linear transformation for the third layer.
215
+
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ dim: int,
221
+ hidden_dim: int,
222
+ multiple_of: int,
223
+ ffn_dim_multiplier: float | None,
224
+ ):
225
+ super().__init__()
226
+ hidden_dim = int(2 * hidden_dim / 3)
227
+ # custom dim factor multiplier
228
+ if ffn_dim_multiplier is not None:
229
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
230
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
231
+
232
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
233
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
234
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
235
+
236
+ def forward(self, x):
237
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
238
+
239
+ def init_weights(self, init_std: float):
240
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
241
+ for linear in (self.w2, self.w3):
242
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
243
+
244
+
245
+ class TransformerBlock(nn.Module):
246
+ """
247
+ TransformerBlock Module
248
+
249
+ Args:
250
+ layer_id (int): Identifier for the layer.
251
+ model_args (TransformerModelArgs): Model configuration arguments.
252
+
253
+ Attributes:
254
+ n_heads (int): Number of attention heads.
255
+ dim (int): Dimension size of the model.
256
+ head_dim (int): Dimension size of each attention head.
257
+ attention (Attention): Attention module.
258
+ feed_forward (FeedForward): FeedForward module.
259
+ layer_id (int): Identifier for the layer.
260
+ attention_norm (RMSNorm): Layer normalization for attention output.
261
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
262
+
263
+ """
264
+
265
+ def __init__(self, layer_id: int, model_args: TransformerModelArgs):
266
+ super().__init__()
267
+ self.n_heads = model_args.n_heads
268
+ self.dim = model_args.dim
269
+ self.attention = Attention(model_args)
270
+
271
+ # use MoE layer for every interleave_moe_layer_step FFN layers
272
+ self.moe_enabled = (
273
+ model_args.moe_enabled
274
+ and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
275
+ )
276
+ if self.moe_enabled:
277
+ self.moe = MoE(model_args)
278
+ else:
279
+ self.feed_forward = FeedForward(
280
+ dim=model_args.dim,
281
+ hidden_dim=4 * model_args.dim,
282
+ multiple_of=model_args.multiple_of,
283
+ ffn_dim_multiplier=model_args.ffn_dim_multiplier,
284
+ )
285
+
286
+ self.layer_id = layer_id
287
+ self.num_layers = model_args.n_layers
288
+
289
+ self.attention_norm = build_norm(
290
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
291
+ )
292
+ self.ffn_norm = build_norm(
293
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
294
+ )
295
+
296
+ if model_args.depth_init:
297
+ self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
298
+ else:
299
+ self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
300
+
301
+ def forward(
302
+ self,
303
+ x: torch.Tensor,
304
+ freqs_cis: torch.Tensor,
305
+ ):
306
+ """
307
+ Perform a forward pass through the TransformerBlock.
308
+
309
+ Args:
310
+ x (torch.Tensor): Input tensor.
311
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
312
+
313
+ Returns:
314
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
315
+
316
+ """
317
+ h = x + self.attention(self.attention_norm(x), freqs_cis)
318
+ if self.moe_enabled:
319
+ out = h + self.moe(self.ffn_norm(h))
320
+ else:
321
+ out = h + self.feed_forward(self.ffn_norm(h))
322
+ return out
323
+
324
+ def init_weights(self):
325
+ for norm in (self.attention_norm, self.ffn_norm):
326
+ norm.reset_parameters()
327
+ self.attention.init_weights(self.weight_init_std)
328
+ if self.moe_enabled:
329
+ self.moe.init_weights(self.weight_init_std)
330
+ else:
331
+ self.feed_forward.init_weights(self.weight_init_std)
332
+
333
+
334
+ class Transformer(nn.Module, ModelProtocol):
335
+ """
336
+ Transformer Module
337
+
338
+ Args:
339
+ model_args (TransformerModelArgs): Model configuration arguments.
340
+
341
+ Attributes:
342
+ model_args (TransformerModelArgs): Model configuration arguments.
343
+ vocab_size (int): Vocabulary size.
344
+ n_layers (int): Number of layers in the model.
345
+ tok_embeddings (ParallelEmbedding): Token embeddings.
346
+ layers (torch.nn.ModuleList): List of Transformer blocks.
347
+ norm (RMSNorm): Layer normalization for the model output.
348
+ output (ColumnParallelLinear): Linear layer for final output.
349
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
350
+
351
+ """
352
+
353
+ def __init__(self, model_args: TransformerModelArgs):
354
+ super().__init__()
355
+ self.model_args = model_args
356
+ self.vocab_size = model_args.vocab_size
357
+ self.n_layers = model_args.n_layers
358
+ self.eos_id = model_args.eos_id
359
+
360
+ self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
361
+
362
+ # TODO persistent should be set to false, since this buffer can be recomputed.
363
+ # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
364
+ # compile or pipeline-tracer will not correctly handle non-persistent buffers,
365
+ # so we need to fix that. (2) if we initialize pipeline-parallel models from
366
+ # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
367
+ # initialized by the checkpoint, or we need to add a separate initializer for
368
+ # just the non-persistent buffers that is called after loading checkpoints.
369
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
370
+
371
+ self.layers = torch.nn.ModuleDict()
372
+ for layer_id in range(model_args.n_layers):
373
+ self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
374
+
375
+ self.norm = build_norm(
376
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
377
+ )
378
+
379
+ self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
380
+ self.init_weights()
381
+
382
+ def init_weights(
383
+ self,
384
+ buffer_device: torch.device | None = None,
385
+ ):
386
+ """
387
+ [Note: On ``init_weights`` vs. ``reset_parameters``]
388
+ Modules may define ``reset_parameters`` to initialize parameter values.
389
+ ``reset_parameters`` is meant to only initialize directly owned
390
+ parameters/buffers, not those of their child modules, and it can be
391
+ used to give the initial values for these tensors.
392
+ Separately, users may want custom initialization for their modules,
393
+ different from that in ``reset_parameters``. For this, we define
394
+ ``init_weights``. We only call it in the constructor of this
395
+ ``Transformer`` root module to avoid reinitializing tensors.
396
+ """
397
+ buffer_device = buffer_device or self.freqs_cis.device
398
+ with torch.device(buffer_device):
399
+ self.freqs_cis = self._precompute_freqs_cis()
400
+ if self.tok_embeddings is not None:
401
+ nn.init.normal_(self.tok_embeddings.weight)
402
+ for layer in self.layers.values():
403
+ if layer is not None:
404
+ layer.init_weights()
405
+ if self.norm is not None:
406
+ self.norm.reset_parameters()
407
+ final_out_std = self.model_args.dim**-0.5
408
+ cutoff_factor = 3
409
+ if self.output is not None:
410
+ nn.init.trunc_normal_(
411
+ self.output.weight,
412
+ mean=0.0,
413
+ std=final_out_std,
414
+ a=-cutoff_factor * final_out_std,
415
+ b=cutoff_factor * final_out_std,
416
+ )
417
+
418
+ def _precompute_freqs_cis(self) -> torch.Tensor:
419
+ return precompute_freqs_cis(
420
+ self.model_args.dim // self.model_args.n_heads,
421
+ # Need to compute until at least the max token limit for generation
422
+ # TODO: explain in docs/composability.md why we removed the 2x
423
+ # relaxing in our CP enablement PR
424
+ self.model_args.max_seq_len,
425
+ self.model_args.rope_theta,
426
+ )
427
+
428
+ def forward(self, tokens: torch.Tensor):
429
+ """
430
+ Perform a forward pass through the Transformer model.
431
+
432
+ Args:
433
+ tokens (torch.Tensor): Input token indices.
434
+
435
+ Returns:
436
+ torch.Tensor: Output logits after applying the Transformer model.
437
+
438
+ """
439
+ # TODO: We will to change forward() signature to allow tokens to
440
+ # be always passed in.
441
+ if self.model_args.use_flex_attn:
442
+ init_attention_mask(tokens, eos_id=self.eos_id)
443
+
444
+ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
445
+ h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
446
+
447
+ for layer in self.layers.values():
448
+ h = layer(h, self.freqs_cis)
449
+
450
+ h = self.norm(h) if self.norm else h
451
+ output = self.output(h) if self.output else h
452
+ return output
453
+
454
+ @classmethod
455
+ def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
456
+ """
457
+ Initialize a Transformer model from a TransformerModelArgs object.
458
+
459
+ Args:
460
+ model_args (TransformerModelArgs): Model configuration arguments.
461
+
462
+ Returns:
463
+ Transformer: Transformer model.
464
+
465
+ """
466
+ return cls(model_args)
torchtitan/experiments/llama4/model/moe.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from .args import TransformerModelArgs
12
+
13
+
14
+ class GroupedExperts(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ hidden_dim: int,
19
+ num_experts: int,
20
+ ):
21
+ super().__init__()
22
+ self.num_experts = num_experts
23
+ self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
24
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
25
+ self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26
+
27
+ def forward(
28
+ self,
29
+ x: torch.Tensor,
30
+ num_local_tokens_per_expert: torch.Tensor | None = None,
31
+ ) -> torch.Tensor:
32
+ if num_local_tokens_per_expert is not None:
33
+ # a tuple of tensors indexed by experts
34
+ # each with shape (tokens_per_expert(varying), dim)
35
+ x = torch.split(
36
+ x,
37
+ split_size_or_sections=num_local_tokens_per_expert.tolist(),
38
+ dim=0,
39
+ )
40
+ out_experts_splits = []
41
+ for expert_idx, x_expert in enumerate(x):
42
+ w1, w2, w3 = (
43
+ self.w1[expert_idx],
44
+ self.w2[expert_idx],
45
+ self.w3[expert_idx],
46
+ )
47
+ h = F.silu(torch.matmul(x_expert, w1))
48
+ h = h * torch.matmul(x_expert, w3)
49
+ h = torch.matmul(h, w2)
50
+ # h shape (tokens_per_expert(varying), dim)
51
+ out_experts_splits.append(h)
52
+ out = torch.cat(out_experts_splits, dim=0)
53
+
54
+ # TODO:optimize with GroupedGEMM
55
+ # https://github.com/pytorch/pytorch/pull/150374
56
+ # _gouped_mm requires shapes to be multiple of 8
57
+ # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
58
+ # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
59
+ # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
60
+ # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
61
+ else:
62
+ # x shape (num_experts, tokens_per_expert, dim)
63
+ h = F.silu(torch.bmm(x, self.w1))
64
+ h = h * torch.bmm(x, self.w3)
65
+ # out shape (num_experts, tokens_per_expert, dim)
66
+ out = torch.bmm(h, self.w2)
67
+ return out
68
+
69
+ def init_weights(self, init_std: float):
70
+ nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
71
+ nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
72
+ nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
73
+
74
+
75
+ class TokenChoiceTopKRouter(nn.Module):
76
+ """This class implements token-choice routing. In token-choice top-K routing, each token is
77
+ routed to top K experts based on the router scores.
78
+
79
+ Args:
80
+ gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
81
+ dim (int): Dimension of input tokens.
82
+ num_experts (int): Number of experts in each moe layer.
83
+ top_k (int): Number of experts each token will be routed to in token-choice routing.
84
+ use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ num_experts: int,
91
+ top_k: int,
92
+ use_sigmoid: bool = False,
93
+ ):
94
+ super().__init__()
95
+ self.gate = nn.Linear(dim, num_experts, bias=False)
96
+ self.num_experts = num_experts
97
+ self.top_k = top_k
98
+ self.use_sigmoid = use_sigmoid
99
+
100
+ def forward(
101
+ self, x: torch.Tensor
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
106
+
107
+ Returns:
108
+ routed_input (torch.Tensor):
109
+ Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
110
+ token_indices (torch.Tensor):
111
+ Token indices for routed_input with shape ``(bs*slen*top_k,)``.
112
+ num_local_tokens_per_expert (torch.Tensor):
113
+ Number of tokens assigned to each expert with shape ``(num_experts,)``.
114
+ """
115
+ # scores shape (bs*slen, num_experts)
116
+ scores = self.gate(x)
117
+
118
+ # By default, sigmoid or softmax is performed in float32 to avoid loss explosion
119
+ if self.use_sigmoid:
120
+ scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
121
+ else:
122
+ scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
123
+
124
+ # top scores shape (bs*slen, top_k)
125
+ top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
126
+ # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
127
+
128
+ # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
129
+ num_local_tokens_per_expert = torch.histc(
130
+ selected_experts_indices.view(-1),
131
+ bins=self.num_experts,
132
+ min=0,
133
+ max=self.num_experts,
134
+ )
135
+ # token_indices_experts_sorted shape (bs*slen*top_k,)
136
+ token_indices_experts_sorted = torch.argsort(
137
+ selected_experts_indices.view(-1), stable=True
138
+ )
139
+ top_scores = top_scores.view(-1)[token_indices_experts_sorted]
140
+ token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
141
+
142
+ return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
143
+
144
+ def init_weights(self, init_std: float):
145
+ nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
146
+
147
+
148
+ # TODO: implement load balancing auxiliary loss for token-choice routing
149
+ class MoE(nn.Module):
150
+ def __init__(self, model_args: TransformerModelArgs):
151
+ super().__init__()
152
+ dim = model_args.dim
153
+ hidden_dim = 4 * model_args.dim
154
+ ffn_dim_multiplier = model_args.ffn_dim_multiplier
155
+ hidden_dim = int(2 * hidden_dim / 3)
156
+ if ffn_dim_multiplier is not None:
157
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
158
+
159
+ num_experts = model_args.num_experts
160
+
161
+ hidden_dim_denom = 1
162
+ if model_args.auto_scale_hidden_dim:
163
+ hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
164
+
165
+ if model_args.auto_scale_hidden_dim:
166
+ hidden_dim = int(hidden_dim / hidden_dim_denom)
167
+ hidden_dim += -hidden_dim % model_args.multiple_of
168
+
169
+ self.experts = GroupedExperts(
170
+ dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
171
+ )
172
+ self.router = TokenChoiceTopKRouter(
173
+ dim=dim, num_experts=num_experts, top_k=model_args.top_k
174
+ )
175
+ self.shared_expert = (
176
+ GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
177
+ if model_args.use_shared_expert
178
+ else None
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
185
+
186
+ Returns:
187
+ out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
188
+ """
189
+ bs, slen, dim = x.shape
190
+ # top_scores and selected_indices shape (bs*slen*top_k,)
191
+ # num_local_tokens_per_expert shape (num_experts,)
192
+ (
193
+ top_scores,
194
+ token_indices,
195
+ num_local_tokens_per_expert,
196
+ ) = self.router(x.reshape(bs * slen, dim))
197
+
198
+ # shape (bs*slen*top_k, dim)
199
+ token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
200
+
201
+ # shape (bs*slen*top_k, dim)
202
+ routed_input = torch.gather(
203
+ x.view(-1, dim),
204
+ dim=0,
205
+ index=token_indices,
206
+ )
207
+ routed_input = routed_input * top_scores.reshape(-1, 1)
208
+
209
+ # shape (bs*slen*top_k, dim)
210
+ routed_output = self.experts(routed_input, num_local_tokens_per_expert)
211
+
212
+ # shared expert
213
+ if self.shared_expert is not None:
214
+ out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
215
+ bs * slen, dim
216
+ )
217
+ else:
218
+ out = torch.zeros_like(x.reshape(bs * slen, dim))
219
+
220
+ out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
221
+ out = out.reshape(bs, slen, dim)
222
+ return out
223
+
224
+ def init_weights(self, init_std: float):
225
+ self.experts.init_weights(init_std)
226
+ self.router.init_weights(init_std)
227
+ if self.shared_expert is not None:
228
+ self.shared_expert.init_weights(init_std)
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ set -ex
9
+
10
+ # use envs as local overrides for convenience
11
+ # e.g.
12
+ # LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh
13
+ NGPU=${NGPU:-"8"}
14
+ LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
15
+ CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"}
16
+
17
+ overrides=""
18
+ if [ $# -ne 0 ]; then
19
+ overrides="$*"
20
+ fi
21
+
22
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
23
+ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
24
+ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
25
+ convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: this toml config is still under development
2
+
3
+ [job]
4
+ dump_folder = "./outputs"
5
+ description = "Llama 4 Maverick 17Bx128E training"
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 100
11
+
12
+ [metrics]
13
+ log_freq = 10
14
+ enable_tensorboard = false
15
+ save_tb_folder = "tb"
16
+
17
+ [model]
18
+ name = "llama4"
19
+ flavor = "17bx128e"
20
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
21
+ tokenizer_path = "./assets/tokenizer/tokenizer.model"
22
+ # converters = "float8"
23
+
24
+ [optimizer]
25
+ name = "AdamW"
26
+ lr = 4e-3
27
+ eps = 1e-15
28
+
29
+ [lr_scheduler]
30
+ warmup_steps = 600
31
+ lr_min = 0.1
32
+
33
+ [training]
34
+ batch_size = 1
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8
45
+ enable_async_tensor_parallel = false
46
+ pipeline_parallel_degree = 4
47
+ # pipeline_parallel_schedule = "interleaved1f1b"
48
+ # pipeline_parallel_microbatches = 2
49
+ context_parallel_degree = 1
50
+
51
+ [checkpoint]
52
+ enable_checkpoint = false
53
+ folder = "checkpoint"
54
+ interval = 500
55
+ model_weights_only = false
56
+ export_dtype = "float32"
57
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
58
+
59
+ [activation_checkpoint]
60
+ mode = 'full' # ['none', 'selective', 'full']
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+ filter_fqns = "output,router.gate"
torchtitan/experiments/multimodal/mm_collator.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from tokenizer.tiktoken import IGNORE_INDEX
16
+
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+
20
+ def padded_collate(
21
+ batch: List[Dict[str, List[int]]],
22
+ padding_idx: int = 0,
23
+ ignore_idx: int = -100,
24
+ ) -> Dict[str, torch.Tensor]:
25
+ """Pad a batch of sequences to the longest sequence length in the batch, and
26
+ convert integer lists to tensors.
27
+
28
+ Args:
29
+ batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
30
+ padding_idx (int): Padding index for input ids. Defaults to 0.
31
+ ignore_idx (int): Padding index for labels. Defaults to -100.
32
+
33
+ Returns:
34
+ Dict[str, torch.Tensor]: Collated input and label tensors.
35
+
36
+ Example:
37
+ >>> token_pairs = [
38
+ >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
39
+ >>> {"input_ids": [7,], "labels": [10,]},
40
+ >>> ]
41
+ >>> collated = padded_collate(
42
+ >>> batch=token_pairs,
43
+ >>> padding_idx=padding_idx,
44
+ >>> ignore_idx=ignore_idx,
45
+ >>> )
46
+ >>> collated["input_ids"]
47
+ >>> tensor([[1, 2, 3], [7, 0, 0]])
48
+ >>> collated["labels"]
49
+ >>> tensor([[4, 5, 6], [10, -100, -100]])
50
+ """
51
+ input_ids = pad_sequence(
52
+ [x["input_ids"] for x in batch],
53
+ batch_first=True,
54
+ padding_value=padding_idx,
55
+ )
56
+ labels = pad_sequence(
57
+ [x["labels"] for x in batch],
58
+ batch_first=True,
59
+ padding_value=ignore_idx,
60
+ )
61
+
62
+ input_ids_seq_len = input_ids.shape[-1]
63
+ labels_seq_len = labels.shape[-1]
64
+
65
+ # Hack to pad correctly and not use max_seq_len, which is costly
66
+ if input_ids_seq_len > labels_seq_len:
67
+ labels = F.pad(
68
+ labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx
69
+ )
70
+ elif labels_seq_len > input_ids_seq_len:
71
+ input_ids = F.pad(
72
+ input_ids,
73
+ (0, labels_seq_len - input_ids_seq_len),
74
+ value=padding_idx,
75
+ )
76
+ return {"input_ids": input_ids, "labels": labels}
77
+
78
+
79
+ # NOTE Inspired from torchtune.data._collate.py
80
+ @dataclass
81
+ class MultiModalCollator:
82
+ padding_idx: int = 128004
83
+ ignore_idx: int = IGNORE_INDEX
84
+ pad_max_tiles: Optional[int] = None
85
+ pad_max_images: Optional[int] = None
86
+
87
+ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
88
+ """Pad a batch of text sequences, tiled image tensors, aspect ratios,
89
+ and cross attention masks. This can be used for both training and inference.
90
+
91
+ ``batch`` is expected to be a list of sample dicts containing the following::
92
+ - "input_ids": List[int] of length text_seq_len, varies across samples
93
+ - "labels": List[int] of length text_seq_len, varies across samples
94
+ - "encoder_input": Dict[str, List[torch.Tensor]]
95
+ - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
96
+ - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
97
+
98
+ Shape notation:
99
+ - c = channel dim
100
+ - h = height dim
101
+ - w = weight dim
102
+
103
+ Note:
104
+ For each element in the batch, ``len(images) == len(aspect_ratio)``.
105
+
106
+ This collater does the following:
107
+ (1) Pad text sequence and encoder mask to the longest sequence length in the batch
108
+ (2) Pad image tensors in the tile dimension with zeros to the largest number
109
+ of tiles in the batch
110
+ (3) Add empty images of zeros to samples up to max number of images in the batch
111
+ (4) Pad aspect ratios with (1,1) for all added padding images
112
+
113
+ Args:
114
+ batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids,
115
+ labels, images, and aspect_ratio.
116
+ padding_idx (int): Padding index for input token ids. Defaults to 0.
117
+ ignore_idx (int): Padding index for labels. Defaults to -100.
118
+ pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
119
+ in the batch. Defaults to None.
120
+ pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
121
+ in the batch. Defaults to None.
122
+
123
+ Returns:
124
+ Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors.
125
+ - tokens: Tensor of shape (bsz, max_seq_len)
126
+ - labels: Tensor of shape (bsz, max_seq_len)
127
+ - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
128
+ - aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
129
+
130
+ Example:
131
+ >>> image_id = 1
132
+ >>> tokens_per_tile = 5
133
+ >>> c, h, w = 1, 1, 1
134
+ >>> batch = [
135
+ ... {
136
+ ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
137
+ ... "encoder_input": {
138
+ ... # One image with two tiles, one image with three tiles
139
+ ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
140
+ ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
141
+ ... },
142
+ ... },
143
+ ... {
144
+ ... "input_ids": [1, 4], "labels": [8, 9],
145
+ ... "encoder_input": {
146
+ ... # One image with four tiles
147
+ ... "images": [torch.ones(4, c, h, w)],
148
+ ... "aspect_ratio": [torch.tensor([2, 2])],
149
+ ... },
150
+ ... },
151
+ ... ]
152
+ ... collator = MultiModalCollator(pad_max_tiles=4)
153
+ >>> model_inputs = collator(batch=batch)
154
+ >>> print(model_inputs["input_ids"])
155
+ tensor([[1, 2, 1, 3],
156
+ [1, 4, 0, 0]])
157
+ >>> print(model_inputs["labels"])
158
+ tensor([[4, 5, 6, 7],
159
+ [8, 9, -100, -100]])
160
+ >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
161
+ torch.Size([2, 2, 4, 1, 1, 1])
162
+ >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
163
+ torch.Size([2, 2, 2])
164
+ >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
165
+ tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
166
+ >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
167
+ tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
168
+ >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
169
+ tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
170
+ >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
171
+ tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
172
+ """
173
+ # Text tokens can be handled independently by existing collaters
174
+ text_only = [
175
+ {"input_ids": sample["input_ids"], "labels": sample["labels"]}
176
+ for sample in batch
177
+ ]
178
+ collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx)
179
+
180
+ if self.pad_max_tiles is None:
181
+ # Get max number of tiles in batch
182
+ max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch)
183
+ else:
184
+ max_num_tiles = self.pad_max_tiles
185
+
186
+ # Pad images and aspect ratios to max number of tiles
187
+ batch_images = []
188
+ batch_aspect_ratios = []
189
+
190
+ for sample in batch:
191
+ sample_images = []
192
+ for image in sample["encoder_input"]["images"]:
193
+ # Single image in each sample has shape (n_tiles, c, h, w)
194
+ n_tiles = image.shape[0]
195
+ # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
196
+ # where image_seq_len = n_tiles * tokens_per_tile
197
+ padding_tiles = max_num_tiles - n_tiles
198
+
199
+ # Image should now have shape (max_num_tiles, c, h, w)
200
+ padded_image = F.pad(
201
+ image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0
202
+ )
203
+
204
+ sample_images.append(padded_image)
205
+ # Stack multiple images and masks per sample in num_images dimension
206
+ batch_images.append(torch.stack(sample_images))
207
+ batch_aspect_ratios.append(
208
+ torch.stack(sample["encoder_input"]["aspect_ratio"])
209
+ )
210
+ # Finally, pad images, masks, aspect ratios to max number of images in batch
211
+ # (bsz, max_num_images, max_num_tiles, c, h, w)
212
+ collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
213
+ # (bsz, max_num_images, 2)
214
+ collated_aspect_ratios = pad_sequence(
215
+ batch_aspect_ratios, batch_first=True, padding_value=1
216
+ )
217
+
218
+ batch_dict = {
219
+ "input_ids": collated_text["input_ids"],
220
+ "labels": collated_text["labels"],
221
+ "encoder_input": {
222
+ "images": collated_images,
223
+ "aspect_ratio": collated_aspect_ratios,
224
+ },
225
+ }
226
+
227
+ return batch_dict
torchtitan/experiments/multimodal/mm_dataset.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from datasets import Dataset, load_dataset
13
+ from datasets.distributed import split_dataset_by_node
14
+
15
+ from mm_collator import MultiModalCollator
16
+ from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer
17
+ from torch.distributed.checkpoint.stateful import Stateful
18
+ from torch.utils.data import IterableDataset
19
+ from transform import CLIPTransform
20
+ from utils import load_image
21
+
22
+ from torchtitan.components.dataloader import ParallelAwareDataloader
23
+ from torchtitan.config_manager import JobConfig
24
+ from torchtitan.tools.logging import logger
25
+
26
+
27
+ def _load_obelics_dataset(dataset_path: str):
28
+ """Load C4 dataset with default configuration."""
29
+ return load_dataset(dataset_path, split="train", streaming=True)
30
+
31
+
32
+ def _process_obelics_sample(
33
+ sample: dict[str, Any], image_token: str = "<|image|>"
34
+ ) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]:
35
+ """
36
+ This function formats samples from the OBELICS dataset
37
+ Returns:
38
+ Dict[str, Any]: The transformed sample with the following fields:
39
+ - images: List[PIL.Image.Image] with the loaded images
40
+ - text: str with the text of the sample ready to be tokenized including the image tokens
41
+ Example:
42
+ >>> formatted_sample = format_obelics(sample, image_token="<|image|>")
43
+ >>> print(formatted_sample["text"])
44
+ ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :("
45
+ """
46
+ sample_images = [image for image in sample["images"] if image is not None]
47
+ sample_text = [
48
+ text if text is not None else image_token for text in sample["texts"]
49
+ ]
50
+ return {
51
+ "images": [load_image(image) for image in sample_images],
52
+ "text": "".join(map(str, sample_text)),
53
+ }
54
+
55
+
56
+ @dataclass
57
+ class DatasetConfig:
58
+ path: str
59
+ loader: Callable
60
+ sample_processor: Callable
61
+
62
+
63
+ # Add your dataset here here - more information at docs/datasets.md
64
+ MM_DATASETS = {
65
+ "obelics": DatasetConfig(
66
+ path="HuggingFaceM4/OBELICS",
67
+ loader=_load_obelics_dataset,
68
+ sample_processor=_process_obelics_sample,
69
+ ),
70
+ }
71
+
72
+
73
+ def _validate_mm_dataset(
74
+ dataset_name: str, dataset_path: str = None
75
+ ) -> tuple[str, Callable, Callable]:
76
+ """Validate dataset name and path."""
77
+ if dataset_name not in MM_DATASETS:
78
+ raise ValueError(
79
+ f"Dataset {dataset_name} is not supported. "
80
+ f"Supported datasets are: {list(MM_DATASETS.keys())}"
81
+ )
82
+
83
+ config = MM_DATASETS[dataset_name]
84
+ path = dataset_path or config.path
85
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
86
+ return path, config.loader, config.sample_processor
87
+
88
+
89
+ class MultiModalDataset(IterableDataset, Stateful):
90
+ """PyTorch MultiModal Dataset.
91
+
92
+ Args:
93
+ dataset_name (str): name of the dataset to load
94
+ tokenizer (Tokenizer):
95
+ Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
96
+ world_size (int): number of data parallel processes participating in training
97
+ rank (int): rank of the current data parallel process
98
+ infinite (bool): whether to loop infinitely over the dataset
99
+
100
+ We currently ONLY support the OBELICS dataset
101
+
102
+ Example use:
103
+ >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer)
104
+ >>> for batch in Dataloader(ds, batch_size=8):
105
+ print(f"Batch size: {len(batch)}")
106
+ Batch size: 8
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ dataset_name: str,
112
+ dataset_path: Optional[str],
113
+ tokenizer: Tokenizer,
114
+ image_token: str = "<|image|>",
115
+ tile_size: int = 448,
116
+ max_num_tiles: int = 4,
117
+ seq_len: int = 2048,
118
+ dp_rank: int = 0,
119
+ dp_world_size: int = 1,
120
+ infinite: bool = False,
121
+ ) -> None:
122
+ # Force lowercase for consistent comparison
123
+ dataset_name = dataset_name.lower()
124
+
125
+ path, dataset_loader, sample_processor = _validate_mm_dataset(
126
+ dataset_name, dataset_path
127
+ )
128
+ ds = dataset_loader(path)
129
+
130
+ # TODO: support shuffling
131
+ self.dataset_name = dataset_name
132
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
133
+ self._tokenizer = tokenizer
134
+ self.seq_len = seq_len
135
+ self.infinite = infinite
136
+ self._sample_processor = sample_processor
137
+ self.image_token = (
138
+ image_token # TODO(tj.solergibert) Add `image_token` to JobConfig
139
+ )
140
+ # TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig
141
+ self.transform_image = CLIPTransform(
142
+ image_mean=(
143
+ 0.48145466,
144
+ 0.4578275,
145
+ 0.40821073,
146
+ ), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?,
147
+ image_std=(0.26862954, 0.26130258, 0.27577711),
148
+ tile_size=tile_size,
149
+ possible_resolutions=None,
150
+ max_num_tiles=max_num_tiles,
151
+ resample="bilinear",
152
+ resize_to_max_canvas=False,
153
+ )
154
+
155
+ # variables for checkpointing
156
+ self._sample_idx = 0
157
+
158
+ def __iter__(self):
159
+
160
+ while True:
161
+ for sample in self._get_data_iter():
162
+ try:
163
+ sample = self._sample_processor(
164
+ sample, image_token=self.image_token
165
+ )
166
+ except Exception:
167
+ continue
168
+ self._sample_idx += 1
169
+
170
+ # CLIP Transform
171
+ encoder_input = {"images": [], "aspect_ratio": []}
172
+ for image in sample["images"]:
173
+ out = self.transform_image(image)
174
+ encoder_input["images"].append(out["image"])
175
+ encoder_input["aspect_ratio"].append(out["aspect_ratio"])
176
+ sample["encoder_input"] = encoder_input
177
+
178
+ # Tokenize
179
+ tokens = self._tokenizer.encode(
180
+ sample["text"],
181
+ bos=True,
182
+ eos=True,
183
+ allowed_special=set(["<|image|>"]),
184
+ )
185
+ sample["input_ids"] = torch.LongTensor(tokens[:-1])
186
+ sample["labels"] = torch.LongTensor(tokens[1:])
187
+ # Mask BOS, EOS & image tokens from the loss
188
+ sample["labels"] = torch.where(
189
+ torch.isin(
190
+ sample["labels"],
191
+ torch.LongTensor(
192
+ [
193
+ self._tokenizer.bos_id,
194
+ self._tokenizer.eos_id,
195
+ self._tokenizer.image_id,
196
+ ]
197
+ ),
198
+ ),
199
+ IGNORE_INDEX,
200
+ sample["labels"],
201
+ )
202
+ # Truncate
203
+ sample["input_ids"], sample["labels"] = (
204
+ sample["input_ids"][: self.seq_len],
205
+ sample["labels"][: self.seq_len],
206
+ )
207
+ yield sample
208
+
209
+ if not self.infinite:
210
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
211
+ break
212
+ else:
213
+ # Reset offset for the next iteration
214
+ self._sample_idx = 0
215
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
216
+
217
+ def _get_data_iter(self):
218
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
219
+ return iter([])
220
+
221
+ it = iter(self._data)
222
+ for _ in range(self._sample_idx):
223
+ next(it)
224
+ return it
225
+
226
+ def load_state_dict(self, state_dict):
227
+ self._sample_idx = state_dict["sample_idx"]
228
+
229
+ def state_dict(self):
230
+ return {"sample_idx": self._sample_idx}
231
+
232
+
233
+ def build_mm_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ tokenizer: Tokenizer,
237
+ job_config: JobConfig,
238
+ infinite: bool = True,
239
+ ) -> ParallelAwareDataloader:
240
+ """Build a data loader for HuggingFace datasets."""
241
+ dataset_name = job_config.training.dataset
242
+ dataset_path = job_config.training.dataset_path
243
+ batch_size = job_config.training.batch_size
244
+ seq_len = job_config.training.seq_len
245
+ pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig
246
+ padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig
247
+
248
+ hf_ds = MultiModalDataset(
249
+ dataset_name=dataset_name,
250
+ dataset_path=dataset_path,
251
+ tokenizer=tokenizer,
252
+ seq_len=seq_len,
253
+ dp_rank=dp_rank,
254
+ dp_world_size=dp_world_size,
255
+ infinite=infinite,
256
+ )
257
+
258
+ collate_fn = MultiModalCollator(
259
+ padding_idx=padding_idx, pad_max_tiles=pad_max_tiles
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=hf_ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ collate_fn=collate_fn,
268
+ )
torchtitan/experiments/multimodal/tests/test_multimodal_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ from torchtitan.experiments.llama_multimodal import (
11
+ ModelArgs,
12
+ MultimodalDecoder,
13
+ VisionEncoder,
14
+ )
15
+
16
+ from .test_utils import fixed_init_model, fixed_init_tensor
17
+
18
+
19
+ @pytest.fixture
20
+ def encoder_config():
21
+ return ModelArgs(
22
+ encoder_embed_dim=32,
23
+ encoder_num_layers=2,
24
+ encoder_num_heads=4,
25
+ tile_size=49,
26
+ patch_size=9,
27
+ max_num_tiles=4,
28
+ in_channels=3,
29
+ return_intermediates=[0, 1],
30
+ num_layers_projection=2,
31
+ decoder_embed_dim=128,
32
+ )
33
+
34
+
35
+ @pytest.fixture
36
+ def decoder_config():
37
+ return ModelArgs(
38
+ decoder_embed_dim=512,
39
+ vocab_size=10000,
40
+ fusion_interval=2,
41
+ num_special_tokens=3,
42
+ decoder_num_layers=6,
43
+ decoder_num_heads=8,
44
+ decoder_num_kv_heads=4,
45
+ max_seq_len=512,
46
+ rope_theta=50000.0,
47
+ )
48
+
49
+
50
+ class TestMultimodalModelVisionEncoder:
51
+ @pytest.fixture(autouse=True)
52
+ def setup_class(self, encoder_config):
53
+ self.model_args = encoder_config
54
+ self.batch_size = 1
55
+ self.num_imgs = 2
56
+ self.num_tiles = 4
57
+ self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape(
58
+ self.batch_size, self.num_imgs, 2
59
+ )
60
+ image = torch.rand(
61
+ (
62
+ self.batch_size,
63
+ self.num_imgs,
64
+ self.num_tiles,
65
+ self.model_args.in_channels,
66
+ self.model_args.tile_size,
67
+ self.model_args.tile_size,
68
+ )
69
+ )
70
+ self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1)
71
+
72
+ def test_llama_mm_vision_encoder(self):
73
+ model = VisionEncoder(self.model_args)
74
+ fixed_init_model(model, min_val=-1, max_val=1)
75
+ output = model(self.image, self.aspect_ratio)
76
+ expected_shape = (
77
+ self.batch_size,
78
+ self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1),
79
+ self.model_args.decoder_embed_dim,
80
+ )
81
+ assert (
82
+ output.shape == expected_shape
83
+ ), f"Expected shape {expected_shape}, but got {output.shape}"
84
+
85
+ # TODO: Need to ensure numerical stability before doing convergence test.
86
+ # output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is
87
+ # the test value from the original torch tune test
88
+ # assert torch.allclose(
89
+ # output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3
90
+ # )
91
+
92
+
93
+ class TestMultimodalModelDecoder:
94
+ @pytest.fixture(autouse=True)
95
+ def setup_class(self, decoder_config):
96
+ self.model_args = decoder_config
97
+ self.batch_size = 1
98
+ self.decoder_embed_dim = self.model_args.decoder_embed_dim
99
+ self.vocab_size = self.model_args.vocab_size
100
+ self.seq_len = 128
101
+ self.input = {
102
+ "tokens": torch.arange(self.batch_size * self.seq_len).reshape(
103
+ self.batch_size, self.seq_len
104
+ ),
105
+ "encoder_input": fixed_init_tensor(
106
+ (self.batch_size, self.seq_len, self.decoder_embed_dim),
107
+ min_val=-1,
108
+ max_val=1,
109
+ ),
110
+ "encoder_mask": None,
111
+ }
112
+
113
+ @torch.no_grad()
114
+ def test_llama_mm_decoder(self):
115
+ model = MultimodalDecoder(self.model_args)
116
+ fixed_init_model(model, min_val=-1, max_val=1)
117
+ output = model(**self.input)
118
+ expected_shape = (self.batch_size, self.seq_len, self.vocab_size)
119
+ assert (
120
+ output.shape == expected_shape
121
+ ), f"Expected shape {expected_shape}, but got {output.shape}"
122
+
123
+ # TODO: Need to ensure numerical stability before doing convergence test.
124
+ # output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is
125
+ # the test value from the original torch tune test
126
+ # assert torch.allclose(
127
+ # output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3
128
+ # )
torchtitan/experiments/multimodal/utils.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from collections import defaultdict
10
+
11
+ from pathlib import Path
12
+ from typing import List, Optional, Set, Tuple, Union
13
+ from urllib import request
14
+
15
+ import torch
16
+ import torchvision
17
+ from torchvision.transforms.v2 import functional as F
18
+
19
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
20
+ def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
21
+ """
22
+ Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
23
+
24
+ Args:
25
+ image (torch.Tensor): Input image to crop into tiles.
26
+ tile_size (int): Size of each tile.
27
+
28
+ Returns:
29
+ torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
30
+
31
+ Examples:
32
+ >>> image = torch.rand(3, 200, 300)
33
+ >>> tiles = tile_crop(image, tile_size=50)
34
+ >>> tiles.shape # 4x6 = 24 tiles
35
+ torch.Size([24, 3, 50, 50])
36
+
37
+ >>> image = torch.rand(3, 400, 600)
38
+ >>> tiles = tile_crop(image, tile_size=200)
39
+ >>> tiles.shape # 2x3 = 6 tiles
40
+ torch.Size([6, 3, 200, 200])
41
+ """
42
+
43
+ channel_size, height, width = image.shape
44
+
45
+ # assert sizes are divisible
46
+ assert (
47
+ height % tile_size == 0 and width % tile_size == 0
48
+ ), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
49
+
50
+ # Reshape to split height and width into tile_size blocks
51
+ tiles_height = height // tile_size
52
+ tiles_width = width // tile_size
53
+
54
+ reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
55
+
56
+ # Transpose to bring tiles together
57
+ # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
58
+ transposed = reshaped.permute(1, 3, 0, 2, 4)
59
+
60
+ # Flatten the tiles
61
+ tiles = transposed.contiguous().view(
62
+ tiles_height * tiles_width, channel_size, tile_size, tile_size
63
+ )
64
+
65
+ return tiles
66
+
67
+
68
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
69
+ def resize_with_pad(
70
+ image: torch.Tensor,
71
+ target_size: Tuple[int, int],
72
+ resample: torchvision.transforms.InterpolationMode,
73
+ max_size: Optional[int] = None,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Resizes and pads an image to target_size without causing distortion.
77
+ The user can set max_size to limit upscaling when target_size exceeds image_size.
78
+
79
+ Args:
80
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
81
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
82
+ resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
83
+ Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
84
+ InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
85
+ max_size (Optional[int]): The maximum size to upscale the image to.
86
+ If None, will upscale up to target_size.
87
+
88
+ Returns:
89
+ torch.Tensor: The resized and padded image tensor in the format [..., H, W].
90
+
91
+ Examples:
92
+
93
+ Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
94
+ and then padded from (448, 1194) to (448, 1344).
95
+
96
+ >>> max_size = None
97
+ >>> image = torch.rand([3, 300, 800])
98
+ >>> target_size = (448, 1344)
99
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
100
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
101
+
102
+ Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
103
+
104
+ >>> max_size = 600
105
+ >>> image = torch.rand([3, 300, 800])
106
+ >>> target_size = (448, 1344)
107
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
108
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
109
+
110
+ Example 3: The image will be downscaled from (500, 1000) to (224, 448),
111
+ and padded from (224, 448) to (448, 448).
112
+
113
+ >>> max_size = 600
114
+ >>> image = torch.rand([3, 500, 1000])
115
+ >>> target_size = (448, 488)
116
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
117
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
118
+
119
+ """
120
+
121
+ image_height, image_width = image.shape[-2:]
122
+ image_size = (image_height, image_width)
123
+
124
+ # If target_size requires upscaling, we might want to limit the upscaling to max_size
125
+ if max_size is not None:
126
+ new_target_height = min(max(image_height, max_size), target_size[0])
127
+ new_target_width = min(max(image_width, max_size), target_size[1])
128
+ target_size_resize = (new_target_height, new_target_width)
129
+ else:
130
+ target_size_resize = target_size
131
+
132
+ # resize to target_size while preserving aspect ratio
133
+ new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
134
+ image_size=image_size,
135
+ target_size=target_size_resize,
136
+ )
137
+
138
+ image = F.resize(
139
+ inpt=image,
140
+ size=list(new_size_preserving_aspect_ratio),
141
+ interpolation=resample,
142
+ antialias=True,
143
+ )
144
+
145
+ image = _pad_image_top_left(image=image, target_size=target_size)
146
+
147
+ return image
148
+
149
+
150
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
151
+ def _pad_image_top_left(
152
+ image: torch.Tensor,
153
+ target_size: Tuple[int, int],
154
+ ) -> torch.Tensor:
155
+ """
156
+ Places the image at the top left of the canvas and pads with 0 the right and bottom
157
+ to fit to the target resolution. If target_size < image_size, it will crop the image.
158
+
159
+ Args:
160
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
161
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
162
+
163
+ Returns:
164
+ torch.Tensor: The padded image tensor in the format [..., H, W].
165
+ """
166
+
167
+ image_size = image.shape[-2:]
168
+
169
+ height, width = image_size
170
+ target_height, target_width = target_size
171
+
172
+ pad_x = target_width - width
173
+ pad_y = target_height - height
174
+
175
+ padding = [0, 0, pad_x, pad_y]
176
+ return F.pad(inpt=image, padding=padding)
177
+
178
+
179
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
180
+ def _get_max_res_without_distortion(
181
+ image_size: Tuple[int, int],
182
+ target_size: Tuple[int, int],
183
+ ) -> Tuple[int, int]:
184
+ """
185
+ Determines the maximum resolution to which an image can be resized to without distorting its
186
+ aspect ratio, based on the target resolution.
187
+
188
+ For example, if image_size = (200,400) and target_size = (600,800),
189
+ scale_h = 600/200 = 3
190
+ scale_w = 800/400 = 2
191
+ So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
192
+
193
+ Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
194
+
195
+ Args:
196
+ image_size (Tuple[int, int]): The original resolution of the image.
197
+ target_size (Tuple[int, int]): The desired resolution to fit the image into.
198
+ Returns:
199
+ Tuple[int, int]: The optimal dimensions to which the image should be resized.
200
+ Examples:
201
+ >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
202
+ (133, 200)
203
+ >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
204
+ (450, 337)
205
+ """
206
+
207
+ original_height, original_width = image_size
208
+ target_height, target_width = target_size
209
+
210
+ scale_w = target_width / original_width
211
+ scale_h = target_height / original_height
212
+
213
+ if scale_w < scale_h:
214
+ new_width = target_width
215
+ new_height = min(math.floor(original_height * scale_w), target_height)
216
+ else:
217
+ new_height = target_height
218
+ new_width = min(math.floor(original_width * scale_h), target_width)
219
+
220
+ return new_height, new_width
221
+
222
+
223
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
224
+ def _get_factors(n: int) -> Set[int]:
225
+ """
226
+ Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
227
+
228
+ Args:
229
+ n (int): The number to find factors for.
230
+
231
+ Returns:
232
+ set: A set containing all factors of the number.
233
+
234
+ Examples:
235
+ >>> _get_factors(n=12)
236
+ {1, 2, 3, 4, 6, 12}
237
+ """
238
+ factors_set = set()
239
+
240
+ for i in range(1, int(n**0.5) + 1):
241
+ if n % i == 0:
242
+ factors_set.add(i)
243
+ factors_set.add(n // i)
244
+ return factors_set
245
+
246
+
247
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
248
+ def get_canvas_best_fit(
249
+ image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
250
+ ) -> Tuple[int, int]:
251
+ """
252
+ Determines the best canvas possible from a list of possible resolutions to
253
+ resize an image to, without distortion.
254
+
255
+ For each possible resolution, calculates the scaling factors for
256
+ width and height, and selects the smallest one, which is the limiting side.
257
+ E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
258
+ then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
259
+
260
+ If there are multiple canvases that satisfy the conditions,
261
+ we pick the one with the lowest area to minimize padding.
262
+
263
+ Args:
264
+ image (torch.Tensor): The image we want to fit into a canvas.
265
+ possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
266
+ row represents a possible canvas.
267
+ resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
268
+ If False, pick the canvas that minimizes downscaling, including no downscaling at all.
269
+
270
+ Returns:
271
+ Tuple[int, int]: The best resolution to fit the image into.
272
+
273
+ Examples:
274
+ >>> image = torch.rand(3, 200, 300)
275
+ >>> possible_resolutions = torch.tensor([
276
+ ... [224, 672],
277
+ ... [672, 224],
278
+ ... [224, 448],
279
+ ... [448, 224],
280
+ ... [224, 224]
281
+ ... ])
282
+ >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
283
+ (224, 448)
284
+
285
+ In the example above, we calculate the scaling factors for each possible resolution
286
+
287
+ >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
288
+ >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
289
+ >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
290
+
291
+ Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
292
+
293
+ >>> upscaling_options = torch.tensor([1.1200, 1.1200])
294
+ >>> selected_scale = torch.tensor(1.1200)
295
+
296
+ There are two possible options, so we pick the one with the smallest area
297
+
298
+ >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
299
+ >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
300
+ """
301
+
302
+ original_height, original_width = image.shape[-2:]
303
+
304
+ # possible resolutions heights/widths
305
+ target_heights, target_widths = (
306
+ possible_resolutions[:, 0],
307
+ possible_resolutions[:, 1],
308
+ )
309
+
310
+ # scaling factors to resize the image without distortion
311
+ scale_w = target_widths / original_width
312
+ scale_h = target_heights / original_height
313
+
314
+ # get limiting side scaling -> no distortion
315
+ scales = torch.where(scale_w > scale_h, scale_h, scale_w)
316
+
317
+ # filter only scales that allow upscaling
318
+ upscaling_options = scales[scales >= 1]
319
+ if len(upscaling_options) > 0:
320
+ if resize_to_max_canvas:
321
+ selected_scale = torch.max(upscaling_options)
322
+ else:
323
+ selected_scale = torch.min(upscaling_options)
324
+ else:
325
+ # no upscaling possible,
326
+ # get the minimum downscaling (max scale for scales<1)
327
+ downscaling_options = scales[scales < 1]
328
+ selected_scale = torch.max(downscaling_options)
329
+
330
+ # get all resolutions that support this scaling factor,
331
+ # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
332
+ chosen_canvas = possible_resolutions[scales == selected_scale]
333
+
334
+ # if there are multiple resolutions,
335
+ # get the one with minimum area to reduce padding
336
+ if len(chosen_canvas) > 1:
337
+ areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
338
+ optimal_idx = torch.argmin(areas)
339
+ optimal_canvas = chosen_canvas[optimal_idx]
340
+ else:
341
+ optimal_canvas = chosen_canvas[0]
342
+
343
+ return tuple(optimal_canvas.tolist())
344
+
345
+
346
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
347
+ def find_supported_resolutions(
348
+ max_num_tiles: int, tile_size: int
349
+ ) -> List[Tuple[int, int]]:
350
+ """
351
+ Computes all combinations of resolutions, multiple of tile_size,
352
+ that contain up to max_num_tiles. Useful for when dividing an image into tiles.
353
+
354
+ For example, if we want at most 2 tiles per image, then we can support the
355
+ following resolutions: (1x1, 1x2, 2x1) * tile_size
356
+
357
+ Args:
358
+ max_num_tiles (int): Maximum number of tiles.
359
+ tile_size (int): Size of the side of the tile.
360
+
361
+ Returns:
362
+ List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
363
+
364
+ Examples:
365
+
366
+ >>> max_num_tiles = 4
367
+ >>> tile_size = 224
368
+ >>> find_supported_resolutions(max_num_tiles, tile_size)
369
+ [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
370
+ """
371
+
372
+ # create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
373
+ # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
374
+ asp_dict = defaultdict(list)
375
+ for _tile_size in range(max_num_tiles, 0, -1):
376
+ factors = sorted(_get_factors(_tile_size))
377
+ asp_ratios = [(factor, _tile_size // factor) for factor in factors]
378
+ for height, width in asp_ratios:
379
+ ratio_float = height / width
380
+ asp_dict[ratio_float].append((height, width))
381
+
382
+ # get the resolutions multiplied by the tile_size
383
+ possible_resolutions = []
384
+ for ar, resolution in asp_dict.items():
385
+ for height, width in resolution:
386
+ possible_resolutions.append((height * tile_size, width * tile_size))
387
+
388
+ return possible_resolutions
389
+
390
+
391
+ # NOTE Copied from torchtune.data._utils.py
392
+ def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
393
+ """
394
+ Convenience method to load an image in torch.Tensor format from a local file path or remote source.
395
+
396
+ Args:
397
+ image_loc (Union[Path, str]): Local file path or remote source pointing to the image
398
+ which will be loaded in PIL format.
399
+
400
+ Note:
401
+ If loading an image from a remote source, the function expects the URL provided in ``image_loc``
402
+ to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
403
+
404
+ Raises:
405
+ ValueError: If the image cannot be loaded from remote source, **or**
406
+ if the image cannot be opened as a :class:`~torch.Tensor`.
407
+
408
+ Examples:
409
+ >>> # Load from remote source
410
+ >>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
411
+
412
+ >>> # Load from local file path
413
+ >>> image = load_image(Path("/home/user/bird.jpg"))
414
+
415
+ Returns:
416
+ torch.Tensor: The loaded image.
417
+ """
418
+
419
+ # If pointing to remote source, try to load to local
420
+ if isinstance(image_loc, str) and image_loc.startswith("http"):
421
+ try:
422
+ image_loc = request.urlopen(image_loc).read()
423
+ image = torchvision.io.decode_image(
424
+ torch.frombuffer(image_loc, dtype=torch.uint8),
425
+ mode="RGB",
426
+ )
427
+ except Exception as e:
428
+ raise ValueError("Failed to load remote image as torch.Tensor") from e
429
+
430
+ # Open the local image as a Tensor image
431
+ else:
432
+ try:
433
+ image = torchvision.io.decode_image(image_loc, mode="RGB")
434
+ except Exception as e:
435
+ raise ValueError("Failed to load local image as torch.Tensor") from e
436
+
437
+ return image
torchtitan/experiments/simple_fsdp/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from torchtitan.components.loss import build_cross_entropy_loss
10
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
11
+ from torchtitan.components.optimizer import build_optimizers
12
+ from torchtitan.datasets.hf_datasets import build_hf_dataloader
13
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
14
+ from torchtitan.models.llama3 import llama3_configs, pipeline_llama
15
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16
+
17
+ from .model import SimpleFSDPTransformer
18
+ from .parallelize_llama import parallelize_llama
19
+
20
+ register_train_spec(
21
+ TrainSpec(
22
+ name="llama3_simple_fsdp",
23
+ cls=SimpleFSDPTransformer,
24
+ config=llama3_configs,
25
+ parallelize_fn=parallelize_llama,
26
+ pipelining_fn=pipeline_llama,
27
+ build_optimizers_fn=build_optimizers,
28
+ build_lr_schedulers_fn=build_lr_schedulers,
29
+ build_dataloader_fn=build_hf_dataloader,
30
+ build_tokenizer_fn=build_tiktoken_tokenizer,
31
+ build_loss_fn=build_cross_entropy_loss,
32
+ )
33
+ )
torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-311.pyc ADDED
Binary file (2.67 kB). View file
 
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-311.pyc ADDED
Binary file (7.22 kB). View file
 
torchtitan/experiments/simple_fsdp/model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torchtitan.models.llama3 import Transformer, TransformerModelArgs
8
+ from .simple_fsdp import disable_data_parallel
9
+
10
+
11
+ class SimpleFSDPTransformer(Transformer):
12
+ def __init__(self, model_args: TransformerModelArgs):
13
+ super().__init__(model_args)
14
+ self.init_weights()
15
+
16
+ def init_weights(self, *args, **kwargs):
17
+ with disable_data_parallel():
18
+ super().init_weights(*args, **kwargs)
torchtitan/experiments/simple_fsdp/parallelize_llama.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from torch.distributed import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+ from torchtitan.models.llama3.parallelize_llama import apply_ac
15
+ from torchtitan.tools.logging import logger
16
+
17
+ from .simple_fsdp import data_parallel, MixedPrecisionPolicy
18
+
19
+
20
+ def parallelize_llama(
21
+ model: nn.Module,
22
+ world_mesh: DeviceMesh,
23
+ parallel_dims: ParallelDims,
24
+ job_config: JobConfig,
25
+ ):
26
+ """
27
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
28
+ parallelism to the model.
29
+
30
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
31
+ the model must fit on GPU or CPU memory.
32
+ """
33
+ # TODO(ruisizhang123): Add support for TP (on-going)
34
+ # if parallel_dims.tp_enabled:
35
+ # if (
36
+ # job_config.parallelism.enable_async_tensor_parallel
37
+ # and not job_config.training.compile
38
+ # ):
39
+ # raise RuntimeError("Async TP requires --training.compile")
40
+
41
+ # enable_float8_linear = "float8" in job_config.model.converters
42
+ # float8_is_rowwise = job_config.float8.recipe_name in (
43
+ # "rowwise",
44
+ # "rowwise_with_gw_hp",
45
+ # )
46
+
47
+ # # For now, float8 all-gather with TP is only supported for tensorwise
48
+ # # float8 scaling recipes. For rowwise recipes, we use regular TP and
49
+ # # all-gather happens in high precision.
50
+ # enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
51
+
52
+ # apply_tp(
53
+ # model,
54
+ # world_mesh["tp"],
55
+ # loss_parallel=parallel_dims.loss_parallel_enabled,
56
+ # enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
57
+ # enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
58
+ # )
59
+
60
+ if job_config.activation_checkpoint.mode != "none":
61
+ apply_ac(model, job_config.activation_checkpoint)
62
+
63
+ # apply data parallel
64
+ if (
65
+ parallel_dims.dp_replicate_enabled
66
+ or parallel_dims.dp_shard_enabled
67
+ or parallel_dims.cp_enabled
68
+ ):
69
+ if parallel_dims.dp_replicate_enabled:
70
+ if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
71
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
72
+ dp_mode = "hybrid_shard"
73
+ else:
74
+ dp_mesh_dim_names = ("dp_replicate",)
75
+ dp_mode = "replicate"
76
+ else:
77
+ dp_mesh_dim_names = ("dp_shard_cp",)
78
+ dp_mode = "fully_shard"
79
+
80
+ mp_policy = MixedPrecisionPolicy(
81
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
82
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
83
+ )
84
+
85
+ model = data_parallel(
86
+ model,
87
+ world_mesh[tuple(dp_mesh_dim_names)],
88
+ mode=dp_mode,
89
+ ac_mode=job_config.activation_checkpoint.mode,
90
+ mp_policy=mp_policy,
91
+ )
92
+ logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
93
+
94
+ if job_config.training.compile:
95
+ torch._inductor.config.reorder_for_peak_memory = False
96
+ model = torch.compile(model, fullgraph=True)
97
+
98
+ return model
torchtitan/models/llama3/train_configs/llama3_8b.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+ # NOTE: this toml config is a preset for 64 A100 GPUs.
3
+
4
+ [job]
5
+ dump_folder = "./outputs"
6
+ description = "Llama 3 8B training"
7
+
8
+ [profiling]
9
+ enable_profiling = true
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 100
12
+
13
+ [metrics]
14
+ log_freq = 10
15
+ enable_tensorboard = true
16
+ save_tb_folder = "tb"
17
+
18
+ [model]
19
+ name = "llama3"
20
+ flavor = "8B"
21
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
22
+ tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
23
+ # converters = "float8"
24
+
25
+ [optimizer]
26
+ name = "AdamW"
27
+ lr = 3e-4
28
+ eps = 1e-8
29
+
30
+ [lr_scheduler]
31
+ warmup_steps = 200 # lr scheduler warm up
32
+
33
+ [training]
34
+ batch_size = 1
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 1000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 1
45
+ pipeline_parallel_degree = 1
46
+ context_parallel_degree = 1
47
+
48
+ [checkpoint]
49
+ enable_checkpoint = false
50
+ folder = "checkpoint"
51
+ interval = 500
52
+ model_weights_only = false
53
+ export_dtype = "float32"
54
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
55
+
56
+ [activation_checkpoint]
57
+ mode = 'selective' # ['none', 'selective', 'full']
58
+ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
59
+
60
+ [float8]
61
+ enable_fsdp_float8_all_gather = false
62
+ precompute_float8_dynamic_scale_for_fsdp = false
63
+ filter_fqns = "output"
torchtitan/models/norms.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
11
+ """
12
+ Builds the specified normalization layer based on the norm_type.
13
+
14
+ Args:
15
+ norm_type (str): The type of normalization layer to build.
16
+ Supported types: layernorm, np_layernorm, rmsnorm
17
+ dim (int): The dimension of the normalization layer.
18
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
19
+
20
+ Returns:
21
+ The built normalization layer.
22
+
23
+ Raises:
24
+ NotImplementedError: If an unknown norm_type is provided.
25
+ """
26
+ norm_type = norm_type.lower() # Normalize to lowercase
27
+
28
+ if norm_type == "layernorm":
29
+ return nn.LayerNorm(dim, eps=eps, bias=False)
30
+ elif norm_type == "np_layernorm":
31
+ return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
32
+ elif norm_type == "rmsnorm":
33
+ return nn.RMSNorm(dim, eps=eps)
34
+ else:
35
+ raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
torchtitan/protocols/__pycache__/model_converter.cpython-311.pyc ADDED
Binary file (5.05 kB). View file
 
torchtitan/tools/utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import gc
8
+ import subprocess
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ import torch
14
+ from torch._utils import _get_available_device_type, _get_device_module
15
+
16
+ from torchtitan.tools.logging import logger
17
+
18
+
19
+ def get_device_info():
20
+ device_type = _get_available_device_type()
21
+ if device_type is None:
22
+ device_type = "cuda" # default device_type: cuda
23
+ device_module = _get_device_module(device_type) # default device_module:torch.cuda
24
+ return device_type, device_module
25
+
26
+
27
+ device_type, device_module = get_device_info()
28
+
29
+
30
+ # used to avoid stragglers in garbage collection
31
+ class GarbageCollection:
32
+ def __init__(self, gc_freq=1000):
33
+ assert gc_freq > 0, "gc_freq must be a positive integer"
34
+ self.gc_freq = gc_freq
35
+ gc.disable()
36
+ self.collect("Initial GC collection.")
37
+
38
+ def run(self, step_count):
39
+ if step_count > 1 and step_count % self.gc_freq == 0:
40
+ self.collect("Peforming periodical GC collection.")
41
+
42
+ @staticmethod
43
+ def collect(reason: str):
44
+ begin = time.monotonic()
45
+ gc.collect(1)
46
+ logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
47
+
48
+
49
+ # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
50
+ def get_peak_flops(device_name: str) -> int:
51
+ try:
52
+ # Run the lspci command and capture the output
53
+ result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
54
+ # Filter the output for lines containing both "NVIDIA" and "H100"
55
+ filtered_lines = [
56
+ line
57
+ for line in result.stdout.splitlines()
58
+ if "NVIDIA" in line and "H100" in line
59
+ ]
60
+ # Join all filtered lines into a single string
61
+ device_name = " ".join(filtered_lines) or device_name
62
+ except FileNotFoundError as e:
63
+ logger.warning(f"Error running lspci: {e}, fallback to use device_name")
64
+ if "A100" in device_name:
65
+ # data from https://www.nvidia.com/en-us/data-center/a100/
66
+ return 312e12
67
+ elif "H100" in device_name:
68
+ # data from https://www.nvidia.com/en-us/data-center/h100/
69
+ # NOTE: Specifications are one-half lower without sparsity.
70
+ if "NVL" in device_name:
71
+ return 835e12
72
+ elif "PCIe" in device_name:
73
+ return 756e12
74
+ else: # for H100 SXM and other variants
75
+ return 989e12
76
+ elif "H200" in device_name:
77
+ # data from https://www.nvidia.com/en-us/data-center/h200/
78
+ return 989e12
79
+ elif "MI300X" in device_name or "MI325X" in device_name:
80
+ # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
81
+ # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
82
+ return 1300e12
83
+ elif "MI250X" in device_name:
84
+ # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
85
+ return 191.5e12
86
+ elif "Data Center GPU Max 1550" in device_name:
87
+ # Also known as Ponte Vecchio (PVC).
88
+ # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
89
+ # Dot Product Accumulate Systolic (DPAS):
90
+ # - Freq: 1300MHz
91
+ # - #ops: 512
92
+ # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
93
+ # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
94
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
95
+ return 512 * max_comp_units * 1300 * 10**6
96
+ else: # for other GPU types, assume A100
97
+ logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
98
+ return 312e12
99
+
100
+
101
+ @dataclass(frozen=True)
102
+ class Color:
103
+ black = "\033[30m"
104
+ red = "\033[31m"
105
+ green = "\033[32m"
106
+ yellow = "\033[33m"
107
+ blue = "\033[34m"
108
+ magenta = "\033[35m"
109
+ cyan = "\033[36m"
110
+ white = "\033[37m"
111
+ reset = "\033[39m"
112
+
113
+
114
+ @dataclass(frozen=True)
115
+ class NoColor:
116
+ black = ""
117
+ red = ""
118
+ green = ""
119
+ yellow = ""
120
+ blue = ""
121
+ magenta = ""
122
+ cyan = ""
123
+ white = ""
124
+ reset = ""
125
+
126
+
127
+ def check_if_feature_in_pytorch(
128
+ feature_name: str,
129
+ pull_request: str,
130
+ min_nightly_version: Optional[str] = None,
131
+ ) -> None:
132
+ if "git" in torch.__version__: # pytorch is built from source
133
+ # notify users to check if the pull request is included in their pytorch
134
+ logger.warning(
135
+ "detected that the pytorch is built from source. Please make sure the PR "
136
+ f"({pull_request_link}) is included in pytorch for correct {feature_name}."
137
+ )
138
+ elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
139
+ logger.warning(
140
+ f"detected that the pytorch version {torch.__version__} is older than "
141
+ f"{min_nightly_version}. Please upgrade a newer version to include the "
142
+ f"change in ({pull_request_link}) for correct {feature_name}."
143
+ )