configs for reproducibility
Browse files- create_dataset.py +90 -0
- stage1/open-stage1.py +3 -0
- stage2/open-stage2.py +3 -0
- stage2/open-stage2.toml +1 -1
- stage3/open-stage3.py +3 -0
- stage3/open-stage3.toml +1 -1
create_dataset.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import datasets
|
| 3 |
+
import importlib
|
| 4 |
+
import tqdm
|
| 5 |
+
import transformers
|
| 6 |
+
import typer
|
| 7 |
+
|
| 8 |
+
def load_config(config_file: str):
|
| 9 |
+
spec = importlib.util.spec_from_file_location("config", config_file)
|
| 10 |
+
config_module = importlib.util.module_from_spec(spec)
|
| 11 |
+
spec.loader.exec_module(config_module)
|
| 12 |
+
return config_module.sources, config_module.tokenizer_name, config_module.prefix
|
| 13 |
+
|
| 14 |
+
def tokenize(batch: dict):
|
| 15 |
+
if tokenizer:
|
| 16 |
+
return {"num_tokens": tokenizer(batch["text"], padding="do_not_pad", return_length=True)["length"]}
|
| 17 |
+
return {"num_tokens": 0}
|
| 18 |
+
|
| 19 |
+
def shard_indices(shard_index):
|
| 20 |
+
if not isinstance(shard_index, list):
|
| 21 |
+
shard_index = [shard_index]
|
| 22 |
+
return shard_index
|
| 23 |
+
|
| 24 |
+
def preprocess_shard(ds: datasets.Dataset, num_shards: int, index: int, num_proc: int):
|
| 25 |
+
shard = ds.shard(num_shards=num_shards, index=index, contiguous=True)
|
| 26 |
+
shard = shard.flatten_indices()
|
| 27 |
+
shard = shard.map(tokenize, batched=True, batch_size=1000, num_proc=num_proc)
|
| 28 |
+
return shard
|
| 29 |
+
|
| 30 |
+
def preprocess_subset(weights: dict, subsets: list, source: str, src_info: dict, dc: datasets.DownloadConfig, num_proc: int):
|
| 31 |
+
for key, frac in tqdm.tqdm(weights.items(), desc="Loading train subsets"):
|
| 32 |
+
uri_template = src_info["uri"]
|
| 33 |
+
print(f" Loading subset: {key} with fraction 1/{frac} from {uri_template.format(key=key)}")
|
| 34 |
+
ds = datasets.load_dataset(
|
| 35 |
+
src_info["format"],
|
| 36 |
+
data_files=uri_template.format(key=key),
|
| 37 |
+
split="train",
|
| 38 |
+
download_config=dc,
|
| 39 |
+
)
|
| 40 |
+
ds = ds.select_columns(["text"])
|
| 41 |
+
ds = ds.add_column("source", [source] * len(ds))
|
| 42 |
+
ds = ds.add_column("subset", [key] * len(ds))
|
| 43 |
+
ds = ds.shuffle(seed=42)
|
| 44 |
+
dss = [preprocess_shard(ds, int(src_info["shards"]/frac), i, num_proc) for i in shard_indices(src_info["shard_index"])]
|
| 45 |
+
ds = datasets.concatenate_datasets(dss)
|
| 46 |
+
ds = ds.cast_column("text", datasets.Value("large_string"))
|
| 47 |
+
print(f" Finished preprocessing subset: {key} with {sum(ds['num_tokens'])} tokens")
|
| 48 |
+
subsets.append(ds)
|
| 49 |
+
|
| 50 |
+
def main(
|
| 51 |
+
config_file: str,
|
| 52 |
+
num_proc: int = 96,
|
| 53 |
+
max_retries: int = 10,
|
| 54 |
+
):
|
| 55 |
+
sources, tokenizer_name, prefix = load_config(config_file)
|
| 56 |
+
global tokenizer
|
| 57 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer_name else None
|
| 58 |
+
dc = datasets.DownloadConfig(num_proc=num_proc, max_retries=max_retries)
|
| 59 |
+
train_subsets = []
|
| 60 |
+
test_subsets = []
|
| 61 |
+
file_name = f"{prefix}-"
|
| 62 |
+
for source, src_info in sources.items():
|
| 63 |
+
print(f"Processing source: {source}")
|
| 64 |
+
shard_index = src_info["shard_index"]
|
| 65 |
+
if not isinstance(shard_index, list):
|
| 66 |
+
shard_index = [shard_index]
|
| 67 |
+
file_name += f"{source}-{'_'.join(str(s) for s in shard_index)}-of-{src_info['shards']}-"
|
| 68 |
+
preprocess_subset(src_info["train"], train_subsets, source, src_info, dc, num_proc)
|
| 69 |
+
preprocess_subset(src_info["test"], test_subsets, source, src_info, dc, num_proc)
|
| 70 |
+
print("Concatenating train subsets")
|
| 71 |
+
final_train = datasets.concatenate_datasets(train_subsets)
|
| 72 |
+
print("Shuffling final train dataset")
|
| 73 |
+
final_train = final_train.shuffle(seed=42)
|
| 74 |
+
print("Flattening final train dataset")
|
| 75 |
+
final_train = final_train.flatten_indices()
|
| 76 |
+
print("Concatenating test subsets")
|
| 77 |
+
final_test = datasets.concatenate_datasets(test_subsets)
|
| 78 |
+
print("Shuffling final test dataset")
|
| 79 |
+
final_test = final_test.shuffle(seed=42)
|
| 80 |
+
print("Flattening final test dataset")
|
| 81 |
+
final_test = final_test.flatten_indices()
|
| 82 |
+
test_file = f"{file_name}test/{file_name}test.parquet"
|
| 83 |
+
print(f"Writing final test dataset with {sum(final_test['num_tokens'])} tokens to {test_file}")
|
| 84 |
+
final_test.to_parquet(test_file)
|
| 85 |
+
train_file = f"{file_name}train/{file_name}train.parquet"
|
| 86 |
+
print(f"Writing final train dataset with {sum(final_train['num_tokens'])} tokens to {train_file}")
|
| 87 |
+
final_train.to_parquet(train_file)
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
typer.run(main)
|
stage1/open-stage1.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
dyna_train = {
|
| 2 |
"adl": 1.0,
|
| 3 |
"ai-aktindsigt": 1.0,
|
|
|
|
| 1 |
+
prefix = "munin-open"
|
| 2 |
+
tokenizer_name = "common-pile/comma-v0.1-2t"
|
| 3 |
+
|
| 4 |
dyna_train = {
|
| 5 |
"adl": 1.0,
|
| 6 |
"ai-aktindsigt": 1.0,
|
stage2/open-stage2.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
dyna_train = {
|
| 2 |
"adl": 1.0,
|
| 3 |
"ai-aktindsigt": 1.0,
|
|
|
|
| 1 |
+
prefix = "munin-open"
|
| 2 |
+
tokenizer_name = "common-pile/comma-v0.1-2t"
|
| 3 |
+
|
| 4 |
dyna_train = {
|
| 5 |
"adl": 1.0,
|
| 6 |
"ai-aktindsigt": 1.0,
|
stage2/open-stage2.toml
CHANGED
|
@@ -32,7 +32,7 @@ selective_ac_option = "op"
|
|
| 32 |
bos_token = 2
|
| 33 |
eos_token = 1
|
| 34 |
data_dirs = [
|
| 35 |
-
"/work/production/data/
|
| 36 |
]
|
| 37 |
dataset_weights = "1.0"
|
| 38 |
|
|
|
|
| 32 |
bos_token = 2
|
| 33 |
eos_token = 1
|
| 34 |
data_dirs = [
|
| 35 |
+
"/work/production/data/munin-open-dyna-0-of-1-cp-1-of-16-train/",
|
| 36 |
]
|
| 37 |
dataset_weights = "1.0"
|
| 38 |
|
stage3/open-stage3.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
dyna_train = {
|
| 2 |
"adl": 1.0,
|
| 3 |
"ai-aktindsigt": 1.0,
|
|
|
|
| 1 |
+
prefix = "munin-open"
|
| 2 |
+
tokenizer_name = "common-pile/comma-v0.1-2t"
|
| 3 |
+
|
| 4 |
dyna_train = {
|
| 5 |
"adl": 1.0,
|
| 6 |
"ai-aktindsigt": 1.0,
|
stage3/open-stage3.toml
CHANGED
|
@@ -32,7 +32,7 @@ selective_ac_option = "op"
|
|
| 32 |
bos_token = 2
|
| 33 |
eos_token = 1
|
| 34 |
data_dirs = [
|
| 35 |
-
"/work/production/data/
|
| 36 |
]
|
| 37 |
dataset_weights = "1.0"
|
| 38 |
|
|
|
|
| 32 |
bos_token = 2
|
| 33 |
eos_token = 1
|
| 34 |
data_dirs = [
|
| 35 |
+
"/work/production/data/munin-open-dyna-0-of-1-cp-2-of-16-train/",
|
| 36 |
]
|
| 37 |
dataset_weights = "1.0"
|
| 38 |
|