VertexAI Training and Gated model access

Oh… It seems like you have to explicitly pass YAML path as an argument…:sweat_smile:


Below is a focused, Vertex-AI-specific view of how TRL + YAML config works, what can go wrong, and how to set it up cleanly.


1. Architecture: what actually runs in “Vertex + TRL”

In the standard pattern that Google and Hugging Face document:

  1. You use a Hugging Face PyTorch Training DLC as the container image.

  2. On Vertex AI, you create a CustomContainerTrainingJob whose command runs the TRL CLI, typically:

    command=[
        "sh",
        "-c",
        'exec trl sft "$@"',
        "--",
    ]
    

    so Vertex executes trl sft inside the DLC. (Hugging Face)

  3. You pass TRL CLI arguments via the job’s args (e.g. --config=..., --model_name_or_path=...).

  4. Vertex automatically mounts your GCS bucket into the container at /gcs/<BUCKET>. (Hugging Face Forums)

From TRL’s perspective, once the container starts, this is just:

trl sft --config /gcs/my-bucket/configs/sft_config.yaml

running on a Linux machine. Everything about YAML parsing, dataset handling, etc., is exactly the same as on your local machine.

So the key pieces are:

  • The Python that launches the Vertex job
  • The YAML config consumed by TRL
  • The environment variables (HF token, HF cache, etc.)

2. How the YAML is parsed by TRL inside the Vertex container

The TRL CLI (trl sft, trl dpo, etc.) uses TrlParser, which is a thin extension around HfArgumentParser. (Hugging Face)

Mechanism:

  1. You call:

    trl sft --config /path/to/sft_config.yaml
    
  2. Internally:

    • TrlParser.parse_args_and_config():

      • Loads the YAML file.
      • Applies any env: block at the top level (sets environment variables).
      • Maps other top-level keys into dataclasses like SFTConfig, ScriptArguments, etc.
    • CLI flags (like --num_train_epochs) override values from the YAML. (Hugging Face)

A minimal example from the docs:

# config.yaml
env:
  VAR1: value1
arg1: 23
arg2: alpha
parser = TrlParser(dataclass_types=[MyArguments])
training_args = parser.parse_args_and_config()
# python main.py --config config.yaml  -> arg1=23, arg2='alpha', VAR1 in env

If you call python main.py --arg1 5 --arg2 beta, the CLI args override the YAML. (Hugging Face)

So:

  • YAML is authoritative unless overridden on the CLI.
  • All keys must match known dataclass fields (model_name_or_path, dataset_name, datasets, etc.).
  • env: must be at the top level.

This behaviour is identical in Vertex and locally.


3. Typical “Vertex + TRL + YAML” layout

3.1. YAML config on GCS

You place your config in GCS, for example:

gs://my-bucket/configs/sft_config.yaml

Inside the Vertex container, this is:

/gcs/my-bucket/configs/sft_config.yaml

You then reference this path in your job’s args:

args = [
    "--config=/gcs/my-bucket/configs/sft_config.yaml",
]

3.2. Example config for Vertex + TRL

A typical sft_config.yaml for Gemma/TxGemma on Vertex looks like:

# ---------------------------
# 1. Environment variables
# ---------------------------
env:
  HF_HOME: /root/.cache/huggingface

# You *can* put HF_TOKEN here, but on Vertex it is safer
# to inject it from the job's environment_variables (see below).

# ---------------------------
# 2. Model + training params
# ---------------------------
model_name_or_path: google/txgemma-2b-predict
output_dir: /gcs/my-bucket/outputs/txgemma-herg
overwrite_output_dir: true

max_seq_length: 1024
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
gradient_accumulation_steps: 8
num_train_epochs: 3
learning_rate: 5e-5
warmup_ratio: 0.05
weight_decay: 0.01
bf16: true

# ---------------------------
# 3. LoRA / PEFT
# ---------------------------
use_peft: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.1
lora_target_modules:
  - q_proj
  - v_proj
  - o_proj

# ---------------------------
# 4. Dataset (local/GCS)
# ---------------------------
datasets:
  - path: json
    data_files:
      train: /gcs/my-bucket/drug-herg/train.jsonl
      validation: /gcs/my-bucket/drug-herg/eval.jsonl
    split: train
    columns: [prompt, completion]

dataset_name: null       # ignored when datasets: is present
dataset_text_field: null # ignored

# ---------------------------
# 5. SFT-specific options
# ---------------------------
completion_only_loss: true

This structure matches TRL’s script utilities docs: datasets: is a mixture, with each entry mapping to datasets.load_dataset(path, data_files, ...). (Hugging Face)

On Vertex, GCS paths under /gcs/... look like local files to datasets.load_dataset. (Hugging Face Forums)


4. Vertex job definition: how YAML and environment tie together

In the Vertex notebook/script that launches training, you typically do:

from google.cloud import aiplatform
import os

aiplatform.init(
    project=os.getenv("PROJECT_ID"),
    location=os.getenv("LOCATION"),
    staging_bucket=os.getenv("BUCKET_URI"),
)

HF_TOKEN = "hf_..."  # ideally loaded from a secret or env var

job = aiplatform.CustomContainerTrainingJob(
    display_name="txgemma-2b-sft",
    container_uri=os.getenv("CONTAINER_URI"),  # HF PyTorch Training DLC
    command=[
        "sh",
        "-c",
        'exec trl sft "$@"',
        "--",
    ],
)

args = [
    "--config=/gcs/my-bucket/configs/sft_config.yaml",
]

job.run(
    args=args,
    replica_count=1,
    machine_type="g2-standard-12",
    accelerator_type="NVIDIA_L4",
    accelerator_count=1,
    base_output_dir=f"{os.getenv('BUCKET_URI')}/outputs/txgemma-2b-sft",
    environment_variables={
        "HF_TOKEN": HF_TOKEN,
        "HF_HOME": "/root/.cache/huggingface",
        "TRL_USE_RICH": "0",
        "ACCELERATE_LOG_LEVEL": "INFO",
        "TRANSFORMERS_LOG_LEVEL": "INFO",
    },
)

This pattern matches the official Mistral/Gemma TRL examples almost exactly; they simply swap in the appropriate model ID and dataset. (Hugging Face)

Key points:

  • YAML is loaded inside the container when trl sft --config ... runs.
  • environment_variables are set by Vertex before your command runs.
  • HF_TOKEN in environment_variables is the recommended way to authenticate to gated Hugging Face models from Vertex. (Hugging Face)

If YAML also has an env: block, those variables are set in addition to what Vertex provided, but in practice you normally prefer to inject secrets like HF_TOKEN via Vertex and use YAML env: for non-sensitive defaults.


5. Common “Vertex + TRL + YAML” failure modes

Given you suspect a parsing issue, here are the most common pitfalls in this exact setup.

5.1. Wrong path for --config in Vertex

  • You must pass the path as seen inside the container.
  • gs://... is not a valid local path inside the container; inside, it becomes /gcs/<BUCKET>/.... (Hugging Face Forums)
  • If you pass --config=gs://my-bucket/..., TRL will fail to open the file and ignore the config, falling back to CLI/defaults.

Check that:

args = ["--config=/gcs/my-bucket/configs/sft_config.yaml"]

and that this file exists (you can test with a tiny debug job that runs ls -R /gcs/my-bucket).

5.2. YAML keys not matching TRL’s expected names

For your version of TRL, the script utilities docs at that version are the source of truth. (Hugging Face)

Common mistakes:

  • Using train_dataset instead of datasets / dataset_name.
  • Misspelling keys like completion_only_loss, per_device_train_batch_size.
  • Nesting keys under sub-dicts that TRL doesn’t know (e.g. training: {max_seq_length: 1024} instead of max_seq_length: 1024 at top level).

If a key is unknown, TrlParser may:

  • Throw an “unrecognized argument” error (best case), or
  • Simply not set that field, leaving the default (worst case, more confusing).

5.3. Mixing datasets: and dataset_name incorrectly

In current TRL, when you provide datasets: (mixture), it is supposed to be used instead of dataset_name. (Hugging Face)

However, some earlier CLI versions had issues when:

  • YAML defined datasets:, but
  • The CLI still had --dataset_name or default assumptions, leading to errors like “the following arguments are required: --dataset_name”.

Safer patterns:

  • If you use datasets: in YAML, do not pass --dataset_name on the CLI.
  • In YAML, explicitly set dataset_name: null (or omit it), and rely only on datasets:.

5.4. env: block mis-indented or unused

To have TRL set env vars from YAML, env: must be at the top level. (Hugging Face)

Correct:

env:
  HF_HOME: /root/.cache/huggingface
  SOME_FLAG: "1"
model_name_or_path: ...

Incorrect (and ignored):

training:
  env:
    HF_HOME: /root/.cache/huggingface

For HF_TOKEN, it is usually better to inject it via Vertex environment_variables rather than YAML, because:

  • You can use secrets / runtime injection.
  • You do not need to put secrets into a file stored in GCS.

5.5. TRL version mismatch inside the DLC

It’s easy to end up with:

  • HF PyTorch DLC’s pre-installed TRL version, plus
  • Another TRL version pulled at runtime (pip install trl -U in a startup script).

If you follow docs for TRL 0.15 but the container uses 0.8 or a dev version, YAML keys or semantics may differ. Issues like “TrlParser not working with --config” have been reported around version changes. (GitHub)

Check inside the container (in a small debug job):

python -c "import trl, inspect; print('TRL version:', trl.__version__)"

and then consult the docs for that exact version.


6. Practical debugging loop for your case

Given where you are now, a targeted debugging sequence in Vertex would be:

  1. Debug job 1: Verify config path and TRL version

    Submit a tiny job whose command is:

    sh -c 'ls -R /gcs && python -c "import trl; print(trl.__version__)"'
    

    to ensure:

    • /gcs/my-bucket/configs/sft_config.yaml exists.
    • You know the TRL version.
  2. Debug job 2: Minimal TRL + YAML

    Put a minimal YAML on GCS:

    model_name_or_path: google/txgemma-2b-predict
    dataset_name: stanfordnlp/imdb
    output_dir: /gcs/my-bucket/debug-out
    num_train_epochs: 1
    per_device_train_batch_size: 1
    

    Run:

    args = ["--config=/gcs/my-bucket/configs/minimal.yaml"]
    

    and confirm the job starts and uses those values. If this works, you know YAML parsing is fine in principle.

  3. Debug job 3: Add datasets: and LoRA step by step

    Incrementally extend the YAML with:

    • datasets: with a local JSONL dataset.
    • LoRA settings.
    • completion_only_loss.

    If at any step something breaks, you can isolate exactly which section is mis-specified.


Summary for “Vertex + TRL” specifically

  • In Vertex custom training with HF PyTorch DLC, trl sft runs inside the container exactly as it would locally.

  • You pass a YAML config via --config=/gcs/<BUCKET>/.../sft_config.yaml.

  • TRL’s TrlParser reads that YAML, sets env: variables, and maps keys into SFTConfig etc. CLI args override YAML. (Hugging Face)

  • Common Vertex-specific pitfalls are:

    • Using gs://... instead of /gcs/... for --config and data_files. (Hugging Face Forums)
    • Mixing dataset_name and datasets: incorrectly.
    • Mis-indented or wrong YAML keys.
    • TRL version mismatch inside the DLC.
  • For gated models (TxGemma), pass your HF token via Vertex environment_variables (HF_TOKEN) and let TRL/Transformers pick it up in the container.