CRAG: Causal Reasoning for Adversomics Graphs

CRAG (Causal Reasoning for Adversomics Graphs) is a dual-encoder model for extracting adverse drug event (ADE) relationships from clinical narratives. It achieves state-of-the-art performance on ADE extraction, significantly outperforming both specialized biomedical language models and large language models.

Model Description

CRAG uses a dual-encoder architecture with:

  • Two separate BioLinkBERT encoders: One for drug mentions, one for adverse event mentions
  • Mean pooling: Aggregates token representations weighted by attention mask
  • Bilinear fusion: Captures complex drug-ADR interactions
  • Multi-view concatenation: Combines bilinear output, individual embeddings, and element-wise products

Training Approach

The model is trained in two phases:

  1. Phase 1 - Contrastive Pre-training: InfoNCE loss with hard negative mining to learn discriminative drug-ADR embeddings
  2. Phase 2 - Classification Fine-tuning: Focal loss to handle class imbalance and refine the classifier

Performance

Test Set Results

Metric Score
F1 Score 0.9370
Precision 0.9336
Recall 0.9405
AUC-ROC 0.9735

Comparison with Baselines

Model F1 Score AUC-ROC F1 Improvement
BioLinkBERT (zero-shot) 0.215 0.523 -
GPT-4 Turbo 0.734 0.713 -
Qwen2.5-1.5B-Instruct 0.714 0.728 -
CRAG (this model) 0.937 0.974 +27% vs GPT-4

Input Format

The model expects two separate text inputs:

  1. Drug context: The clinical text with the drug mention marked with [DRUG]...[/DRUG] tokens
  2. ADR context: The same text with the adverse event marked with [ADR]...[/ADR] tokens

Example Input

# Original text: "The patient developed aspirin-induced gastric bleeding."

drug_context = "The patient developed [DRUG] aspirin [/DRUG]-induced gastric bleeding."
adr_context = "The patient developed aspirin-induced [ADR] gastric bleeding [/ADR]."

Output

The model outputs a single logit value. Apply sigmoid to get a probability:

  • > 0.5: Positive relationship (drug causes the adverse event)
  • < 0.5: Negative/no relationship

Usage

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download

# --- 1. Define model architecture (must match training) ---
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim, num_heads=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))

    def forward(self, hidden_states, attention_mask=None):
        batch_size = hidden_states.size(0)
        query = self.query.expand(batch_size, -1, -1)
        key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
        attn_output, _ = self.attention(query, hidden_states, hidden_states, key_padding_mask=key_padding_mask)
        return attn_output.squeeze(1)

class CRAGDualEncoder(nn.Module):
    def __init__(self, base_model="michiyasunaga/BioLinkBERT-base"):
        super().__init__()
        hidden_dim, fusion_dim, dropout = 768, 256, 0.1

        self.drug_encoder = AutoModel.from_pretrained(base_model)
        self.adr_encoder = AutoModel.from_pretrained(base_model)
        self.drug_pooler = AttentionPooling(hidden_dim, 4)
        self.adr_pooler = AttentionPooling(hidden_dim, 4)

        self.drug_projection = nn.Sequential(
            nn.Linear(hidden_dim, fusion_dim), nn.LayerNorm(fusion_dim),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim, fusion_dim))
        self.adr_projection = nn.Sequential(
            nn.Linear(hidden_dim, fusion_dim), nn.LayerNorm(fusion_dim),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim, fusion_dim))

        self.bilinear = nn.Bilinear(fusion_dim, fusion_dim, fusion_dim)
        self.fusion_norm = nn.LayerNorm(fusion_dim)
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim * 4, fusion_dim), nn.LayerNorm(fusion_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(fusion_dim, fusion_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(fusion_dim // 2, 1))

    def forward(self, drug_ids, drug_mask, adr_ids, adr_mask):
        drug_out = self.drug_encoder(input_ids=drug_ids, attention_mask=drug_mask)
        adr_out = self.adr_encoder(input_ids=adr_ids, attention_mask=adr_mask)
        drug_pooled = self.drug_pooler(drug_out.last_hidden_state, drug_mask)
        adr_pooled = self.adr_pooler(adr_out.last_hidden_state, adr_mask)
        drug_repr = self.drug_projection(drug_pooled)
        adr_repr = self.adr_projection(adr_pooled)
        bilinear_out = self.fusion_norm(self.bilinear(drug_repr, adr_repr))
        combined = torch.cat([bilinear_out, drug_repr, adr_repr, drug_repr * adr_repr], dim=-1)
        return self.classifier(combined).squeeze(-1)

# --- 2. Load tokenizer and model ---
tokenizer = AutoTokenizer.from_pretrained("michiyasunaga/BioLinkBERT-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['[DRUG]', '[/DRUG]', '[ADR]', '[/ADR]']})

model = CRAGDualEncoder()
model.drug_encoder.resize_token_embeddings(len(tokenizer))
model.adr_encoder.resize_token_embeddings(len(tokenizer))

model_path = hf_hub_download(repo_id="chrisvoncsefalvay/CRAG-dual-encoder-mimicause", filename="pytorch_model.pt")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

# --- 3. Run inference ---
drug_context = "The patient developed [DRUG] aspirin [/DRUG]-induced gastric bleeding."
adr_context = "The patient developed aspirin-induced [ADR] gastric bleeding [/ADR]."

drug_enc = tokenizer(drug_context, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
adr_enc = tokenizer(adr_context, return_tensors='pt', max_length=128, padding='max_length', truncation=True)

with torch.no_grad():
    logit = model(drug_enc['input_ids'], drug_enc['attention_mask'],
                  adr_enc['input_ids'], adr_enc['attention_mask'])
    probability = torch.sigmoid(logit).item()

print(f"ADE probability: {probability:.4f}")
print(f"Prediction: {'ADE Relationship' if probability > 0.5 else 'No ADE Relationship'}")

Model Architecture

CRAGDualEncoder(
  (drug_encoder): BioLinkBERT-base (110M params)
  (adr_encoder): BioLinkBERT-base (110M params)
  (pooling): Mean pooling (weighted by attention mask)
  (drug_projection): Linear(768 -> 256) + LayerNorm + GELU + Linear
  (adr_projection): Linear(768 -> 256) + LayerNorm + GELU + Linear
  (bilinear): Bilinear(256, 256 -> 256)
  (classifier): Linear(1024 -> 256) + LayerNorm + GELU + Linear(256 -> 128) + GELU + Linear(128 -> 1)
)

Total Parameters: 238,798,081

Training Data

The model was trained on a combination of:

  • ADE Corpus v2: Biomedical literature annotations for drug-adverse event pairs
  • MIMICause: Clinical notes from MIMIC-III with causal ADE annotations
Split Samples Positive Negative
Train 12,978 6,264 6,714
Validation 1,667 812 855
Test 1,681 807 874

Training Configuration

Parameter Phase 1 (Contrastive) Phase 2 (Classification)
Epochs 5 8
Batch Size 16 16
Learning Rate 2e-5 2e-5
Loss Function InfoNCE (τ=0.07) Focal (γ=1.0, α=0.75)
Pooling Mean Mean
Hard Negatives 50% -

Experiment Tracking

ONNX Export

An ONNX export is provided for deployment flexibility:

import onnxruntime as ort
from huggingface_hub import hf_hub_download

# Download ONNX model
onnx_path = hf_hub_download(
    repo_id="chrisvoncsefalvay/CRAG-dual-encoder-mimicause",
    filename="crag_model.onnx"
)

# Create inference session
session = ort.InferenceSession(onnx_path)

# Run inference (see pytorch example for input preparation)
outputs = session.run(None, {
    "drug_input_ids": drug_ids,
    "drug_attention_mask": drug_mask,
    "adr_input_ids": adr_ids,
    "adr_attention_mask": adr_mask
})

Limitations

  • Trained primarily on English clinical and biomedical text
  • Requires drug and ADR spans to be pre-identified (not an end-to-end NER+RE model)
  • Performance may vary on drug/ADR pairs not seen during training
  • Best suited for binary relation classification, not relation type classification

Citation

If you use this model, please cite:

@misc{crag2025,
  author = {von Csefalvay, Chris},
  title = {CRAG: Causal Reasoning for Adversomics Graphs},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-mimicause}}
}

License

Apache 2.0

Downloads last month
35
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 2 Ask for provider support

Model tree for chrisvoncsefalvay/CRAG-dual-encoder-mimicause

Quantized
(1)
this model

Dataset used to train chrisvoncsefalvay/CRAG-dual-encoder-mimicause

Space using chrisvoncsefalvay/CRAG-dual-encoder-mimicause 1

Collection including chrisvoncsefalvay/CRAG-dual-encoder-mimicause

Evaluation results