Spaces:
Sleeping
Sleeping
Commit
·
a29f249
1
Parent(s):
b0d3efd
Fix bug 3
Browse files- app.py +16 -14
- config/agent/csgo.yaml +5 -5
- config/world_model_env/fast.yaml +3 -3
- requirements.txt +3 -3
- src/__pycache__/agent.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/coroutines/__pycache__/env_loop.cpython-310.pyc +0 -0
- src/csgo/__pycache__/web_action_processing.cpython-310.pyc +0 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/envs/__pycache__/world_model_env.cpython-310.pyc +0 -0
- src/game/__pycache__/__init__.cpython-310.pyc +0 -0
- src/game/__pycache__/web_play_env.cpython-310.pyc +0 -0
- src/models/__pycache__/actor_critic.cpython-310.pyc +0 -0
- src/models/__pycache__/rew_end_model.cpython-310.pyc +0 -0
- src/models/diffusion/__pycache__/denoiser.cpython-310.pyc +0 -0
- src/utils.py +17 -5
app.py
CHANGED
|
@@ -99,12 +99,20 @@ class WebGameEngine:
|
|
| 99 |
def progress_hook(block_num, block_size, total_size):
|
| 100 |
if total_size > 0:
|
| 101 |
progress = min(100, (block_num * block_size * 100) / total_size)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
|
| 107 |
self.download_progress = 100
|
|
|
|
| 108 |
|
| 109 |
# Run download in thread pool to avoid blocking
|
| 110 |
loop = asyncio.get_event_loop()
|
|
@@ -142,21 +150,14 @@ class WebGameEngine:
|
|
| 142 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 143 |
logger.info(f"Using device: {device}")
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
logger.warning(f"No checkpoint found at {checkpoint_path} - using dummy mode")
|
| 149 |
-
self._init_dummy_mode()
|
| 150 |
-
return True
|
| 151 |
|
| 152 |
# Get spawn directory
|
| 153 |
spawn_dir = web_config.get_spawn_dir()
|
| 154 |
|
| 155 |
-
#
|
| 156 |
-
num_actions = cfg.env.num_actions
|
| 157 |
-
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
|
| 158 |
-
|
| 159 |
-
# Try to load checkpoint (remote or local)
|
| 160 |
try:
|
| 161 |
# First try to download from Hugging Face Hub using direct URL
|
| 162 |
try:
|
|
@@ -192,6 +193,7 @@ class WebGameEngine:
|
|
| 192 |
logger.warning(f"Failed to download from HF Hub: {hub_error}")
|
| 193 |
|
| 194 |
# Fallback to local checkpoint if available
|
|
|
|
| 195 |
if checkpoint_path.exists():
|
| 196 |
logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
|
| 197 |
agent.load(checkpoint_path)
|
|
|
|
| 99 |
def progress_hook(block_num, block_size, total_size):
|
| 100 |
if total_size > 0:
|
| 101 |
progress = min(100, (block_num * block_size * 100) / total_size)
|
| 102 |
+
new_progress = int(progress)
|
| 103 |
+
|
| 104 |
+
# Update progress more frequently for smooth progress bar
|
| 105 |
+
if new_progress != self.download_progress:
|
| 106 |
+
self.download_progress = new_progress
|
| 107 |
+
self.loading_status = f"Downloading AI model ({self.download_progress}%)"
|
| 108 |
+
|
| 109 |
+
# Log every 5% instead of 10% for better feedback
|
| 110 |
+
if self.download_progress % 5 == 0:
|
| 111 |
+
logger.info(f"Download progress: {self.download_progress}%")
|
| 112 |
|
| 113 |
urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
|
| 114 |
self.download_progress = 100
|
| 115 |
+
self.loading_status = "Download complete! Loading model..."
|
| 116 |
|
| 117 |
# Run download in thread pool to avoid blocking
|
| 118 |
loop = asyncio.get_event_loop()
|
|
|
|
| 150 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 151 |
logger.info(f"Using device: {device}")
|
| 152 |
|
| 153 |
+
# Initialize agent first
|
| 154 |
+
num_actions = cfg.env.num_actions
|
| 155 |
+
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# Get spawn directory
|
| 158 |
spawn_dir = web_config.get_spawn_dir()
|
| 159 |
|
| 160 |
+
# Try to load checkpoint (remote first, then local, then dummy mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
try:
|
| 162 |
# First try to download from Hugging Face Hub using direct URL
|
| 163 |
try:
|
|
|
|
| 193 |
logger.warning(f"Failed to download from HF Hub: {hub_error}")
|
| 194 |
|
| 195 |
# Fallback to local checkpoint if available
|
| 196 |
+
checkpoint_path = web_config.get_checkpoint_path()
|
| 197 |
if checkpoint_path.exists():
|
| 198 |
logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
|
| 199 |
agent.load(checkpoint_path)
|
config/agent/csgo.yaml
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
_target_: agent.AgentConfig
|
| 2 |
|
| 3 |
denoiser:
|
| 4 |
-
_target_: models.diffusion.DenoiserConfig
|
| 5 |
sigma_data: 0.5
|
| 6 |
sigma_offset_noise: 0.1
|
| 7 |
noise_previous_obs: true
|
| 8 |
upsampling_factor: null
|
| 9 |
inner_model:
|
| 10 |
-
_target_: models.diffusion.InnerModelConfig
|
| 11 |
img_channels: 3
|
| 12 |
num_steps_conditioning: 4
|
| 13 |
cond_channels: 2048
|
|
@@ -16,13 +16,13 @@ denoiser:
|
|
| 16 |
attn_depths: [0, 0, 1, 1]
|
| 17 |
|
| 18 |
upsampler:
|
| 19 |
-
_target_: models.diffusion.DenoiserConfig
|
| 20 |
sigma_data: 0.5
|
| 21 |
sigma_offset_noise: 0.1
|
| 22 |
noise_previous_obs: false
|
| 23 |
upsampling_factor: 5
|
| 24 |
inner_model:
|
| 25 |
-
_target_: models.diffusion.InnerModelConfig
|
| 26 |
img_channels: 3
|
| 27 |
num_steps_conditioning: 1
|
| 28 |
cond_channels: 2048
|
|
|
|
| 1 |
+
_target_: src.agent.AgentConfig
|
| 2 |
|
| 3 |
denoiser:
|
| 4 |
+
_target_: src.models.diffusion.DenoiserConfig
|
| 5 |
sigma_data: 0.5
|
| 6 |
sigma_offset_noise: 0.1
|
| 7 |
noise_previous_obs: true
|
| 8 |
upsampling_factor: null
|
| 9 |
inner_model:
|
| 10 |
+
_target_: src.models.diffusion.InnerModelConfig
|
| 11 |
img_channels: 3
|
| 12 |
num_steps_conditioning: 4
|
| 13 |
cond_channels: 2048
|
|
|
|
| 16 |
attn_depths: [0, 0, 1, 1]
|
| 17 |
|
| 18 |
upsampler:
|
| 19 |
+
_target_: src.models.diffusion.DenoiserConfig
|
| 20 |
sigma_data: 0.5
|
| 21 |
sigma_offset_noise: 0.1
|
| 22 |
noise_previous_obs: false
|
| 23 |
upsampling_factor: 5
|
| 24 |
inner_model:
|
| 25 |
+
_target_: src.models.diffusion.InnerModelConfig
|
| 26 |
img_channels: 3
|
| 27 |
num_steps_conditioning: 1
|
| 28 |
cond_channels: 2048
|
config/world_model_env/fast.yaml
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
_target_: envs.WorldModelEnvConfig
|
| 2 |
horizon: 1000
|
| 3 |
num_batches_to_preload: 1
|
| 4 |
diffusion_sampler_next_obs:
|
| 5 |
-
_target_: models.diffusion.DiffusionSamplerConfig
|
| 6 |
num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
|
| 7 |
sigma_min: 0.002
|
| 8 |
sigma_max: 5.0
|
| 9 |
rho: 7
|
| 10 |
order: 1
|
| 11 |
diffusion_sampler_upsampling:
|
| 12 |
-
_target_: models.diffusion.DiffusionSamplerConfig
|
| 13 |
num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
|
| 14 |
sigma_min: 0.002
|
| 15 |
sigma_max: 5.0
|
|
|
|
| 1 |
+
_target_: src.envs.WorldModelEnvConfig
|
| 2 |
horizon: 1000
|
| 3 |
num_batches_to_preload: 1
|
| 4 |
diffusion_sampler_next_obs:
|
| 5 |
+
_target_: src.models.diffusion.DiffusionSamplerConfig
|
| 6 |
num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
|
| 7 |
sigma_min: 0.002
|
| 8 |
sigma_max: 5.0
|
| 9 |
rho: 7
|
| 10 |
order: 1
|
| 11 |
diffusion_sampler_upsampling:
|
| 12 |
+
_target_: src.models.diffusion.DiffusionSamplerConfig
|
| 13 |
num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
|
| 14 |
sigma_min: 0.002
|
| 15 |
sigma_max: 5.0
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
torch>=1.13.0
|
| 3 |
torchvision>=0.14.0
|
| 4 |
torchaudio>=0.13.0
|
| 5 |
-
numpy>=1.21.0
|
| 6 |
|
| 7 |
# Configuration management
|
| 8 |
hydra-core>=1.2.0
|
|
@@ -28,8 +28,8 @@ h5py>=3.7.0
|
|
| 28 |
ale_py>=0.8.0
|
| 29 |
gymnasium>=0.28.0
|
| 30 |
|
| 31 |
-
# Experiment tracking (
|
| 32 |
-
wandb>=0.13.0
|
| 33 |
|
| 34 |
# Metrics (required by rew_end_model.py)
|
| 35 |
torcheval>=0.0.6
|
|
|
|
| 2 |
torch>=1.13.0
|
| 3 |
torchvision>=0.14.0
|
| 4 |
torchaudio>=0.13.0
|
| 5 |
+
numpy>=1.21.0,<2.0.0
|
| 6 |
|
| 7 |
# Configuration management
|
| 8 |
hydra-core>=1.2.0
|
|
|
|
| 28 |
ale_py>=0.8.0
|
| 29 |
gymnasium>=0.28.0
|
| 30 |
|
| 31 |
+
# Experiment tracking (optional, for training only)
|
| 32 |
+
# wandb>=0.13.0 # Commented out due to NumPy 2.0 compatibility issues
|
| 33 |
|
| 34 |
# Metrics (required by rew_end_model.py)
|
| 35 |
torcheval>=0.0.6
|
src/__pycache__/agent.cpython-310.pyc
CHANGED
|
Binary files a/src/__pycache__/agent.cpython-310.pyc and b/src/__pycache__/agent.cpython-310.pyc differ
|
|
|
src/__pycache__/utils.cpython-310.pyc
CHANGED
|
Binary files a/src/__pycache__/utils.cpython-310.pyc and b/src/__pycache__/utils.cpython-310.pyc differ
|
|
|
src/coroutines/__pycache__/env_loop.cpython-310.pyc
CHANGED
|
Binary files a/src/coroutines/__pycache__/env_loop.cpython-310.pyc and b/src/coroutines/__pycache__/env_loop.cpython-310.pyc differ
|
|
|
src/csgo/__pycache__/web_action_processing.cpython-310.pyc
CHANGED
|
Binary files a/src/csgo/__pycache__/web_action_processing.cpython-310.pyc and b/src/csgo/__pycache__/web_action_processing.cpython-310.pyc differ
|
|
|
src/data/__pycache__/dataset.cpython-310.pyc
CHANGED
|
Binary files a/src/data/__pycache__/dataset.cpython-310.pyc and b/src/data/__pycache__/dataset.cpython-310.pyc differ
|
|
|
src/envs/__pycache__/world_model_env.cpython-310.pyc
CHANGED
|
Binary files a/src/envs/__pycache__/world_model_env.cpython-310.pyc and b/src/envs/__pycache__/world_model_env.cpython-310.pyc differ
|
|
|
src/game/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/src/game/__pycache__/__init__.cpython-310.pyc and b/src/game/__pycache__/__init__.cpython-310.pyc differ
|
|
|
src/game/__pycache__/web_play_env.cpython-310.pyc
CHANGED
|
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
|
|
|
src/models/__pycache__/actor_critic.cpython-310.pyc
CHANGED
|
Binary files a/src/models/__pycache__/actor_critic.cpython-310.pyc and b/src/models/__pycache__/actor_critic.cpython-310.pyc differ
|
|
|
src/models/__pycache__/rew_end_model.cpython-310.pyc
CHANGED
|
Binary files a/src/models/__pycache__/rew_end_model.cpython-310.pyc and b/src/models/__pycache__/rew_end_model.cpython-310.pyc differ
|
|
|
src/models/diffusion/__pycache__/denoiser.cpython-310.pyc
CHANGED
|
Binary files a/src/models/diffusion/__pycache__/denoiser.cpython-310.pyc and b/src/models/diffusion/__pycache__/denoiser.cpython-310.pyc differ
|
|
|
src/utils.py
CHANGED
|
@@ -17,7 +17,13 @@ from torch.optim.lr_scheduler import LambdaLR
|
|
| 17 |
import torch.nn as nn
|
| 18 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
from torch.optim import AdamW
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
ATARI_100K_GAMES = [
|
|
@@ -275,8 +281,12 @@ def prompt_atari_game():
|
|
| 275 |
|
| 276 |
def prompt_run_name(game):
|
| 277 |
cfg_file = Path("config/trainer.yaml")
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
name = game + suffix
|
| 281 |
name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n")
|
| 282 |
if name_ != "":
|
|
@@ -329,5 +339,7 @@ def try_until_no_except(func: Callable) -> None:
|
|
| 329 |
|
| 330 |
|
| 331 |
def wandb_log(logs: Logs, epoch: int):
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
| 17 |
import torch.nn as nn
|
| 18 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
from torch.optim import AdamW
|
| 20 |
+
try:
|
| 21 |
+
import wandb
|
| 22 |
+
WANDB_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
# wandb not available, set to None for graceful fallback
|
| 25 |
+
wandb = None
|
| 26 |
+
WANDB_AVAILABLE = False
|
| 27 |
|
| 28 |
|
| 29 |
ATARI_100K_GAMES = [
|
|
|
|
| 281 |
|
| 282 |
def prompt_run_name(game):
|
| 283 |
cfg_file = Path("config/trainer.yaml")
|
| 284 |
+
try:
|
| 285 |
+
cfg_name = OmegaConf.load(cfg_file).wandb.name
|
| 286 |
+
suffix = f"-{cfg_name}" if cfg_name is not None else ""
|
| 287 |
+
except:
|
| 288 |
+
# If wandb config not available, use empty suffix
|
| 289 |
+
suffix = ""
|
| 290 |
name = game + suffix
|
| 291 |
name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n")
|
| 292 |
if name_ != "":
|
|
|
|
| 339 |
|
| 340 |
|
| 341 |
def wandb_log(logs: Logs, epoch: int):
|
| 342 |
+
if WANDB_AVAILABLE and wandb is not None:
|
| 343 |
+
for d in logs:
|
| 344 |
+
wandb.log({"epoch": epoch, **d})
|
| 345 |
+
# If wandb not available, silently skip logging
|