roneneldan/TinyStories
Viewer • Updated • 2.14M • 89.8k • 983
This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX.
- Hidden Size: 512
- Number of Layers: 8
- Attention Heads: 8
- Intermediate Size: 2048
- Max Position Embeddings: 256
- Vocab Size: 50257
- Rotary Position Embeddings: True
# This model was trained with JAX/Flax and requires the custom transformer implementation
# to load and use. See the repository for implementation details.
from transformers import AutoTokenizer
import jax.numpy as jnp
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Example text generation (requires custom model loading)
prompt = "Once upon a time, there was a little"
# ... (model loading and generation code)
model:
hidden_size: 512
num_layers: 8
num_attention_heads: 8
intermediate_size: 2048
max_position_embeddings: 256
training:
learning_rate: 0.0003
batch_size: 32
epochs: 10
warmup_ratio: 0.1
config.json: Model configurationtrain_history.json: Training metrics and durationtokenizer/: GPT-2 tokenizer files model_checkpoint/: Best model checkpointtensorboard_logs/: Training logs for TensorBoardMIT License - see LICENSE file for details.