Spaces:
Sleeping
Sleeping
Commit
·
b8159f9
1
Parent(s):
f1594be
Fix bug 10
Browse files- app.py +12 -1
- src/game/__pycache__/web_play_env.cpython-310.pyc +0 -0
- src/game/dataset_env.py +9 -9
- src/game/web_play_env.py +32 -0
app.py
CHANGED
|
@@ -727,6 +727,10 @@ async def get_homepage():
|
|
| 727 |
ws.onopen = function(event) {
|
| 728 |
statusEl.textContent = 'Connected';
|
| 729 |
statusEl.style.color = '#00ff00';
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
};
|
| 731 |
|
| 732 |
ws.onmessage = function(event) {
|
|
@@ -857,7 +861,14 @@ async def get_homepage():
|
|
| 857 |
canvas.addEventListener('click', () => {
|
| 858 |
canvas.focus();
|
| 859 |
if (!gameStarted) {
|
| 860 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
}
|
| 862 |
});
|
| 863 |
|
|
|
|
| 727 |
ws.onopen = function(event) {
|
| 728 |
statusEl.textContent = 'Connected';
|
| 729 |
statusEl.style.color = '#00ff00';
|
| 730 |
+
// If user already clicked to start before WS was ready, send start now
|
| 731 |
+
if (gameStarted) {
|
| 732 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 733 |
+
}
|
| 734 |
};
|
| 735 |
|
| 736 |
ws.onmessage = function(event) {
|
|
|
|
| 861 |
canvas.addEventListener('click', () => {
|
| 862 |
canvas.focus();
|
| 863 |
if (!gameStarted) {
|
| 864 |
+
// Queue start locally and send immediately if WS is open
|
| 865 |
+
gameStarted = true;
|
| 866 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 867 |
+
gameStatusEl.style.color = '#ffff00';
|
| 868 |
+
loadingEl.style.display = 'block';
|
| 869 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 870 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 871 |
+
}
|
| 872 |
}
|
| 873 |
});
|
| 874 |
|
src/game/__pycache__/web_play_env.cpython-310.pyc
CHANGED
|
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
|
|
|
src/game/dataset_env.py
CHANGED
|
@@ -77,15 +77,15 @@ class DatasetEnv:
|
|
| 77 |
|
| 78 |
@torch.no_grad()
|
| 79 |
def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
|
| 80 |
-
match
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
|
| 90 |
n_digits = len(str(self.ep_length))
|
| 91 |
|
|
|
|
| 77 |
|
| 78 |
@torch.no_grad()
|
| 79 |
def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
|
| 80 |
+
# Replaced Python 3.10 `match` statement with if/elif chain for Python 3.8/3.9 compatibility
|
| 81 |
+
if act == 1:
|
| 82 |
+
self.set_timestep(self.t - 1)
|
| 83 |
+
elif act == 2:
|
| 84 |
+
self.set_timestep(self.t + 1)
|
| 85 |
+
elif act == 3:
|
| 86 |
+
self.set_timestep(self.t - 10)
|
| 87 |
+
elif act == 4:
|
| 88 |
+
self.set_timestep(self.t + 10)
|
| 89 |
|
| 90 |
n_digits = len(str(self.ep_length))
|
| 91 |
|
src/game/web_play_env.py
CHANGED
|
@@ -6,6 +6,8 @@ from typing import Any, Dict, List, Set, Tuple
|
|
| 6 |
import torch
|
| 7 |
from torch import Tensor
|
| 8 |
from torch.distributions.categorical import Categorical
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from ..agent import Agent
|
| 11 |
from ..envs import WorldModelEnv
|
|
@@ -71,6 +73,14 @@ class WebPlayEnv(PlayEnv):
|
|
| 71 |
r_click=r_click
|
| 72 |
)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# If we have human input override or in human mode, use human input
|
| 75 |
if self.human_input_override or self.is_human_player:
|
| 76 |
# Encode the web action to tensor format
|
|
@@ -83,11 +93,33 @@ class WebPlayEnv(PlayEnv):
|
|
| 83 |
obs = self.obs
|
| 84 |
if obs.ndim == 3: # CHW -> BCHW
|
| 85 |
obs = obs.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Detach hidden states to prevent gradient tracking
|
| 88 |
self.hx = self.hx.detach()
|
| 89 |
self.cx = self.cx.detach()
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# Get action logits and value from actor-critic
|
| 92 |
logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
|
| 93 |
|
|
|
|
| 6 |
import torch
|
| 7 |
from torch import Tensor
|
| 8 |
from torch.distributions.categorical import Categorical
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
|
| 12 |
from ..agent import Agent
|
| 13 |
from ..envs import WorldModelEnv
|
|
|
|
| 73 |
r_click=r_click
|
| 74 |
)
|
| 75 |
|
| 76 |
+
# Ensure we have a valid observation; if not, reset the environment
|
| 77 |
+
if self.obs is None:
|
| 78 |
+
try:
|
| 79 |
+
self.obs, _ = self.reset()
|
| 80 |
+
except Exception:
|
| 81 |
+
# If reset fails, fall back to human input below
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
# If we have human input override or in human mode, use human input
|
| 85 |
if self.human_input_override or self.is_human_player:
|
| 86 |
# Encode the web action to tensor format
|
|
|
|
| 93 |
obs = self.obs
|
| 94 |
if obs.ndim == 3: # CHW -> BCHW
|
| 95 |
obs = obs.unsqueeze(0)
|
| 96 |
+
# Ensure obs is on the same device as the models
|
| 97 |
+
if obs.device != self.agent.device:
|
| 98 |
+
obs = obs.to(self.agent.device, non_blocking=True)
|
| 99 |
|
| 100 |
# Detach hidden states to prevent gradient tracking
|
| 101 |
self.hx = self.hx.detach()
|
| 102 |
self.cx = self.cx.detach()
|
| 103 |
|
| 104 |
+
# Resize observation to match actor-critic expected encoder/LSTM input
|
| 105 |
+
# Count how many MaxPool2d layers are in the encoder to infer downsampling factor
|
| 106 |
+
if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None:
|
| 107 |
+
try:
|
| 108 |
+
n_pools = sum(
|
| 109 |
+
1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d)
|
| 110 |
+
)
|
| 111 |
+
# We want the spatial size after the encoder to be 1x1 so that
|
| 112 |
+
# flattening matches the LSTM input size configured at init time.
|
| 113 |
+
# With n_pools halvings, input size must be 2**n_pools.
|
| 114 |
+
target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1)))
|
| 115 |
+
if obs.size(-2) != target_hw or obs.size(-1) != target_hw:
|
| 116 |
+
obs = F.interpolate(
|
| 117 |
+
obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False
|
| 118 |
+
)
|
| 119 |
+
except Exception:
|
| 120 |
+
# If anything goes wrong in the shape logic, fall back without resizing
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
# Get action logits and value from actor-critic
|
| 124 |
logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
|
| 125 |
|