app / llamafactory /v1 /core /data_engine.py
lincarl's picture
Upload folder using huggingface_hub
b67858f verified
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import AsyncIterable, Iterable
from typing import Any, Union
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
"""Data engine."""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
"""Data arguments."""
self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)"""
self.dataset_infos: dict[str, DatasetInfo] = {}
"""Dict of (dataset_name, dataset_info)"""
self.data_index: list[tuple[str, int]] = []
"""List of (dataset_name, sample_index)"""
self.streaming: bool = False
"""Whether dataset is streaming."""
self.get_dataset_info()
self.load_dataset()
self.build_data_index()
def get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(
os.path.join(self.args.dataset_dir, self.args.dataset)
): # local file
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
def load_dataset(self) -> None:
"""Load datasets according to dataset info."""
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
def build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""Convert dataset sample.
Args:
raw_sample (dict[str, Any]): Raw dataset sample.
dataset_name (str): Dataset name.
Returns:
Sample: Dataset sample.
"""
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
from ..plugins.data_plugins.converter import get_converter
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
else:
return {"_dataset_name": dataset_name, **raw_sample}
def __len__(self) -> int:
"""Get dataset length.
Returns:
int: Dataset length.
"""
if self.streaming:
return -1
else:
return len(self.data_index)
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
"""Get dataset item.
Args:
index (int): Dataset index.
Returns:
Sample: Dataset item.
"""
if self.streaming:
raise ValueError("Streaming dataset does not support index access.")
if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else:
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
for dataset_name, sample_index in selected_index
]
else:
dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable:
"""Get dataset iterator.
Returns:
Iterable: Dataset iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable:
"""Get dataset async iterator.
Returns:
AsyncIterable: Dataset async iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError()
if __name__ == "__main__":
from ..config.parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)
print(data_engine[0])