musictimer commited on
Commit
b8159f9
·
1 Parent(s): f1594be

Fix bug 10

Browse files
app.py CHANGED
@@ -727,6 +727,10 @@ async def get_homepage():
727
  ws.onopen = function(event) {
728
  statusEl.textContent = 'Connected';
729
  statusEl.style.color = '#00ff00';
 
 
 
 
730
  };
731
 
732
  ws.onmessage = function(event) {
@@ -857,7 +861,14 @@ async def get_homepage():
857
  canvas.addEventListener('click', () => {
858
  canvas.focus();
859
  if (!gameStarted) {
860
- startGame();
 
 
 
 
 
 
 
861
  }
862
  });
863
 
 
727
  ws.onopen = function(event) {
728
  statusEl.textContent = 'Connected';
729
  statusEl.style.color = '#00ff00';
730
+ // If user already clicked to start before WS was ready, send start now
731
+ if (gameStarted) {
732
+ ws.send(JSON.stringify({ type: 'start' }));
733
+ }
734
  };
735
 
736
  ws.onmessage = function(event) {
 
861
  canvas.addEventListener('click', () => {
862
  canvas.focus();
863
  if (!gameStarted) {
864
+ // Queue start locally and send immediately if WS is open
865
+ gameStarted = true;
866
+ gameStatusEl.textContent = 'Starting AI...';
867
+ gameStatusEl.style.color = '#ffff00';
868
+ loadingEl.style.display = 'block';
869
+ if (ws && ws.readyState === WebSocket.OPEN) {
870
+ ws.send(JSON.stringify({ type: 'start' }));
871
+ }
872
  }
873
  });
874
 
src/game/__pycache__/web_play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
 
src/game/dataset_env.py CHANGED
@@ -77,15 +77,15 @@ class DatasetEnv:
77
 
78
  @torch.no_grad()
79
  def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
80
- match act:
81
- case 1:
82
- self.set_timestep(self.t - 1)
83
- case 2:
84
- self.set_timestep(self.t + 1)
85
- case 3:
86
- self.set_timestep(self.t - 10)
87
- case 4:
88
- self.set_timestep(self.t + 10)
89
 
90
  n_digits = len(str(self.ep_length))
91
 
 
77
 
78
  @torch.no_grad()
79
  def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
80
+ # Replaced Python 3.10 `match` statement with if/elif chain for Python 3.8/3.9 compatibility
81
+ if act == 1:
82
+ self.set_timestep(self.t - 1)
83
+ elif act == 2:
84
+ self.set_timestep(self.t + 1)
85
+ elif act == 3:
86
+ self.set_timestep(self.t - 10)
87
+ elif act == 4:
88
+ self.set_timestep(self.t + 10)
89
 
90
  n_digits = len(str(self.ep_length))
91
 
src/game/web_play_env.py CHANGED
@@ -6,6 +6,8 @@ from typing import Any, Dict, List, Set, Tuple
6
  import torch
7
  from torch import Tensor
8
  from torch.distributions.categorical import Categorical
 
 
9
 
10
  from ..agent import Agent
11
  from ..envs import WorldModelEnv
@@ -71,6 +73,14 @@ class WebPlayEnv(PlayEnv):
71
  r_click=r_click
72
  )
73
 
 
 
 
 
 
 
 
 
74
  # If we have human input override or in human mode, use human input
75
  if self.human_input_override or self.is_human_player:
76
  # Encode the web action to tensor format
@@ -83,11 +93,33 @@ class WebPlayEnv(PlayEnv):
83
  obs = self.obs
84
  if obs.ndim == 3: # CHW -> BCHW
85
  obs = obs.unsqueeze(0)
 
 
 
86
 
87
  # Detach hidden states to prevent gradient tracking
88
  self.hx = self.hx.detach()
89
  self.cx = self.cx.detach()
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Get action logits and value from actor-critic
92
  logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
93
 
 
6
  import torch
7
  from torch import Tensor
8
  from torch.distributions.categorical import Categorical
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
 
12
  from ..agent import Agent
13
  from ..envs import WorldModelEnv
 
73
  r_click=r_click
74
  )
75
 
76
+ # Ensure we have a valid observation; if not, reset the environment
77
+ if self.obs is None:
78
+ try:
79
+ self.obs, _ = self.reset()
80
+ except Exception:
81
+ # If reset fails, fall back to human input below
82
+ pass
83
+
84
  # If we have human input override or in human mode, use human input
85
  if self.human_input_override or self.is_human_player:
86
  # Encode the web action to tensor format
 
93
  obs = self.obs
94
  if obs.ndim == 3: # CHW -> BCHW
95
  obs = obs.unsqueeze(0)
96
+ # Ensure obs is on the same device as the models
97
+ if obs.device != self.agent.device:
98
+ obs = obs.to(self.agent.device, non_blocking=True)
99
 
100
  # Detach hidden states to prevent gradient tracking
101
  self.hx = self.hx.detach()
102
  self.cx = self.cx.detach()
103
 
104
+ # Resize observation to match actor-critic expected encoder/LSTM input
105
+ # Count how many MaxPool2d layers are in the encoder to infer downsampling factor
106
+ if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None:
107
+ try:
108
+ n_pools = sum(
109
+ 1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d)
110
+ )
111
+ # We want the spatial size after the encoder to be 1x1 so that
112
+ # flattening matches the LSTM input size configured at init time.
113
+ # With n_pools halvings, input size must be 2**n_pools.
114
+ target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1)))
115
+ if obs.size(-2) != target_hw or obs.size(-1) != target_hw:
116
+ obs = F.interpolate(
117
+ obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False
118
+ )
119
+ except Exception:
120
+ # If anything goes wrong in the shape logic, fall back without resizing
121
+ pass
122
+
123
  # Get action logits and value from actor-critic
124
  logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
125