""" IMPORTANT: This script prepares data downloaded from the OneDrive link provided on the repo that introduced the dataset: https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/ => Any issue related to the download of this data should be reported on the dataset repo linked above (NOT on DIAMOND's repo) This script should be called with exactly 2 positional arguments: - : folder containing the .tar files from `dataset_dm_scraped_dust2_tars` folder on the OneDrive - : a new dir (should not exist already), the script will untar and process data there """ import argparse from functools import partial from pathlib import Path from multiprocessing import Pool import shutil import subprocess import torch import torchvision.transforms.functional as T from tqdm import tqdm from data.dataset import Dataset, CSGOHdf5Dataset from data.episode import Episode from data.segment import SegmentId # PREFIX = "hdf5_dm_july2021_" PREFIX = "hdf5_highwayenv_" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "tar_dir", type=Path, default="highway_data_tar", help="folder containing the .tar files from `dataset_dm_scraped_dust2_tars` folder on the OneDrive", ) parser.add_argument( "out_dir", type=Path, default="highway_dataset_preprocessed", help="a new directory (should not exist already), the script will untar and process data there", ) return parser.parse_args() def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None: d = path_tar.stem assert path_tar.stem.startswith(PREFIX) d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_")) d.mkdir(exist_ok=False, parents=True) shutil.move(path_tar, d) subprocess.run(f"cd {d} && tar -xvf {path_tar.name}", shell=True) new_path_tar = d / path_tar.name if remove_tar: new_path_tar.unlink() else: shutil.move(new_path_tar, path_tar.parent) def main(): args = parse_args() tar_dir = args.tar_dir.absolute() out_dir = args.out_dir.absolute() if not tar_dir.exists(): print( "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)" ) return if out_dir.exists(): print(f"Wrong usage: the output directory should not exist ({args.out_dir})") return with Path("test_split_highway.txt").open("r") as f: test_files = f.read().split("\n") full_res_dir = out_dir / "full_res" low_res_dir = out_dir / "low_res" tar_files = [ x for x in tar_dir.iterdir() if x.suffix == ".tar" and x.stem.startswith(PREFIX) ] n = len(tar_files) if ( n < 28 and input( f"Found only {n} .tar files instead of 28, so it looks like the data has not been entirely downloaded, continue? [y|N] " ) != "y" ): return str_files = "\n".join(map(str, tar_files)) print(f"Ready to untar {n} tar files:\n{str_files}") remove_tar = ( input("Remove .tar files once they are processed? [y|N] ").lower() == "y" ) # Untar CSGO files f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar) with Pool(n) as p: p.map(f, tar_files) print(f"{n} .tar files unpacked in {full_res_dir}") # # Create low-res data # csgo_dataset = CSGOHdf5Dataset(full_res_dir) train_dataset = Dataset(low_res_dir / "train", None) test_dataset = Dataset(low_res_dir / "test", None) for i in tqdm(csgo_dataset._filenames, desc="Creating low_res"): episode = Episode( **{ k: v for k, v in csgo_dataset[SegmentId(i, 0, 1000)].__dict__.items() if k not in ("mask_padding", "id") } ) # episode.obs = T.resize( # episode.obs, (30, 56), interpolation=T.InterpolationMode.BICUBIC # ) episode.obs = T.resize( episode.obs, (30, 120), interpolation=T.InterpolationMode.BICUBIC ) filename = csgo_dataset._filenames[i] file_id = f"{filename.parent.stem}/{filename.name}" episode.info = {"original_file_id": file_id} dataset = test_dataset if filename.name in test_files else train_dataset dataset.add_episode(episode) train_dataset.save_to_default_path() test_dataset.save_to_default_path() print( f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n" ) print("You can now edit `config/env/csgo.yaml` and set:") print(f"path_data_low_res: {low_res_dir}") print(f"path_data_full_res: {full_res_dir}") if __name__ == "__main__": main()