rhordoancc commited on
Commit
a820eb7
·
verified ·
1 Parent(s): 7b7682b

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - knowledge-graph
5
+ - question-answering
6
+ - retrieval-augmented-generation
7
+ - gnn
8
+ - freebase
9
+ ---
10
+
11
+ # D-RAG Phase 1 Checkpoints
12
+
13
+ Pre-trained retriever checkpoints and heuristics for [D-RAG: Differentiable Retrieval-Augmented Generation](https://github.com/rhordoan/drag-improved).
14
+
15
+ ## 📦 Contents
16
+
17
+ | File | Size | Description |
18
+ |------|------|-------------|
19
+ | `checkpoints_cwq_subgraph/phase1_best.pt` | 288 MB | CWQ retriever (27,613 samples, 10 epochs) |
20
+ | `checkpoints_webqsp_subgraph/phase1_best.pt` | 288 MB | WebQSP retriever (2,826 samples, 10 epochs) |
21
+ | `data/train_heuristics_cwq.jsonl` | 111 MB | CWQ heuristics with per-question subgraphs |
22
+ | `data/train_heuristics_webqsp_subgraph.jsonl` | 12 MB | WebQSP heuristics with per-question subgraphs |
23
+
24
+ **Total size:** ~700 MB
25
+
26
+ ## 🚀 Usage
27
+
28
+ ### Automatic Download (Recommended)
29
+
30
+ ```bash
31
+ git clone https://github.com/rhordoan/drag-improved.git
32
+ cd drag-improved
33
+ ./scripts/setup_environment.sh # Downloads checkpoints automatically
34
+ ```
35
+
36
+ ### Manual Download
37
+
38
+ ```bash
39
+ # Using Python script
40
+ python scripts/download_checkpoints.py
41
+
42
+ # Or using huggingface-cli
43
+ huggingface-cli download rhordoan/drag-improved-checkpoints \
44
+ checkpoints_cwq_subgraph/phase1_best.pt \
45
+ --local-dir .
46
+ ```
47
+
48
+ ### Load in Python
49
+
50
+ ```python
51
+ import torch
52
+ from src.model.retriever import DRAGRetriever
53
+
54
+ # Load checkpoint
55
+ checkpoint = torch.load('checkpoints_cwq_subgraph/phase1_best.pt')
56
+
57
+ # Initialize retriever
58
+ retriever = DRAGRetriever(
59
+ node_dim=256,
60
+ edge_dim=256,
61
+ hidden_dim=256,
62
+ instruction_dim=384,
63
+ relation_dim=256,
64
+ num_reasoning_steps=3
65
+ )
66
+
67
+ # Load weights
68
+ retriever.load_state_dict(checkpoint['model_state_dict'])
69
+ ```
70
+
71
+ ## 📊 Dataset Details
72
+
73
+ ### CWQ (ComplexWebQuestions)
74
+ - **Source:** `rmanluo/RoG-cwq` (Hugging Face)
75
+ - **Samples:** 27,613
76
+ - **Training time:** ~3.5 minutes on A100
77
+ - **Final loss:** 0.2616 (BCE: 0.092, Ranking: 0.656)
78
+
79
+ ### WebQSP (WebQuestions Semantic Parses)
80
+ - **Source:** `rmanluo/RoG-webqsp` (Hugging Face)
81
+ - **Samples:** 2,826
82
+ - **Training time:** ~30 seconds on A100
83
+ - **Final loss:** ~0.25
84
+
85
+ ## 🏗️ Model Architecture
86
+
87
+ **DRAGRetriever** (GNN-based fact retriever):
88
+ - **Instruction Module:** Sentence-BERT encoder
89
+ - **Graph Reasoning:** 3 layers of instruction-conditioned message passing
90
+ - **Instruction Update:** Iterative refinement
91
+ - **Fact Scorer:** Binary selection per edge
92
+
93
+ **Training config:**
94
+ - Optimizer: AdamW (lr=5e-5, weight_decay=0.001)
95
+ - Loss: ρ × BCE + (1-ρ) × Ranking (ρ=0.7)
96
+ - Batch size: 16
97
+ - Gradient clipping: 1.0
98
+
99
+ ## 📜 Citation
100
+
101
+ ```bibtex
102
+ @article{drag2024,
103
+ title={D-RAG: Differentiable Retrieval-Augmented Generation},
104
+ journal={arXiv preprint},
105
+ year={2024}
106
+ }
107
+ ```
108
+
109
+ ## 🔗 Links
110
+
111
+ - **GitHub:** https://github.com/rhordoan/drag-improved
112
+ - **Paper:** [D-RAG Paper](https://arxiv.org/...)
113
+ - **Datasets:** [RoG-CWQ](https://huggingface.co/datasets/rmanluo/RoG-cwq) | [RoG-WebQSP](https://huggingface.co/datasets/rmanluo/RoG-webqsp)
114
+
115
+ ## 📄 License
116
+
117
+ MIT License - See repository for details.
118
+