File size: 4,322 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from pathlib import Path

from huggingface_hub import snapshot_download
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import torch

from agent import Agent
from envs import WorldModelEnv
from game import Game, PlayEnv


OmegaConf.register_new_resolver("eval", eval)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.")
    parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.")
    parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.")
    parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.")
    parser.add_argument("--size-multiplier", type=int, default=2, help="Multiplication factor for the screen size.")
    parser.add_argument("--compile", action="store_true", help="Turn on model compilation.")
    parser.add_argument("--fps", type=int, default=15, help="Frame rate.")
    parser.add_argument("--no-header", action="store_true")
    return parser.parse_args()


def check_args(args: argparse.Namespace) -> None:
    if not args.record and (args.store_denoising_trajectory or args.store_original_obs):
        print("Warning: not in recording mode, ignoring --store* options")
    return True


def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv:

    #path_hf = Path(snapshot_download(repo_id="eloialonso/diamond", allow_patterns="csgo/*"))

    path_ckpt = Path("/home/alienware3/Documents/diamond/agent_epoch_00003.pt")
    spawn_dir = Path("/home/alienware3/Documents/diamond/csgo/spawn")

    # Override config
    cfg.agent = OmegaConf.load("config/agent/csgo.yaml")
    cfg.env = OmegaConf.load("config/env/csgo.yaml")

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print("----------------------------------------------------------------------")
    print(f"Using {device} for rendering.")
    if not torch.cuda.is_available():
        print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.")
    print("----------------------------------------------------------------------")

    assert cfg.env.train.id == "csgo"
    num_actions = cfg.env.num_actions

    # Models
    agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
    agent.load(path_ckpt)
    
    # World model environment
    sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
    if agent.upsampler is not None:
        sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
    wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
    wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
    
    if device.type == "cuda" and args.compile:
        print("Compiling models...")
        wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
        wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")

    play_env = PlayEnv(
        agent,
        wm_env,
        args.record,
        args.store_denoising_trajectory,
        args.store_original_obs,
    )

    return play_env


@torch.no_grad()
def main():
    args = parse_args()
    ok = check_args(args)
    if not ok:
        return

    with initialize(version_base="1.3", config_path="../config"):
        cfg = compose(config_name="trainer")

    # window size
    h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size
    size_h, size_w = h * args.size_multiplier, w * args.size_multiplier
    env = prepare_play_mode(cfg, args)
    game = Game(env, (size_h, size_w), args.mouse_multiplier, fps=args.fps, verbose=not args.no_header)
    game.run()


if __name__ == "__main__":
    main()