Gemma-3-1B-IT Math GRPO
SFT โ RS-SFT โ GRPO 3๋จ๊ณ ํ์ดํ๋ผ์ธ์ผ๋ก ํ์ตํ ํ๊ตญ์ด ์ํ ๋ชจ๋ธ.
์ฑ๋ฅ
| Benchmark | Score |
|---|---|
| HRM8K eval GSM8K (264๋ฌธ์ , Korean) | ~46.2% |
| HRM8K eval MATH (577๋ฌธ์ , Korean) | ~16.5% |
โ ๏ธ GRPO๋ base (RS-SFT, ~46.6%)๋๋น ์ ์๋ฏธํ ๊ฐ์ ์์. ์ด๋ฏธ SFT+RS-SFT๋ก ์ต์ ํ๋ 1B ๋ชจ๋ธ์์ RL ์ถ๊ฐ ์ฌ์ง๊ฐ ๊ทนํ ์ ํ์ .
๋ฐ์ดํฐ & ํ์ต ํ์ดํ๋ผ์ธ
Stage 1-2: SFT โ RS-SFT
์ RS-SFT ๋ชจ๋ธ๊ณผ ๋์ผ. (SFT ๊ต์ฌ ์ฆ๋ฅ โ RS ์ํ๋ง โ RS-SFT with 5x replay)
Stage 3: GRPO
Reward ํจ์
ํ์ต ๋ฐ์ดํฐ
- ํ๋กฌํํธ๋ง ํ์ (GRPO๋ ํ์ต ์ค ์์ฒด ์์ฑ)
- GSM8K train 6,871๊ฐ unique ํ๊ตญ์ด ๋ฌธ์
- ๊ฐ ๋ฌธ์ ์ ground truth ๋ต ํฌํจ (reward ๊ณ์ฐ์ฉ)
ํ์ต ์ค์
DPO vs GRPO ๋น๊ต (์คํ ๊ฒฐ๊ณผ)
DPO ์คํจ ๋ถ์ (10ํ)
| ๋ฐ์ดํฐ ์ ๋ต | GSM8K | ๋ถ์ |
|---|---|---|
| ๊ธฐ์กด ๋ฐฉ์ (shortest correct + longest incorrect) | 48.1% | ๊ธธ์ด ํธํฅ๋ง ํ์ต |
| Length-matched (55์ ์ฐจ์ด) | 46.2% | ์ ํธ ์์ (DPO accโ0.50) |
| Teacher-chosen (30B ๊ต์ฌ ํ์ด=chosen) | 47.3% | OOD ๋ฌธ์ |
| Multi-pair (์ง๋ฌธ๋น 3์, ๋์ด๋ ๊ฐ์ค) | 46.6% | ์ ์ฆ๊ฐ๋ ๋ฌดํจ |
| base ๋๋น | ยฑ0-2%p | ๋ชจ๋ variance ๋ฒ์ |
DPO ๊ทผ๋ณธ ๋ฌธ์ : 1B ๋ชจ๋ธ์ด ์ ๋ต/์ค๋ต ํ์ด์ ๋ฏธ๋ฌํ ์ฐจ์ด๋ฅผ ๋ด๋ถ์ ์ผ๋ก ๊ตฌ๋ถํ capacity ๋ถ์กฑ.
GRPO ๊ฒฐ๊ณผ (2ํ)
| beta | steps | GSM8K | ๋น๊ณ |
|---|---|---|---|
| 0.001 | 200 | 43.9% | format ํด๋ณด (boxedโ) |
| 0.04 | 500 | 46.2% | base์ ๋์ผ ์์ค |
ํ๊ฒฝ
- GPU: H100 NVL 95GB
- Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0
- GRPO ํ์ต: ~55๋ถ (vLLM colocate ์ฌ์ฉ)
์ฌํ ๋ฐฉ๋ฒ
INFO 03-19 14:53:37 [init.py:216] Automatically detected platform cuda. [1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0 [1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
ํ์ผ
- : Stage 1 SFT
- : RS ์ํ๋ง
- : Stage 2 RS-SFT
- : Stage 3 GRPO
- : HRM8K ํ๊ฐ
- Downloads last month
- 67