Gemma-3-1B-IT Math GRPO

SFT โ†’ RS-SFT โ†’ GRPO 3๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ํ•™์Šตํ•œ ํ•œ๊ตญ์–ด ์ˆ˜ํ•™ ๋ชจ๋ธ.

์„ฑ๋Šฅ

Benchmark Score
HRM8K eval GSM8K (264๋ฌธ์ œ, Korean) ~46.2%
HRM8K eval MATH (577๋ฌธ์ œ, Korean) ~16.5%

โš ๏ธ GRPO๋Š” base (RS-SFT, ~46.6%)๋Œ€๋น„ ์œ ์˜๋ฏธํ•œ ๊ฐœ์„  ์—†์Œ. ์ด๋ฏธ SFT+RS-SFT๋กœ ์ตœ์ ํ™”๋œ 1B ๋ชจ๋ธ์—์„œ RL ์ถ”๊ฐ€ ์—ฌ์ง€๊ฐ€ ๊ทนํžˆ ์ œํ•œ์ .

๋ฐ์ดํ„ฐ & ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ

Stage 1-2: SFT โ†’ RS-SFT

์œ„ RS-SFT ๋ชจ๋ธ๊ณผ ๋™์ผ. (SFT ๊ต์‚ฌ ์ฆ๋ฅ˜ โ†’ RS ์ƒ˜ํ”Œ๋ง โ†’ RS-SFT with 5x replay)

Stage 3: GRPO

Reward ํ•จ์ˆ˜

ํ•™์Šต ๋ฐ์ดํ„ฐ

  • ํ”„๋กฌํ”„ํŠธ๋งŒ ํ•„์š” (GRPO๋Š” ํ•™์Šต ์ค‘ ์ž์ฒด ์ƒ์„ฑ)
  • GSM8K train 6,871๊ฐœ unique ํ•œ๊ตญ์–ด ๋ฌธ์ œ
  • ๊ฐ ๋ฌธ์ œ์— ground truth ๋‹ต ํฌํ•จ (reward ๊ณ„์‚ฐ์šฉ)

ํ•™์Šต ์„ค์ •

DPO vs GRPO ๋น„๊ต (์‹คํ—˜ ๊ฒฐ๊ณผ)

DPO ์‹คํŒจ ๋ถ„์„ (10ํšŒ)

๋ฐ์ดํ„ฐ ์ „๋žต GSM8K ๋ถ„์„
๊ธฐ์กด ๋ฐฉ์‹ (shortest correct + longest incorrect) 48.1% ๊ธธ์ด ํŽธํ–ฅ๋งŒ ํ•™์Šต
Length-matched (55์ž ์ฐจ์ด) 46.2% ์‹ ํ˜ธ ์—†์Œ (DPO accโ‰ˆ0.50)
Teacher-chosen (30B ๊ต์‚ฌ ํ’€์ด=chosen) 47.3% OOD ๋ฌธ์ œ
Multi-pair (์งˆ๋ฌธ๋‹น 3์Œ, ๋‚œ์ด๋„ ๊ฐ€์ค‘) 46.6% ์–‘ ์ฆ๊ฐ€๋„ ๋ฌดํšจ
base ๋Œ€๋น„ ยฑ0-2%p ๋ชจ๋‘ variance ๋ฒ”์œ„

DPO ๊ทผ๋ณธ ๋ฌธ์ œ: 1B ๋ชจ๋ธ์ด ์ •๋‹ต/์˜ค๋‹ต ํ’€์ด์˜ ๋ฏธ๋ฌ˜ํ•œ ์ฐจ์ด๋ฅผ ๋‚ด๋ถ€์ ์œผ๋กœ ๊ตฌ๋ถ„ํ•  capacity ๋ถ€์กฑ.

GRPO ๊ฒฐ๊ณผ (2ํšŒ)

beta steps GSM8K ๋น„๊ณ 
0.001 200 43.9% format ํ‡ด๋ณด (boxedโ†“)
0.04 500 46.2% base์™€ ๋™์ผ ์ˆ˜์ค€

ํ™˜๊ฒฝ

  • GPU: H100 NVL 95GB
  • Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0
  • GRPO ํ•™์Šต: ~55๋ถ„ (vLLM colocate ์‚ฌ์šฉ)

์žฌํ˜„ ๋ฐฉ๋ฒ•

INFO 03-19 14:53:37 [init.py:216] Automatically detected platform cuda. (APIServer pid=3429210) INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0 (APIServer pid=3429210) INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}

ํŒŒ์ผ

  • : Stage 1 SFT
  • : RS ์ƒ˜ํ”Œ๋ง
  • : Stage 2 RS-SFT
  • : Stage 3 GRPO
  • : HRM8K ํ‰๊ฐ€
Downloads last month
67
Safetensors
Model size
1.0B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for NotoriousH2/gemma-3-1b-it-Math-GRPO

Finetuned
(453)
this model