musictimer commited on
Commit
a29f249
·
1 Parent(s): b0d3efd
app.py CHANGED
@@ -99,12 +99,20 @@ class WebGameEngine:
99
  def progress_hook(block_num, block_size, total_size):
100
  if total_size > 0:
101
  progress = min(100, (block_num * block_size * 100) / total_size)
102
- self.download_progress = int(progress)
103
- if progress % 10 == 0: # Log every 10%
104
- logger.info(f"Download progress: {self.download_progress}%")
 
 
 
 
 
 
 
105
 
106
  urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
107
  self.download_progress = 100
 
108
 
109
  # Run download in thread pool to avoid blocking
110
  loop = asyncio.get_event_loop()
@@ -142,21 +150,14 @@ class WebGameEngine:
142
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
  logger.info(f"Using device: {device}")
144
 
145
- # Load model checkpoint
146
- checkpoint_path = web_config.get_checkpoint_path()
147
- if not checkpoint_path.exists():
148
- logger.warning(f"No checkpoint found at {checkpoint_path} - using dummy mode")
149
- self._init_dummy_mode()
150
- return True
151
 
152
  # Get spawn directory
153
  spawn_dir = web_config.get_spawn_dir()
154
 
155
- # Initialize agent
156
- num_actions = cfg.env.num_actions
157
- agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
158
-
159
- # Try to load checkpoint (remote or local)
160
  try:
161
  # First try to download from Hugging Face Hub using direct URL
162
  try:
@@ -192,6 +193,7 @@ class WebGameEngine:
192
  logger.warning(f"Failed to download from HF Hub: {hub_error}")
193
 
194
  # Fallback to local checkpoint if available
 
195
  if checkpoint_path.exists():
196
  logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
197
  agent.load(checkpoint_path)
 
99
  def progress_hook(block_num, block_size, total_size):
100
  if total_size > 0:
101
  progress = min(100, (block_num * block_size * 100) / total_size)
102
+ new_progress = int(progress)
103
+
104
+ # Update progress more frequently for smooth progress bar
105
+ if new_progress != self.download_progress:
106
+ self.download_progress = new_progress
107
+ self.loading_status = f"Downloading AI model ({self.download_progress}%)"
108
+
109
+ # Log every 5% instead of 10% for better feedback
110
+ if self.download_progress % 5 == 0:
111
+ logger.info(f"Download progress: {self.download_progress}%")
112
 
113
  urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
114
  self.download_progress = 100
115
+ self.loading_status = "Download complete! Loading model..."
116
 
117
  # Run download in thread pool to avoid blocking
118
  loop = asyncio.get_event_loop()
 
150
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151
  logger.info(f"Using device: {device}")
152
 
153
+ # Initialize agent first
154
+ num_actions = cfg.env.num_actions
155
+ agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
 
 
 
156
 
157
  # Get spawn directory
158
  spawn_dir = web_config.get_spawn_dir()
159
 
160
+ # Try to load checkpoint (remote first, then local, then dummy mode)
 
 
 
 
161
  try:
162
  # First try to download from Hugging Face Hub using direct URL
163
  try:
 
193
  logger.warning(f"Failed to download from HF Hub: {hub_error}")
194
 
195
  # Fallback to local checkpoint if available
196
+ checkpoint_path = web_config.get_checkpoint_path()
197
  if checkpoint_path.exists():
198
  logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
199
  agent.load(checkpoint_path)
config/agent/csgo.yaml CHANGED
@@ -1,13 +1,13 @@
1
- _target_: agent.AgentConfig
2
 
3
  denoiser:
4
- _target_: models.diffusion.DenoiserConfig
5
  sigma_data: 0.5
6
  sigma_offset_noise: 0.1
7
  noise_previous_obs: true
8
  upsampling_factor: null
9
  inner_model:
10
- _target_: models.diffusion.InnerModelConfig
11
  img_channels: 3
12
  num_steps_conditioning: 4
13
  cond_channels: 2048
@@ -16,13 +16,13 @@ denoiser:
16
  attn_depths: [0, 0, 1, 1]
17
 
18
  upsampler:
19
- _target_: models.diffusion.DenoiserConfig
20
  sigma_data: 0.5
21
  sigma_offset_noise: 0.1
22
  noise_previous_obs: false
23
  upsampling_factor: 5
24
  inner_model:
25
- _target_: models.diffusion.InnerModelConfig
26
  img_channels: 3
27
  num_steps_conditioning: 1
28
  cond_channels: 2048
 
1
+ _target_: src.agent.AgentConfig
2
 
3
  denoiser:
4
+ _target_: src.models.diffusion.DenoiserConfig
5
  sigma_data: 0.5
6
  sigma_offset_noise: 0.1
7
  noise_previous_obs: true
8
  upsampling_factor: null
9
  inner_model:
10
+ _target_: src.models.diffusion.InnerModelConfig
11
  img_channels: 3
12
  num_steps_conditioning: 4
13
  cond_channels: 2048
 
16
  attn_depths: [0, 0, 1, 1]
17
 
18
  upsampler:
19
+ _target_: src.models.diffusion.DenoiserConfig
20
  sigma_data: 0.5
21
  sigma_offset_noise: 0.1
22
  noise_previous_obs: false
23
  upsampling_factor: 5
24
  inner_model:
25
+ _target_: src.models.diffusion.InnerModelConfig
26
  img_channels: 3
27
  num_steps_conditioning: 1
28
  cond_channels: 2048
config/world_model_env/fast.yaml CHANGED
@@ -1,15 +1,15 @@
1
- _target_: envs.WorldModelEnvConfig
2
  horizon: 1000
3
  num_batches_to_preload: 1
4
  diffusion_sampler_next_obs:
5
- _target_: models.diffusion.DiffusionSamplerConfig
6
  num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
7
  sigma_min: 0.002
8
  sigma_max: 5.0
9
  rho: 7
10
  order: 1
11
  diffusion_sampler_upsampling:
12
- _target_: models.diffusion.DiffusionSamplerConfig
13
  num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
14
  sigma_min: 0.002
15
  sigma_max: 5.0
 
1
+ _target_: src.envs.WorldModelEnvConfig
2
  horizon: 1000
3
  num_batches_to_preload: 1
4
  diffusion_sampler_next_obs:
5
+ _target_: src.models.diffusion.DiffusionSamplerConfig
6
  num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
7
  sigma_min: 0.002
8
  sigma_max: 5.0
9
  rho: 7
10
  order: 1
11
  diffusion_sampler_upsampling:
12
+ _target_: src.models.diffusion.DiffusionSamplerConfig
13
  num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
14
  sigma_min: 0.002
15
  sigma_max: 5.0
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  torch>=1.13.0
3
  torchvision>=0.14.0
4
  torchaudio>=0.13.0
5
- numpy>=1.21.0
6
 
7
  # Configuration management
8
  hydra-core>=1.2.0
@@ -28,8 +28,8 @@ h5py>=3.7.0
28
  ale_py>=0.8.0
29
  gymnasium>=0.28.0
30
 
31
- # Experiment tracking (required by utils.py)
32
- wandb>=0.13.0
33
 
34
  # Metrics (required by rew_end_model.py)
35
  torcheval>=0.0.6
 
2
  torch>=1.13.0
3
  torchvision>=0.14.0
4
  torchaudio>=0.13.0
5
+ numpy>=1.21.0,<2.0.0
6
 
7
  # Configuration management
8
  hydra-core>=1.2.0
 
28
  ale_py>=0.8.0
29
  gymnasium>=0.28.0
30
 
31
+ # Experiment tracking (optional, for training only)
32
+ # wandb>=0.13.0 # Commented out due to NumPy 2.0 compatibility issues
33
 
34
  # Metrics (required by rew_end_model.py)
35
  torcheval>=0.0.6
src/__pycache__/agent.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/agent.cpython-310.pyc and b/src/__pycache__/agent.cpython-310.pyc differ
 
src/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/utils.cpython-310.pyc and b/src/__pycache__/utils.cpython-310.pyc differ
 
src/coroutines/__pycache__/env_loop.cpython-310.pyc CHANGED
Binary files a/src/coroutines/__pycache__/env_loop.cpython-310.pyc and b/src/coroutines/__pycache__/env_loop.cpython-310.pyc differ
 
src/csgo/__pycache__/web_action_processing.cpython-310.pyc CHANGED
Binary files a/src/csgo/__pycache__/web_action_processing.cpython-310.pyc and b/src/csgo/__pycache__/web_action_processing.cpython-310.pyc differ
 
src/data/__pycache__/dataset.cpython-310.pyc CHANGED
Binary files a/src/data/__pycache__/dataset.cpython-310.pyc and b/src/data/__pycache__/dataset.cpython-310.pyc differ
 
src/envs/__pycache__/world_model_env.cpython-310.pyc CHANGED
Binary files a/src/envs/__pycache__/world_model_env.cpython-310.pyc and b/src/envs/__pycache__/world_model_env.cpython-310.pyc differ
 
src/game/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/__init__.cpython-310.pyc and b/src/game/__pycache__/__init__.cpython-310.pyc differ
 
src/game/__pycache__/web_play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
 
src/models/__pycache__/actor_critic.cpython-310.pyc CHANGED
Binary files a/src/models/__pycache__/actor_critic.cpython-310.pyc and b/src/models/__pycache__/actor_critic.cpython-310.pyc differ
 
src/models/__pycache__/rew_end_model.cpython-310.pyc CHANGED
Binary files a/src/models/__pycache__/rew_end_model.cpython-310.pyc and b/src/models/__pycache__/rew_end_model.cpython-310.pyc differ
 
src/models/diffusion/__pycache__/denoiser.cpython-310.pyc CHANGED
Binary files a/src/models/diffusion/__pycache__/denoiser.cpython-310.pyc and b/src/models/diffusion/__pycache__/denoiser.cpython-310.pyc differ
 
src/utils.py CHANGED
@@ -17,7 +17,13 @@ from torch.optim.lr_scheduler import LambdaLR
17
  import torch.nn as nn
18
  from torch.nn.parallel import DistributedDataParallel as DDP
19
  from torch.optim import AdamW
20
- import wandb
 
 
 
 
 
 
21
 
22
 
23
  ATARI_100K_GAMES = [
@@ -275,8 +281,12 @@ def prompt_atari_game():
275
 
276
  def prompt_run_name(game):
277
  cfg_file = Path("config/trainer.yaml")
278
- cfg_name = OmegaConf.load(cfg_file).wandb.name
279
- suffix = f"-{cfg_name}" if cfg_name is not None else ""
 
 
 
 
280
  name = game + suffix
281
  name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n")
282
  if name_ != "":
@@ -329,5 +339,7 @@ def try_until_no_except(func: Callable) -> None:
329
 
330
 
331
  def wandb_log(logs: Logs, epoch: int):
332
- for d in logs:
333
- wandb.log({"epoch": epoch, **d})
 
 
 
17
  import torch.nn as nn
18
  from torch.nn.parallel import DistributedDataParallel as DDP
19
  from torch.optim import AdamW
20
+ try:
21
+ import wandb
22
+ WANDB_AVAILABLE = True
23
+ except ImportError:
24
+ # wandb not available, set to None for graceful fallback
25
+ wandb = None
26
+ WANDB_AVAILABLE = False
27
 
28
 
29
  ATARI_100K_GAMES = [
 
281
 
282
  def prompt_run_name(game):
283
  cfg_file = Path("config/trainer.yaml")
284
+ try:
285
+ cfg_name = OmegaConf.load(cfg_file).wandb.name
286
+ suffix = f"-{cfg_name}" if cfg_name is not None else ""
287
+ except:
288
+ # If wandb config not available, use empty suffix
289
+ suffix = ""
290
  name = game + suffix
291
  name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n")
292
  if name_ != "":
 
339
 
340
 
341
  def wandb_log(logs: Logs, epoch: int):
342
+ if WANDB_AVAILABLE and wandb is not None:
343
+ for d in logs:
344
+ wandb.log({"epoch": epoch, **d})
345
+ # If wandb not available, silently skip logging