MedSLM โ€” Medical Small Language Model (~381M Parameters)

A 381M parameter transformer language model pre-trained on curated medical text from PubMed abstracts, PMC full-text articles, and clinical guidelines.

Architecture

MedSLM uses a modern GPT-style transformer with several architectural improvements over the standard GPT-2 design:

Component Detail
Normalization RMSNorm (faster than LayerNorm, used in LLaMA/Mistral)
Positional Encoding Rotary Positional Embeddings (RoPE) โ€” better length generalization
Feed-Forward SwiGLU activation (gated FFN, outperforms GELU)
Attention Grouped-Query Attention (GQA) โ€” shared KV heads for efficiency
Layers 24 transformer blocks
Attention Heads 16 query heads, 8 KV heads
Embedding Dim 1024
Context Length 1024 tokens
Vocab Size 50,257 (GPT-2 BPE tokenizer)
Parameters 381,373,440 (~381M)

Training

  • Dataset: Saminx22/medical_data_for_slm (~44M tokens)
  • Sources: PubMed abstracts, PMC Open Access full-text, Clinical Guidelines
  • Tokenizer: GPT-2 BPE tokenizer (50,257 vocab)
  • Optimizer: AdamW (betas=0.9/0.95, weight_decay=0.1)
  • LR Schedule: Linear warmup (1000 steps) + Cosine decay
  • Peak LR: 0.0003
  • Precision: bfloat16
  • Effective Batch Size: 256
  • Max Steps: 20,000
  • Best Val Loss: 3.2198 (at step 19500)

Usage

Loading the Model

import torch
import json
from safetensors.torch import load_file
from transformers import AutoTokenizer

# Load config
with open("config.json") as f:
    config_dict = json.load(f)

# Reconstruct model (requires the MedSLM class definition)
config = MedSLMConfig(**{k: v for k, v in config_dict.items()
                         if k in MedSLMConfig.__dataclass_fields__})
model = MedSLM(config)

# Load weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/")

Generating Text

prompt = "The patient presented with acute myocardial infarction"
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

output = model.generate(input_ids, max_new_tokens=200, temperature=0.8, top_k=50, top_p=0.9)
print(tokenizer.decode(output.squeeze().tolist()))

Resuming Training

# Load optimizer state
optimizer_state = torch.load("optimizer.pt")
optimizer.load_state_dict(optimizer_state)

Files

File Description
model.safetensors Model weights (safetensors format)
optimizer.pt Optimizer state dict for resuming training
config.json Model architecture configuration
training_config.json Training hyperparameters and loss history
tokenizer/ GPT-2 tokenizer files
loss_curves.png Training/validation loss plot

Intended Use

This model is intended for research purposes in medical NLP. It can be used as:

  • A foundation model for downstream medical NLP tasks (NER, classification, QA)
  • A starting point for medical instruction tuning
  • A baseline for comparing medical language model architectures

Limitations

  • Not for clinical use: This model should NOT be used for clinical decision-making
  • Small scale: ~381M parameters is relatively small; larger models will perform better
  • Limited data: Trained on ~44M tokens (production models use trillions)
  • No alignment: This is a base model without instruction tuning or RLHF
  • English only: Trained exclusively on English medical text
  • Potential biases: May reflect biases present in the medical literature

License

Apache 2.0

Downloads last month
220
Safetensors
Model size
0.3B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train Saminx22/MedSLM