CRAG: Causal Reasoning for Adversomics Graphs
Collection
SOTA dual-encoder models for drug-ADR relation extraction.
•
3 items
•
Updated
•
1
CRAG (Causal Reasoning for Adversomics Graphs) is a dual-encoder model for extracting adverse drug event (ADE) relationships from clinical narratives. It achieves state-of-the-art performance on ADE extraction, significantly outperforming both specialized biomedical language models and large language models.
CRAG uses a dual-encoder architecture with:
The model is trained in two phases:
| Metric | Score |
|---|---|
| F1 Score | 0.9370 |
| Precision | 0.9336 |
| Recall | 0.9405 |
| AUC-ROC | 0.9735 |
| Model | F1 Score | AUC-ROC | F1 Improvement |
|---|---|---|---|
| BioLinkBERT (zero-shot) | 0.215 | 0.523 | - |
| GPT-4 Turbo | 0.734 | 0.713 | - |
| Qwen2.5-1.5B-Instruct | 0.714 | 0.728 | - |
| CRAG (this model) | 0.937 | 0.974 | +27% vs GPT-4 |
The model expects two separate text inputs:
[DRUG]...[/DRUG] tokens[ADR]...[/ADR] tokens# Original text: "The patient developed aspirin-induced gastric bleeding."
drug_context = "The patient developed [DRUG] aspirin [/DRUG]-induced gastric bleeding."
adr_context = "The patient developed aspirin-induced [ADR] gastric bleeding [/ADR]."
The model outputs a single logit value. Apply sigmoid to get a probability:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
# --- 1. Define model architecture (must match training) ---
class AttentionPooling(nn.Module):
def __init__(self, hidden_dim, num_heads=4):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
def forward(self, hidden_states, attention_mask=None):
batch_size = hidden_states.size(0)
query = self.query.expand(batch_size, -1, -1)
key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
attn_output, _ = self.attention(query, hidden_states, hidden_states, key_padding_mask=key_padding_mask)
return attn_output.squeeze(1)
class CRAGDualEncoder(nn.Module):
def __init__(self, base_model="michiyasunaga/BioLinkBERT-base"):
super().__init__()
hidden_dim, fusion_dim, dropout = 768, 256, 0.1
self.drug_encoder = AutoModel.from_pretrained(base_model)
self.adr_encoder = AutoModel.from_pretrained(base_model)
self.drug_pooler = AttentionPooling(hidden_dim, 4)
self.adr_pooler = AttentionPooling(hidden_dim, 4)
self.drug_projection = nn.Sequential(
nn.Linear(hidden_dim, fusion_dim), nn.LayerNorm(fusion_dim),
nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim, fusion_dim))
self.adr_projection = nn.Sequential(
nn.Linear(hidden_dim, fusion_dim), nn.LayerNorm(fusion_dim),
nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim, fusion_dim))
self.bilinear = nn.Bilinear(fusion_dim, fusion_dim, fusion_dim)
self.fusion_norm = nn.LayerNorm(fusion_dim)
self.classifier = nn.Sequential(
nn.Linear(fusion_dim * 4, fusion_dim), nn.LayerNorm(fusion_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(fusion_dim, fusion_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim // 2, 1))
def forward(self, drug_ids, drug_mask, adr_ids, adr_mask):
drug_out = self.drug_encoder(input_ids=drug_ids, attention_mask=drug_mask)
adr_out = self.adr_encoder(input_ids=adr_ids, attention_mask=adr_mask)
drug_pooled = self.drug_pooler(drug_out.last_hidden_state, drug_mask)
adr_pooled = self.adr_pooler(adr_out.last_hidden_state, adr_mask)
drug_repr = self.drug_projection(drug_pooled)
adr_repr = self.adr_projection(adr_pooled)
bilinear_out = self.fusion_norm(self.bilinear(drug_repr, adr_repr))
combined = torch.cat([bilinear_out, drug_repr, adr_repr, drug_repr * adr_repr], dim=-1)
return self.classifier(combined).squeeze(-1)
# --- 2. Load tokenizer and model ---
tokenizer = AutoTokenizer.from_pretrained("michiyasunaga/BioLinkBERT-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['[DRUG]', '[/DRUG]', '[ADR]', '[/ADR]']})
model = CRAGDualEncoder()
model.drug_encoder.resize_token_embeddings(len(tokenizer))
model.adr_encoder.resize_token_embeddings(len(tokenizer))
model_path = hf_hub_download(repo_id="chrisvoncsefalvay/CRAG-dual-encoder-mimicause", filename="pytorch_model.pt")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
# --- 3. Run inference ---
drug_context = "The patient developed [DRUG] aspirin [/DRUG]-induced gastric bleeding."
adr_context = "The patient developed aspirin-induced [ADR] gastric bleeding [/ADR]."
drug_enc = tokenizer(drug_context, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
adr_enc = tokenizer(adr_context, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
with torch.no_grad():
logit = model(drug_enc['input_ids'], drug_enc['attention_mask'],
adr_enc['input_ids'], adr_enc['attention_mask'])
probability = torch.sigmoid(logit).item()
print(f"ADE probability: {probability:.4f}")
print(f"Prediction: {'ADE Relationship' if probability > 0.5 else 'No ADE Relationship'}")
CRAGDualEncoder(
(drug_encoder): BioLinkBERT-base (110M params)
(adr_encoder): BioLinkBERT-base (110M params)
(pooling): Mean pooling (weighted by attention mask)
(drug_projection): Linear(768 -> 256) + LayerNorm + GELU + Linear
(adr_projection): Linear(768 -> 256) + LayerNorm + GELU + Linear
(bilinear): Bilinear(256, 256 -> 256)
(classifier): Linear(1024 -> 256) + LayerNorm + GELU + Linear(256 -> 128) + GELU + Linear(128 -> 1)
)
Total Parameters: 238,798,081
The model was trained on a combination of:
| Split | Samples | Positive | Negative |
|---|---|---|---|
| Train | 12,978 | 6,264 | 6,714 |
| Validation | 1,667 | 812 | 855 |
| Test | 1,681 | 807 | 874 |
| Parameter | Phase 1 (Contrastive) | Phase 2 (Classification) |
|---|---|---|
| Epochs | 5 | 8 |
| Batch Size | 16 | 16 |
| Learning Rate | 2e-5 | 2e-5 |
| Loss Function | InfoNCE (τ=0.07) | Focal (γ=1.0, α=0.75) |
| Pooling | Mean | Mean |
| Hard Negatives | 50% | - |
An ONNX export is provided for deployment flexibility:
import onnxruntime as ort
from huggingface_hub import hf_hub_download
# Download ONNX model
onnx_path = hf_hub_download(
repo_id="chrisvoncsefalvay/CRAG-dual-encoder-mimicause",
filename="crag_model.onnx"
)
# Create inference session
session = ort.InferenceSession(onnx_path)
# Run inference (see pytorch example for input preparation)
outputs = session.run(None, {
"drug_input_ids": drug_ids,
"drug_attention_mask": drug_mask,
"adr_input_ids": adr_ids,
"adr_attention_mask": adr_mask
})
If you use this model, please cite:
@misc{crag2025,
author = {von Csefalvay, Chris},
title = {CRAG: Causal Reasoning for Adversomics Graphs},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-mimicause}}
}
Apache 2.0
Base model
michiyasunaga/BioLinkBERT-base