TRL documentation
SDPO
SDPO
Self-Distillation Policy Optimization (SDPO) was introduced in Reinforcement Learning via Self-Distillation by Jonas Hübotter, Frederike Lübeck, Lejs Behric, Anton Baumann, Marco Bagatella, Daniel Marta, Ido Hakimi, Idan Shenfeld, Thomas Kleine Buening, Carlos Guestrin, and Andreas Krause.
Large language models are increasingly post-trained with reinforcement learning in verifiable domains such as code and math. Yet, current methods for reinforcement learning with verifiable rewards (RLVR) learn only from a scalar outcome reward per attempt, creating a severe credit-assignment bottleneck. Many verifiable environments actually provide rich textual feedback, such as runtime errors or judge evaluations, that explain why an attempt failed. We formalize this setting as reinforcement learning with rich feedback and introduce Self-Distillation Policy Optimization (SDPO), which converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed next-token predictions back into the policy. In this way, SDPO leverages the model’s ability to retrospectively identify its own mistakes in-context. Across scientific reasoning, tool use, and competitive programming on LiveCodeBench v6, SDPO improves sample efficiency and final accuracy over strong RLVR baselines. Notably, SDPO also outperforms baselines in standard RLVR environments that only return scalar feedback by using successful rollouts as implicit feedback for failed attempts. Finally, applying SDPO to individual questions at test time accelerates discovery on difficult binary-reward tasks, achieving the same discovery probability as best-of-k sampling or multi-turn conversations with 3x fewer attempts.
The SDPO trainer is built on TRL’s experimental shared self-distillation stack. It keeps the online rollout-and-reward training flow, then builds a teacher-conditioned view of the same completions from successful rollouts and optional environment feedback.
In the current TRL implementation:
- the default SDPO policy loss mode is
distillation_only hybridmode is also available to combine the base policy loss with the self-distillation loss- supported teacher regularization modes are
emaandnone distillation_topkis only valid whenfull_logit_distillation=True- when
full_logit_distillation=False, SDPO uses token-level reverse KL and requiresdistillation_alpha=1.0 - environment feedback can be injected into teacher reprompts when the dataset exposes a
privileged_contextcolumn
Expected dataset columns
Each example must provide:
prompt: the student-facing promptprivileged_context: optional privileged text, such as environment feedback, used wheninclude_environment_feedback=True
Usage
from datasets import Dataset
from trl.experimental.sdpo import SDPOConfig, SDPOTrainer
dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Solve 2+2."}]],
"privileged_context": ["Your earlier answer used the wrong format."],
}
)
training_args = SDPOConfig(
output_dir="sdpo-model",
distillation_topk=100, # Top-K logit distillation approximation
full_logit_distillation=True, # Required for top-K; enables non-reverse divergences
include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts
)
trainer = SDPOTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()SDPO always requires a prompt column. To use environment feedback, also include a privileged_context column and set include_environment_feedback=True. SDPO will use successful rollouts and, when enabled, that text to build teacher reprompts for self-distillation.
Callbacks
The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows.
Shared self-distillation hooks:
on_self_distillation_batch_prepared: fired when a self-distillation batch is ready. The payload includesprompt_ids,completion_ids, andold_per_token_logpswhen importance-sampling clipping inputs are available.on_generation_batch_built: fired when a new buffered generation batch is created. The payload includesgenerate_everyandsteps_per_generation.
SDPO-specific hook:
on_teacher_context_built: fired after SDPO constructs the teacher-conditioned inputs. The payload includesteacher_input_ids,teacher_attention_mask,completion_mask, andself_distillation_mask.
SDPOConfig
class trl.experimental.sdpo.SDPOConfig
< source >( output_dir: str | None = None per_device_train_batch_size: int = 8 num_train_epochs: float = 3.0 max_steps: int = -1 learning_rate: float = 5e-05 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_steps: float = 0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 optim_target_modules: None | str | list[str] = None gradient_accumulation_steps: int = 1 average_tokens_across_devices: bool = True max_grad_norm: float = 1.0 label_smoothing_factor: float = 0.0 bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None use_cache: bool = False neftune_noise_alpha: float | None = None torch_empty_cache_steps: int | None = None auto_find_batch_size: bool = False logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_steps: float = 10 logging_first_step: bool = False log_on_each_node: bool = True logging_nan_inf_filter: bool = True include_num_input_tokens_seen: str | bool = 'no' log_level: str = 'passive' log_level_replica: str = 'warning' disable_tqdm: bool | None = None report_to: None | str | list[str] = 'none' run_name: str | None = None project: str = 'huggingface' trackio_space_id: str | None = 'trackio' eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' eval_steps: float | None = None eval_delay: float = 0 per_device_eval_batch_size: int = 8 prediction_loss_only: bool = False eval_on_start: bool = False eval_do_concat_batches: bool = True eval_use_gather_object: bool = False eval_accumulation_steps: int | None = None include_for_metrics: list = <factory> batch_eval_metrics: bool = False save_only_model: bool = False save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_on_each_node: bool = False save_total_limit: int | None = None enable_jit_checkpoint: bool = False push_to_hub: bool = False hub_token: str | None = None hub_private_repo: bool | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_always_push: bool = False hub_revision: str | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False restore_callback_states_from_checkpoint: bool = False full_determinism: bool = False seed: int = 42 data_seed: int | None = None use_cpu: bool = False accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None dataloader_drop_last: bool = False dataloader_num_workers: int = 0 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False dataloader_prefetch_factor: int | None = None remove_unused_columns: bool = False label_names: list[str] | None = None train_sampling_strategy: str = 'random' length_column_name: str = 'length' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None ddp_backend: str | None = None ddp_timeout: int = 1800 fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None deepspeed: dict | str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' skip_memory_metrics: bool = True do_train: bool = False do_eval: bool = False do_predict: bool = False resume_from_checkpoint: str | None = None warmup_ratio: float | None = None logging_dir: str | None = None local_rank: int = -1 model_init_kwargs: dict[str, typing.Any] | None = None disable_dropout: bool = False max_prompt_length: int | None = 512 num_generations: int = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict[str, typing.Any] | None = None chat_template_kwargs: dict[str, typing.Any] | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 epsilon_high: float | None = None importance_sampling_level: str = 'token' reward_weights: list[float] | None = None scale_rewards: str | bool = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 distillation_alpha: float = 1.0 distillation_topk: int | None = None full_logit_distillation: bool = False distillation_is_clip: float | None = 2.0 distillation_add_tail: bool = False distillation_weight: float = 1.0 diagnostics_warning_interval: int = 10 diagnostics_flat_tolerance: float = 1e-08 dont_reprompt_on_self_success: bool = True sdpo_policy_loss_mode: str = 'distillation_only' teacher_regularization: str = 'ema' teacher_update_rate: float | None = None ema_update_rate: float = 0.05 max_reprompt_len: int = 10240 use_successful_as_teacher: bool = True success_reward_threshold: float = 1.0 reprompt_template: str = '{prompt}{solution}{feedback}\n\nCorrectly solve the original question.\n' solution_template: str = '\nCorrect solution:\n\n{successful_previous_attempt}\n\n' feedback_template: str = '\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n' include_environment_feedback: bool = False environment_feedback_only_without_solution: bool = False remove_thinking_from_demonstration: bool = False )
Parameters that control the SDPO loss
- sdpo_policy_loss_mode (
str, optional, defaults to"distillation_only") — How SDPO combines the online policy loss and self-distillation loss. Supported:distillation_only,hybrid. - distillation_alpha (
float, optional, defaults to1.0) — Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL settingdistillation_alpha=1.0. - distillation_topk (
intorNone, optional) — Top-k approximation for logit-level SDPO. Requiresfull_logit_distillation=True.
Parameters that control the teacher
- teacher_regularization (
str, optional, defaults to"ema") — Teacher update strategy. Supported:ema,none. - teacher_update_rate (
floatorNone, optional) — EMA update rate used whenteacher_regularization="ema". - ema_update_rate (
float, optional, defaults to0.05) — Deprecated alias forteacher_update_rate.
Parameters that control reprompting
- use_successful_as_teacher (
bool, optional, defaults toTrue) — Whether successful rollouts are turned into teacher demonstrations. - success_reward_threshold (
float, optional, defaults to1.0) — Minimum reward for a rollout to count as successful. - include_environment_feedback (
bool, optional, defaults toFalse) — Whetherprivileged_contextis injected into teacher reprompts when available.
Configuration class for the SDPOTrainer.
This class extends experimental.self_distillation.SelfDistillationConfig with the online teacher-construction
parameters used by Self-Distillation Policy Optimization (SDPO).
SDPOTrainer
class trl.experimental.sdpo.SDPOTrainer
< source >( model: str | transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module reward_funcs: typing.Union[typing.Any, list[typing.Any], NoneType] = None args: trl.experimental.sdpo.sdpo_config.SDPOConfig | None = None train_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None eval_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | dict[str, datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset] | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.processing_utils.ProcessorMixin | None = None reward_processing_classes: transformers.tokenization_utils_base.PreTrainedTokenizerBase | list[transformers.tokenization_utils_base.PreTrainedTokenizerBase] | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) peft_config = None )
Trainer for Self-Distillation Policy Optimization (SDPO).
SDPO augments on-policy optimization with self-distillation from the model’s own high-reward trajectories. It converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed next-token predictions back into the policy.
train
< source >( resume_from_checkpoint: str | bool | None = None trial: optuna.Trial | dict[str, Any] | None = None ignore_keys_for_eval: list[str] | None = None ) → ~trainer_utils.TrainOutput
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.