| |
| """ |
| Configuration class for GPT model |
| Author: Shilpaj Bhalerao |
| Date: 2025-01-19 |
| """ |
| |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class RoPEConfig: |
| """ |
| Configuration for Rotary Position Embeddings |
| """ |
| base: int = 10000 |
| scaling_factor: float = 1.0 |
| head_dim_fraction: float = 0.3125 |
| round_multiple: int = 8 |
|
|
|
|
| @dataclass |
| class SmollmConfig: |
| """ |
| Configuration for Smollm training setup |
| """ |
| |
| block_size: int = 2048 |
| vocab_size: int = 49152 |
| n_layer: int = 30 |
| n_head: int = 9 |
| n_embd: int = 576 |
| mlp_ratio: int = 2.67 |
| dropout: float = 0.0 |
| |
| |
| batch_size: int = 1 |
| num_workers: int = 0 |
| shuffle_buffer_size: int = 1000 |
| max_length: int = 2048 |
| learning_rate: float = 3e-5 |
| weight_decay: float = 1e-4 |
| |
| |
| max_new_tokens: int = 100 |
| |
| |
| seed: int = 1337 |
| max_steps: int = 5000 |
| clear_cache_every: int = 1000 |
| |
| |
| context_length: int = 10 |
| temperature: float = 1.0 |
| top_k: int = 50 |
|
|
|
|
| @dataclass |
| class CheckpointConfig: |
| """ |
| Configuration for checkpointing |
| """ |
| checkpoint_dir: str = "checkpoints" |
| checkpoint_every: int = 500 |
| save_last: bool = True |
| save_top_k: int = 1 |
| save_weights_only: bool = True |
| monitor: str = "train_loss" |
| mode: str = "min" |
| save_on_train_epoch_end: bool = False |
|
|
|
|
| @dataclass |
| class LoggingConfig: |
| """ |
| Configuration for logging |
| """ |
| log_every: int = 50 |
| generate_every: int = 500 |
| log_metrics: bool = True |
| log_progress_bar: bool = True |
| log_model_summary: bool = True |
|
|
|
|
| @dataclass |
| class OptimizerConfig: |
| """ |
| Configuration for optimizer |
| """ |
| optimizer: str = "AdamW" |
| learning_rate: float = 3e-5 |
| weight_decay: float = 1e-4 |
| max_lr: float = 3e-4 |
| div_factor: float = 25.0 |
| final_div_factor: float = 100.0 |
| pct_start: float = 0.2 |
| |
| |
| optimizer_kwargs: dict = field(default_factory=lambda: { |
| 'betas': (0.9, 0.95), |
| 'eps': 1e-8, |
| }) |
| three_phase: bool = False |
| anneal_strategy: str = 'linear' |
|
|
|
|
| @dataclass |
| class DataConfig: |
| """ |
| Configuration for dataset and tokenizer |
| """ |
| |
| dataset_path: str = "HuggingFaceTB/smollm-corpus" |
| dataset_name: str = "cosmopedia-v2" |
| |
| |
| tokenizer_path: str = "HuggingFaceTB/cosmo2-tokenizer" |
| |
| |
| batch_size: int = 32 |
| num_workers: int = 4 |
| shuffle_buffer_size: int = 1000 |
| max_length: int = 512 |
| |
| |
| validation_split: float = 0.1 |
| pin_memory: bool = True |
| streaming: bool = True |
|
|
|
|
| @dataclass |
| class TrainerConfig: |
| """ |
| Configuration for PyTorch Lightning Trainer |
| """ |
| accelerator: str = 'auto' |
| devices: int = 1 |
| precision: str = '16-mixed' |
| log_every_n_steps: int = 10 |
| strategy: str = 'auto' |
| deterministic: bool = False |
| benchmark: bool = True |
| enable_progress_bar: bool = True |
| enable_model_summary: bool = True |
| profiler: str = 'simple' |
| gradient_clip_val: float = 1.0 |
| accumulate_grad_batches: int = 2 |
| val_check_interval: int = 1000 |
| check_val_every_n_epoch: None = None |
|
|