musictimer commited on
Commit
ded2bd6
·
1 Parent(s): 7deb5ff
app.py CHANGED
@@ -25,6 +25,14 @@ from omegaconf import DictConfig, OmegaConf
25
  from PIL import Image
26
 
27
  # Import your modules
 
 
 
 
 
 
 
 
28
  from src.agent import Agent
29
  from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
30
  from src.envs import WorldModelEnv
@@ -96,62 +104,24 @@ class WebGameEngine:
96
  def load_model_weights():
97
  """Load model weights in thread pool to avoid blocking"""
98
  try:
99
- # Direct download without any caching to avoid permission issues on HF Spaces
100
- logger.info("Downloading model directly without caching...")
101
- self.loading_status = "Downloading model without caching..."
102
  self.download_progress = 10
103
 
104
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
105
 
106
- # Use requests to download directly into memory
107
- import requests
108
- import io
109
-
110
- logger.info(f"Starting direct download from {model_url}")
111
- response = requests.get(model_url, stream=True)
112
- response.raise_for_status()
113
-
114
- # Get the total file size for progress tracking
115
- total_size = int(response.headers.get('content-length', 0))
116
- logger.info(f"Model file size: {total_size / (1024*1024):.1f} MB")
117
-
118
- # Download with progress tracking
119
- downloaded_data = io.BytesIO()
120
- downloaded_size = 0
121
-
122
- for chunk in response.iter_content(chunk_size=8192):
123
- if chunk:
124
- downloaded_data.write(chunk)
125
- downloaded_size += len(chunk)
126
-
127
- # Update progress
128
- if total_size > 0:
129
- progress = min(50, int((downloaded_size / total_size) * 40) + 10) # 10-50%
130
- if progress != self.download_progress:
131
- self.download_progress = progress
132
- logger.info(f"Download progress: {progress}%")
133
-
134
- self.download_progress = 50
135
- self.loading_status = "Download complete, loading model..."
136
- logger.info("Download completed, loading state dict...")
137
-
138
- # Reset to beginning of buffer and load
139
- downloaded_data.seek(0)
140
- state_dict = torch.load(downloaded_data, map_location=device)
141
- logger.info("Successfully loaded model using direct download")
142
 
143
- except Exception as e:
144
- logger.error(f"Failed to download model directly: {e}")
145
- raise Exception(f"Direct download failed: {str(e)}")
146
-
147
- # Load state dict into agent using the new load_state_dict method
148
- try:
149
- logger.info("Model download completed, loading weights...")
150
  self.download_progress = 60
151
  self.loading_status = "Loading model weights into agent..."
 
152
 
153
- # Use the agent's new load_state_dict method
154
- agent.load_state_dict(state_dict)
 
 
155
 
156
  self.download_progress = 100
157
  self.loading_status = "Model loaded successfully!"
@@ -159,7 +129,7 @@ class WebGameEngine:
159
  return True
160
 
161
  except Exception as e:
162
- logger.error(f"Failed to load state dict into agent: {e}")
163
  import traceback
164
  traceback.print_exc()
165
  return False
@@ -210,6 +180,15 @@ class WebGameEngine:
210
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
  logger.info(f"Using device: {device}")
212
 
 
 
 
 
 
 
 
 
 
213
  # Initialize agent first
214
  num_actions = cfg.env.num_actions
215
  agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
@@ -228,7 +207,7 @@ class WebGameEngine:
228
  logger.info("Successfully loaded checkpoint from HF Hub")
229
  else:
230
  # Fallback to local checkpoint if available
231
- logger.warning("Failed to load from HF Hub, trying local checkpoint...")
232
  checkpoint_path = web_config.get_checkpoint_path()
233
  if checkpoint_path.exists():
234
  logger.info(f"Loading local checkpoint: {checkpoint_path}")
@@ -236,6 +215,7 @@ class WebGameEngine:
236
  agent.load(checkpoint_path)
237
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
238
  else:
 
239
  raise FileNotFoundError("No model checkpoint available (local or remote)")
240
 
241
  except Exception as e:
@@ -255,6 +235,18 @@ class WebGameEngine:
255
  # Create play environment
256
  self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
257
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Model compilation causes 10-30s delay on first inference, so make it optional
259
  # You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable
260
  import os
 
25
  from PIL import Image
26
 
27
  # Import your modules
28
+ import sys
29
+ from pathlib import Path
30
+
31
+ # Add project root to path for src package imports
32
+ project_root = Path(__file__).parent
33
+ if str(project_root) not in sys.path:
34
+ sys.path.insert(0, str(project_root))
35
+
36
  from src.agent import Agent
37
  from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
38
  from src.envs import WorldModelEnv
 
104
  def load_model_weights():
105
  """Load model weights in thread pool to avoid blocking"""
106
  try:
107
+ logger.info("Loading model using torch.hub.load_state_dict_from_url...")
108
+ self.loading_status = "Downloading model..."
 
109
  self.download_progress = 10
110
 
111
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
112
 
113
+ # Use torch.hub to download and load state dict
114
+ logger.info(f"Loading state dict from {model_url}")
115
+ state_dict = torch.hub.load_state_dict_from_url(model_url, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
 
 
117
  self.download_progress = 60
118
  self.loading_status = "Loading model weights into agent..."
119
+ logger.info("State dict loaded, applying to agent...")
120
 
121
+ # Load state dict into agent, but skip actor_critic if not present
122
+ has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
123
+ logger.info(f"Model has actor_critic weights: {has_actor_critic}")
124
+ agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
125
 
126
  self.download_progress = 100
127
  self.loading_status = "Model loaded successfully!"
 
129
  return True
130
 
131
  except Exception as e:
132
+ logger.error(f"Failed to load model: {e}")
133
  import traceback
134
  traceback.print_exc()
135
  return False
 
180
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
  logger.info(f"Using device: {device}")
182
 
183
+ # Log GPU availability and CUDA info for debugging HF Spaces
184
+ if torch.cuda.is_available():
185
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
186
+ logger.info(f"GPU device count: {torch.cuda.device_count()}")
187
+ logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}")
188
+ logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
189
+ else:
190
+ logger.info("CUDA not available, using CPU - this is normal for HF Spaces free tier")
191
+
192
  # Initialize agent first
193
  num_actions = cfg.env.num_actions
194
  agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
 
207
  logger.info("Successfully loaded checkpoint from HF Hub")
208
  else:
209
  # Fallback to local checkpoint if available
210
+ logger.error("Failed to load from HF Hub! Check the detailed error above.")
211
  checkpoint_path = web_config.get_checkpoint_path()
212
  if checkpoint_path.exists():
213
  logger.info(f"Loading local checkpoint: {checkpoint_path}")
 
215
  agent.load(checkpoint_path)
216
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
217
  else:
218
+ logger.error(f"No local checkpoint found at: {checkpoint_path}")
219
  raise FileNotFoundError("No model checkpoint available (local or remote)")
220
 
221
  except Exception as e:
 
235
  # Create play environment
236
  self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
237
 
238
+ # Verify actor-critic is loaded and ready for inference
239
+ if agent.actor_critic is not None:
240
+ logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
241
+ logger.info(f"Actor-critic device: {agent.actor_critic.device}")
242
+ # Force AI control for web demo
243
+ self.play_env.is_human_player = False
244
+ logger.info("WebPlayEnv set to AI control mode")
245
+ else:
246
+ logger.warning("No actor-critic model found - AI inference will not work!")
247
+ self.play_env.is_human_player = True
248
+ logger.info("WebPlayEnv set to human control mode (fallback)")
249
+
250
  # Model compilation causes 10-30s delay on first inference, so make it optional
251
  # You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable
252
  import os
config/agent/csgo.yaml CHANGED
@@ -31,4 +31,10 @@ upsampler:
31
  attn_depths: [0, 0, 0, 1]
32
 
33
  rew_end_model: null
34
- actor_critic: null
 
 
 
 
 
 
 
31
  attn_depths: [0, 0, 0, 1]
32
 
33
  rew_end_model: null
34
+ actor_critic:
35
+ _target_: src.models.actor_critic.ActorCriticConfig
36
+ lstm_dim: 512
37
+ img_channels: 3
38
+ img_size: 64
39
+ channels: [32, 64, 128]
40
+ down: [2, 2, 2]
src/__init__.pyc ADDED
Binary file (102 Bytes). View file
 
src/__pycache__/agent.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/agent.cpython-310.pyc and b/src/__pycache__/agent.cpython-310.pyc differ
 
src/game/__init__.pyc ADDED
Binary file (366 Bytes). View file
 
src/game/__pycache__/dataset_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/dataset_env.cpython-310.pyc and b/src/game/__pycache__/dataset_env.cpython-310.pyc differ
 
src/game/__pycache__/game.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/game.cpython-310.pyc and b/src/game/__pycache__/game.cpython-310.pyc differ
 
src/game/__pycache__/play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/play_env.cpython-310.pyc and b/src/game/__pycache__/play_env.cpython-310.pyc differ
 
src/game/__pycache__/web_play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
 
src/game/play_env.py CHANGED
@@ -7,11 +7,11 @@ import pygame
7
  import torch
8
  from torch import Tensor
9
 
10
- from agent import Agent
11
- from csgo.action_processing import CSGOAction, decode_csgo_action, encode_csgo_action, print_csgo_action
12
- from csgo.keymap import CSGO_KEYMAP
13
- from data import Dataset, Episode
14
- from envs import WorldModelEnv
15
 
16
 
17
  NamedEnv = namedtuple("NamedEnv", "name env")
 
7
  import torch
8
  from torch import Tensor
9
 
10
+ from ..agent import Agent
11
+ from ..csgo.action_processing import CSGOAction, decode_csgo_action, encode_csgo_action, print_csgo_action
12
+ from ..csgo.keymap import CSGO_KEYMAP
13
+ from ..data import Dataset, Episode
14
+ from ..envs import WorldModelEnv
15
 
16
 
17
  NamedEnv = namedtuple("NamedEnv", "name env")
src/game/web_play_env.py CHANGED
@@ -1,82 +1,68 @@
1
  """
2
- Web-compatible PlayEnv that works without pygame
3
  """
4
 
5
- from collections import defaultdict, namedtuple
6
- from pathlib import Path
7
- from typing import Any, Dict, List, Tuple, Set
8
-
9
  import torch
10
  from torch import Tensor
 
11
 
12
  from ..agent import Agent
13
- from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action
14
- from ..data import Dataset, Episode
15
  from ..envs import WorldModelEnv
 
 
16
 
17
- OneStepData = namedtuple("OneStepData", "obs act rew end trunc")
18
 
19
- class WebPlayEnv:
20
- """Web-compatible version of PlayEnv without pygame dependencies"""
21
 
22
  def __init__(
23
  self,
24
  agent: Agent,
25
  wm_env: WorldModelEnv,
26
- recording_mode: bool = False,
27
- store_denoising_trajectory: bool = False,
28
- store_original_obs: bool = False,
29
  ) -> None:
30
- self.agent = agent
31
- self.recording_mode = recording_mode
32
- self.store_denoising_trajectory = store_denoising_trajectory
33
- self.store_original_obs = store_original_obs
34
- self.is_human_player = True
35
- self.env_id = 0
36
- self.env_name = "world model"
37
- self.env = wm_env
38
- self.obs, self.t, self.buffer, self.rec_dataset = (None,) * 4
39
-
40
- def print_controls(self) -> None:
41
- """Print available controls for web interface"""
42
- print("\nWeb Environment Controls:\n")
43
- controls = {
44
- "W": "Move Forward",
45
- "A": "Move Left",
46
- "S": "Move Back",
47
- "D": "Move Right",
48
- "Space": "Jump",
49
- "Ctrl": "Crouch",
50
- "Shift": "Walk",
51
- "1": "Weapon 1",
52
- "2": "Weapon 2",
53
- "3": "Weapon 3",
54
- "R": "Reload",
55
- "Arrow Keys": "Camera Movement",
56
- "Left Click": "Primary Fire",
57
- "Right Click": "Secondary Fire"
58
- }
59
 
60
- for key, action in controls.items():
61
- print(f"{key}: {action}")
62
-
63
- def step_from_web_input(self, pressed_keys: Set[str], mouse_x: float = 0, mouse_y: float = 0,
64
- l_click: bool = False, r_click: bool = False) -> Tuple[Tensor, Tensor, bool, bool, Dict]:
65
- """
66
- Step the environment using web input
67
 
68
- Args:
69
- pressed_keys: Set of currently pressed key codes (e.g., {'KeyW', 'KeyA'})
70
- mouse_x, mouse_y: Mouse movement deltas
71
- l_click, r_click: Mouse button states
72
-
73
- Returns:
74
- Tuple of (observation, reward, done, truncated, info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  """
 
76
  # Convert web keys to action names
77
  action_names = web_keys_to_csgo_action_names(pressed_keys)
78
 
79
- # Create WebCSGOAction
80
  web_action = WebCSGOAction(
81
  key_names=action_names,
82
  mouse_x=mouse_x,
@@ -85,84 +71,52 @@ class WebPlayEnv:
85
  r_click=r_click
86
  )
87
 
88
- # Convert to tensor format for the model
89
- action_tensor = encode_web_csgo_action(web_action, self.agent.device)
90
-
91
- # Step the environment with the action tensor
92
- return self.step_with_tensor(action_tensor)
93
-
94
- def step_with_tensor(self, action_tensor: Tensor) -> Tuple[Tensor, Tensor, bool, bool, Dict]:
95
- """Step environment with pre-encoded action tensor"""
96
- if self.is_human_player:
97
- # Use human action
98
- act = action_tensor.unsqueeze(0) # Add batch dimension
99
  else:
100
- # Use AI agent action
101
- with torch.no_grad():
102
- act_logits, _ = self.agent.actor_critic.predict_act_value(self.obs.unsqueeze(0), self.hx_cx)
103
- act = torch.distributions.Categorical(logits=act_logits).sample()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Step environment
106
- next_obs, rew, end, trunc, info = self.env.step(act)
107
 
108
- # Handle episode completion
109
- if end or trunc:
110
- if self.recording_mode and self.rec_dataset is not None:
111
- self.rec_dataset.save_episode()
112
- print(f"Episode saved! Length: {len(self.buffer)}")
113
 
114
- return next_obs[0], rew[0], end[0], trunc[0], info
115
-
116
- def reset(self) -> Tuple[Tensor, Dict]:
117
- """Reset the environment"""
118
- self.obs = self.env.reset()[0] # Get first observation from batch
119
- self.t = 0
120
- self.buffer = []
121
 
122
- # Initialize actor-critic hidden state if using AI player
123
- if hasattr(self.agent, 'actor_critic') and self.agent.actor_critic is not None:
124
- self.hx_cx = (
125
- torch.zeros(1, self.agent.actor_critic.lstm_dim, device=self.agent.device),
126
- torch.zeros(1, self.agent.actor_critic.lstm_dim, device=self.agent.device)
127
- )
128
- else:
129
- self.hx_cx = None
130
-
131
- info = {"step": 0, "episode_return": 0}
132
- return self.obs, info
133
-
134
- def switch_controller(self) -> bool:
135
- """Switch between human and AI control"""
136
- self.is_human_player = not self.is_human_player
137
- controller = "Human" if self.is_human_player else "AI"
138
- print(f"Switched to {controller} control")
139
- return True
140
-
141
- def next_mode(self) -> bool:
142
- """Switch control mode"""
143
- return self.switch_controller()
144
-
145
- def next_axis_1(self) -> bool:
146
- """Placeholder for axis control"""
147
- return False
148
-
149
- def prev_axis_1(self) -> bool:
150
- """Placeholder for axis control"""
151
- return False
152
-
153
- def next_axis_2(self) -> bool:
154
- """Placeholder for axis control"""
155
- return False
156
-
157
- def prev_axis_2(self) -> bool:
158
- """Placeholder for axis control"""
159
- return False
160
-
161
- def print_env(self) -> None:
162
- """Print current environment info"""
163
- print(f"> Environment: {self.env_name}")
164
- print(f"> Controller: {'Human' if self.is_human_player else 'AI'}")
165
-
166
- def str_control(self) -> str:
167
- """Get control mode string"""
168
- return "Human" if self.is_human_player else "AI"
 
1
  """
2
+ Web-compatible PlayEnv that handles web input and AI inference
3
  """
4
 
5
+ from typing import Any, Dict, List, Set, Tuple
 
 
 
6
  import torch
7
  from torch import Tensor
8
+ from torch.distributions.categorical import Categorical
9
 
10
  from ..agent import Agent
 
 
11
  from ..envs import WorldModelEnv
12
+ from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action
13
+ from .play_env import PlayEnv
14
 
 
15
 
16
+ class WebPlayEnv(PlayEnv):
17
+ """Web-compatible version of PlayEnv that handles web input and AI inference"""
18
 
19
  def __init__(
20
  self,
21
  agent: Agent,
22
  wm_env: WorldModelEnv,
23
+ recording_mode: bool,
24
+ store_denoising_trajectory: bool,
25
+ store_original_obs: bool,
26
  ) -> None:
27
+ super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # For web demo, we want AI control by default
30
+ self.is_human_player = False # AI controls the actions
31
+ self.human_input_override = False # Can be set to True to allow human input
 
 
 
 
32
 
33
+ # Initialize LSTM hidden states for actor-critic
34
+ self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
35
+ self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
36
+
37
+ def switch_controller(self) -> None:
38
+ """Switch between AI and human control"""
39
+ self.is_human_player = not self.is_human_player
40
+ print(f"Switched to {'human' if self.is_human_player else 'AI'} control")
41
+
42
+ def str_control(self) -> str:
43
+ """Return control mode string"""
44
+ if self.human_input_override:
45
+ return "Human Override"
46
+ return "Human" if self.is_human_player else "AI"
47
+
48
+ @torch.no_grad()
49
+ def step_from_web_input(
50
+ self,
51
+ pressed_keys: Set[str],
52
+ mouse_x: float,
53
+ mouse_y: float,
54
+ l_click: bool,
55
+ r_click: bool,
56
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
57
+ """
58
+ Step the environment with web input.
59
+ If AI mode is enabled, use AI inference. If human mode or override, use human input.
60
  """
61
+
62
  # Convert web keys to action names
63
  action_names = web_keys_to_csgo_action_names(pressed_keys)
64
 
65
+ # Create web CSGO action from input
66
  web_action = WebCSGOAction(
67
  key_names=action_names,
68
  mouse_x=mouse_x,
 
71
  r_click=r_click
72
  )
73
 
74
+ # If we have human input override or in human mode, use human input
75
+ if self.human_input_override or self.is_human_player:
76
+ # Encode the web action to tensor format
77
+ action = encode_web_csgo_action(web_action, device=self.agent.device)
78
+
 
 
 
 
 
 
79
  else:
80
+ # AI mode - use the agent's actor-critic to predict the action
81
+ try:
82
+ # Get current observation (ensure it has batch dimension)
83
+ obs = self.obs
84
+ if obs.ndim == 3: # CHW -> BCHW
85
+ obs = obs.unsqueeze(0)
86
+
87
+ # Detach hidden states to prevent gradient tracking
88
+ self.hx = self.hx.detach()
89
+ self.cx = self.cx.detach()
90
+
91
+ # Get action logits and value from actor-critic
92
+ logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
93
+
94
+ # Sample action from logits
95
+ action_dist = Categorical(logits=logits_act)
96
+ action = action_dist.sample()
97
+
98
+ # Convert to proper shape (remove batch dimension if needed)
99
+ if action.ndim > 0 and action.size(0) == 1:
100
+ action = action.squeeze(0)
101
+
102
+ except Exception as e:
103
+ print(f"AI inference failed: {e}")
104
+ import traceback
105
+ traceback.print_exc()
106
+ # Fallback to human input if AI fails
107
+ action = encode_web_csgo_action(web_action, device=self.agent.device)
108
 
109
+ # Step the environment with the chosen action
110
+ next_obs, rew, end, trunc, env_info = self.env.step(action)
111
 
112
+ # Update internal state
113
+ self.obs = next_obs
114
+ self.t += 1
 
 
115
 
116
+ # Reset hidden states on episode end
117
+ if end.any() or trunc.any():
118
+ self.hx.zero_()
119
+ self.cx.zero_()
 
 
 
120
 
121
+ # Return the step results
122
+ return next_obs, rew, end, trunc, env_info