File size: 4,800 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
IMPORTANT:
This script prepares data downloaded from the OneDrive link provided on the repo that introduced the dataset: https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/
=> Any issue related to the download of this data should be reported on the dataset repo linked above (NOT on DIAMOND's repo)

This script should be called with exactly 2 positional arguments:

- <tar_dir>: folder containing the .tar files from `dataset_dm_scraped_dust2_tars` folder on the OneDrive
- <out_dir>: a new dir (should not exist already), the script will untar and process data there
"""

import argparse
from functools import partial
from pathlib import Path
from multiprocessing import Pool
import shutil
import subprocess

import torch
import torchvision.transforms.functional as T
from tqdm import tqdm

from data.dataset import Dataset, CSGOHdf5Dataset
from data.episode import Episode
from data.segment import SegmentId


# PREFIX = "hdf5_dm_july2021_"
PREFIX = "hdf5_highwayenv_"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "tar_dir",
        type=Path,
        default="highway_data_tar",
        help="folder containing the .tar files from `dataset_dm_scraped_dust2_tars` folder on the OneDrive",
    )
    parser.add_argument(
        "out_dir",
        type=Path,
        default="highway_dataset_preprocessed",
        help="a new directory (should not exist already), the script will untar and process data there",
    )
    return parser.parse_args()


def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None:
    d = path_tar.stem
    assert path_tar.stem.startswith(PREFIX)
    d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_"))
    d.mkdir(exist_ok=False, parents=True)
    shutil.move(path_tar, d)
    subprocess.run(f"cd {d} && tar -xvf {path_tar.name}", shell=True)
    new_path_tar = d / path_tar.name
    if remove_tar:
        new_path_tar.unlink()
    else:
        shutil.move(new_path_tar, path_tar.parent)


def main():
    args = parse_args()

    tar_dir = args.tar_dir.absolute()
    out_dir = args.out_dir.absolute()

    if not tar_dir.exists():
        print(
            "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)"
        )
        return

    if out_dir.exists():
        print(f"Wrong usage: the output directory should not exist ({args.out_dir})")
        return

    with Path("test_split_highway.txt").open("r") as f:
        test_files = f.read().split("\n")

    full_res_dir = out_dir / "full_res"
    low_res_dir = out_dir / "low_res"

    tar_files = [
        x for x in tar_dir.iterdir() if x.suffix == ".tar" and x.stem.startswith(PREFIX)
    ]
    n = len(tar_files)

    if (
        n < 28
        and input(
            f"Found only {n} .tar files instead of 28, so it looks like the data has not been entirely downloaded, continue? [y|N] "
        )
        != "y"
    ):
        return

    str_files = "\n".join(map(str, tar_files))
    print(f"Ready to untar {n} tar files:\n{str_files}")

    remove_tar = (
        input("Remove .tar files once they are processed? [y|N] ").lower() == "y"
    )

    # Untar CSGO files
    f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar)
    with Pool(n) as p:
        p.map(f, tar_files)

    print(f"{n} .tar files unpacked in {full_res_dir}")

    #
    # Create low-res data
    #

    csgo_dataset = CSGOHdf5Dataset(full_res_dir)

    train_dataset = Dataset(low_res_dir / "train", None)
    test_dataset = Dataset(low_res_dir / "test", None)

    for i in tqdm(csgo_dataset._filenames, desc="Creating low_res"):
        episode = Episode(
            **{
                k: v
                for k, v in csgo_dataset[SegmentId(i, 0, 1000)].__dict__.items()
                if k not in ("mask_padding", "id")
            }
        )
        # episode.obs = T.resize(
        #     episode.obs, (30, 56), interpolation=T.InterpolationMode.BICUBIC
        # )
        episode.obs = T.resize(
            episode.obs, (30, 120), interpolation=T.InterpolationMode.BICUBIC
        )
        filename = csgo_dataset._filenames[i]
        file_id = f"{filename.parent.stem}/{filename.name}"
        episode.info = {"original_file_id": file_id}
        dataset = test_dataset if filename.name in test_files else train_dataset
        dataset.add_episode(episode)

    train_dataset.save_to_default_path()
    test_dataset.save_to_default_path()

    print(
        f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n"
    )

    print("You can now edit `config/env/csgo.yaml` and set:")
    print(f"path_data_low_res: {low_res_dir}")
    print(f"path_data_full_res: {full_res_dir}")


if __name__ == "__main__":
    main()