musictimer commited on
Commit
052d6f4
·
1 Parent(s): 5a63400

Fix initial bugs

Browse files
app.py CHANGED
@@ -28,7 +28,7 @@ import io
28
  import json
29
  import logging
30
  from pathlib import Path
31
- from typing import Dict, List, Optional, Set
32
 
33
  import cv2
34
  import numpy as np
@@ -101,8 +101,13 @@ class WebGameEngine:
101
  self.obs = None
102
  self.running = False
103
  self.game_started = False
104
- self.fps = 30 # Display FPS
105
- self.ai_fps = 10 # AI inference FPS (slower than display for efficiency)
 
 
 
 
 
106
  self.frame_count = 0
107
  self.ai_frame_count = 0
108
  self.last_ai_time = 0
@@ -121,6 +126,13 @@ class WebGameEngine:
121
  self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
122
  import time
123
  self.time_module = time
 
 
 
 
 
 
 
124
 
125
  async def _load_model_from_url_async(self, agent, device):
126
  """Load model from URL using torch.hub (HF Spaces compatible)"""
@@ -295,11 +307,10 @@ class WebGameEngine:
295
  self.play_env.is_human_player = True
296
  logger.info("WebPlayEnv set to human control mode (fallback)")
297
 
298
- # Model compilation causes 10-30s delay on first inference, so make it optional
299
- # You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable
300
  import os
301
- if device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "0") == "1":
302
- logger.info("Compiling models for faster inference (will cause delay on first inference)...")
303
  try:
304
  wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
305
  if wm_env.upsample_next_obs is not None:
@@ -308,7 +319,7 @@ class WebGameEngine:
308
  except Exception as e:
309
  logger.warning(f"Model compilation failed: {e}")
310
  else:
311
- logger.info("Model compilation disabled (faster startup). Set ENABLE_TORCH_COMPILE=1 to enable.")
312
 
313
  # Reset environment
314
  self.obs, _ = self.play_env.reset()
@@ -377,27 +388,27 @@ class WebGameEngine:
377
  self.last_ai_time = self.time_module.time() # Reset AI timer
378
  return self.obs, 0.0, False, False, {"reset": True}
379
 
380
- # Intelligent frame skipping: only run AI inference at target FPS
381
  current_time = self.time_module.time()
 
 
382
  time_since_last_ai = current_time - self.last_ai_time
383
  should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
384
-
385
- if should_run_ai:
386
- # Show loading indicator for first inference (can be slow)
387
- if not self.first_inference_done:
388
- logger.info("Running first AI inference (may take 5-15 seconds)...")
389
-
390
- # Run AI inference
391
- inference_start = self.time_module.time()
392
- next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
393
- pressed_keys=self.pressed_keys,
394
  mouse_x=self.mouse_x,
395
  mouse_y=self.mouse_y,
396
  l_click=self.l_click,
397
- r_click=self.r_click
398
  )
399
- inference_time = self.time_module.time() - inference_start
400
-
 
 
 
 
401
  # Log first inference completion
402
  if not self.first_inference_done:
403
  self.first_inference_done = True
@@ -489,23 +500,80 @@ class WebGameEngine:
489
  img = Image.fromarray(img_array)
490
 
491
  # Resize for web display to match canvas size (optimized)
492
- img = img.resize((600, 150), Image.NEAREST) # NEAREST is faster than BICUBIC
493
-
494
- # Optimized base64 conversion with JPEG for better compression/speed
495
- buffer = io.BytesIO()
496
- img.save(buffer, format='JPEG', quality=85, optimize=True) # JPEG is faster than PNG
497
- img_str = base64.b64encode(buffer.getvalue()).decode()
498
- return f"data:image/jpeg;base64,{img_str}"
 
 
 
 
 
 
 
 
 
499
 
500
  except Exception as e:
501
  logger.error(f"Error converting observation to base64: {e}")
502
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  async def game_loop(self):
505
  """Main game loop that runs continuously"""
506
  self.running = True
507
-
 
508
  while self.running:
 
 
 
 
 
 
 
509
  try:
510
  # Check if models are ready
511
  if not self.models_ready:
@@ -554,50 +622,104 @@ class WebGameEngine:
554
  self.obs = next_obs
555
 
556
  # Send frame to all connected clients (regardless of game state)
557
- if should_send_frame and connected_clients and self.obs is not None:
558
  # Set default values for when game isn't running
559
  if not self.game_started:
560
  reward = 0.0
561
  info = {"waiting": True}
562
  # If game is started, reward and info should be set above
563
 
564
- # Convert observation to base64
565
- image_data = self.obs_to_base64(self.obs)
566
-
567
- # Debug logging for first few frames
568
- if self.frame_count < 5:
569
- logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
570
- f"image_data_length={len(image_data) if image_data else 0}, "
571
- f"game_started={self.game_started}")
572
-
573
- frame_data = {
574
- 'type': 'frame',
575
- 'image': image_data,
576
- 'frame_count': self.frame_count,
577
- 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
578
- 'info': str(info) if info else "",
579
- 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
580
- 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False
581
- }
582
-
583
- # Send to all connected clients
584
- disconnected = set()
585
- for client in connected_clients.copy():
586
- try:
587
- await client.send_text(json.dumps(frame_data))
588
- except:
589
- disconnected.add(client)
590
-
591
- # Remove disconnected clients
592
- connected_clients.difference_update(disconnected)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
 
594
  self.frame_count += 1
595
- await asyncio.sleep(1.0 / self.fps) # Control FPS
 
 
 
 
 
596
 
597
  except Exception as e:
598
  logger.error(f"Error in game loop: {e}")
599
  await asyncio.sleep(0.1)
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  # Global game engine instance
602
  game_engine = WebGameEngine()
603
 
 
28
  import json
29
  import logging
30
  from pathlib import Path
31
+ from typing import Dict, List, Optional, Set, Tuple
32
 
33
  import cv2
34
  import numpy as np
 
101
  self.obs = None
102
  self.running = False
103
  self.game_started = False
104
+ # Allow runtime tuning via environment variables
105
+ import os
106
+ self.fps = int(os.getenv("DISPLAY_FPS", "30")) # Display FPS
107
+ # Increase default AI inference FPS; can be overridden with AI_FPS env var
108
+ self.ai_fps = int(os.getenv("AI_FPS", "15"))
109
+ # Send every Nth frame to the browser (1 = send all frames)
110
+ self.send_every = int(os.getenv("DISPLAY_SKIP", "1"))
111
  self.frame_count = 0
112
  self.ai_frame_count = 0
113
  self.last_ai_time = 0
 
126
  self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
127
  import time
128
  self.time_module = time
129
+
130
+ # Async inference queues to decouple GPU work from websocket I/O
131
+ import asyncio
132
+ self._in_queue: asyncio.Queue = asyncio.Queue(maxsize=1)
133
+ self._out_queue: asyncio.Queue = asyncio.Queue(maxsize=1)
134
+ # Flag to start worker once models are ready
135
+ self._worker_started = False
136
 
137
  async def _load_model_from_url_async(self, agent, device):
138
  """Load model from URL using torch.hub (HF Spaces compatible)"""
 
307
  self.play_env.is_human_player = True
308
  logger.info("WebPlayEnv set to human control mode (fallback)")
309
 
310
+ # Enable torch.compile by default like play.py does (can disable with DISABLE_TORCH_COMPILE=1)
 
311
  import os
312
+ if device.type == "cuda" and os.getenv("DISABLE_TORCH_COMPILE", "0") != "1":
313
+ logger.info("Compiling models for faster inference (like play.py --compile)...")
314
  try:
315
  wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
316
  if wm_env.upsample_next_obs is not None:
 
319
  except Exception as e:
320
  logger.warning(f"Model compilation failed: {e}")
321
  else:
322
+ logger.info("Model compilation disabled. Set DISABLE_TORCH_COMPILE=0 to enable.")
323
 
324
  # Reset environment
325
  self.obs, _ = self.play_env.reset()
 
388
  self.last_ai_time = self.time_module.time() # Reset AI timer
389
  return self.obs, 0.0, False, False, {"reset": True}
390
 
 
391
  current_time = self.time_module.time()
392
+
393
+ # Push task to inference queue if needed
394
  time_since_last_ai = current_time - self.last_ai_time
395
  should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
396
+
397
+ if should_run_ai and self._in_queue.empty():
398
+ # Snapshot web input state
399
+ web_state = dict(
400
+ pressed_keys=set(self.pressed_keys),
 
 
 
 
 
401
  mouse_x=self.mouse_x,
402
  mouse_y=self.mouse_y,
403
  l_click=self.l_click,
404
+ r_click=self.r_click,
405
  )
406
+ asyncio.create_task(self._in_queue.put((self.obs, web_state)))
407
+
408
+ # Check for completed inference
409
+ if not self._out_queue.empty():
410
+ (next_obs, reward, done, truncated, info, inference_time) = self._out_queue.get_nowait()
411
+
412
  # Log first inference completion
413
  if not self.first_inference_done:
414
  self.first_inference_done = True
 
500
  img = Image.fromarray(img_array)
501
 
502
  # Resize for web display to match canvas size (optimized)
503
+ img = img.resize((600, 150), Image.NEAREST)
504
+
505
+ # Choose codec via env var for flexibility (jpeg|png)
506
+ codec = os.getenv("IMG_CODEC", "jpeg").lower()
507
+ img_np = np.array(img)[:, :, ::-1] # RGB -> BGR
508
+ if codec == "png":
509
+ success, encoded_img = cv2.imencode('.png', img_np, [cv2.IMWRITE_PNG_COMPRESSION, 1])
510
+ mime = 'png'
511
+ else:
512
+ # JPEG with quality 70 for speed/size balance
513
+ success, encoded_img = cv2.imencode('.jpg', img_np, [cv2.IMWRITE_JPEG_QUALITY, 70])
514
+ mime = 'jpeg'
515
+ if not success:
516
+ return ""
517
+ img_str = base64.b64encode(encoded_img).decode()
518
+ return f"data:image/{mime};base64,{img_str}"
519
 
520
  except Exception as e:
521
  logger.error(f"Error converting observation to base64: {e}")
522
  return ""
523
+
524
+ # ------------------------------------------------------------------
525
+ # Faster binary encoder (JPEG/PNG) with OpenCV – no Pillow involved
526
+ # ------------------------------------------------------------------
527
+ def obs_to_bytes(self, obs: torch.Tensor) -> Tuple[bytes, str]:
528
+ """Return encoded image bytes and MIME (image/jpeg or image/png)."""
529
+ if obs is None:
530
+ return b"", "image/jpeg"
531
+
532
+ try:
533
+ # Keep operations on GPU as long as possible (like play.py)
534
+ if obs.ndim == 4 and obs.size(0) == 1:
535
+ img_tensor = obs[0]
536
+ else:
537
+ img_tensor = obs
538
+
539
+ # Resize on GPU first (faster than CPU resize)
540
+ img_tensor = torch.nn.functional.interpolate(
541
+ img_tensor.unsqueeze(0), size=(75, 300), mode='nearest'
542
+ ).squeeze(0)
543
+
544
+ # Convert to uint8 on GPU, then transfer to CPU once
545
+ img_np = (img_tensor.add(1).mul(127.5).clamp(0, 255).byte()
546
+ .permute(1, 2, 0).contiguous().cpu().numpy()) # HWC uint8
547
+
548
+ # Encode with OpenCV
549
+ import os
550
+ codec = os.getenv("IMG_CODEC", "jpeg").lower()
551
+ if codec == "png":
552
+ ok, enc = cv2.imencode('.png', img_np, [cv2.IMWRITE_PNG_COMPRESSION, 1])
553
+ mime = "image/png"
554
+ else:
555
+ ok, enc = cv2.imencode('.jpg', img_np, [cv2.IMWRITE_JPEG_QUALITY, 75])
556
+ mime = "image/jpeg"
557
+ if not ok:
558
+ return b"", mime
559
+ return enc.tobytes(), mime
560
+ except Exception as e:
561
+ logger.error(f"obs_to_bytes error: {e}")
562
+ return b"", "image/jpeg"
563
 
564
  async def game_loop(self):
565
  """Main game loop that runs continuously"""
566
  self.running = True
567
+ # Start inference worker once, when models are ready
568
+
569
  while self.running:
570
+ loop_start_time = self.time_module.time()
571
+
572
+ # Spawn worker lazily after models initialized
573
+ if self.models_ready and not self._worker_started:
574
+ asyncio.create_task(self._inference_worker())
575
+ self._worker_started = True
576
+
577
  try:
578
  # Check if models are ready
579
  if not self.models_ready:
 
622
  self.obs = next_obs
623
 
624
  # Send frame to all connected clients (regardless of game state)
625
+ if should_send_frame and connected_clients and self.obs is not None and (self.frame_count % self.send_every == 0):
626
  # Set default values for when game isn't running
627
  if not self.game_started:
628
  reward = 0.0
629
  info = {"waiting": True}
630
  # If game is started, reward and info should be set above
631
 
632
+ # Prefer binary frames if client agrees (feature flag)
633
+ use_binary = os.getenv("BINARY_WS", "0") == "1"
634
+
635
+ if use_binary:
636
+ img_bytes, mime = self.obs_to_bytes(self.obs)
637
+ meta = {
638
+ 'type': 'frame_meta',
639
+ 'mime': mime,
640
+ 'frame_count': self.frame_count,
641
+ 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
642
+ 'info': str(info) if info else "",
643
+ 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
644
+ 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False
645
+ }
646
+ disconnected = set()
647
+ for client in connected_clients.copy():
648
+ try:
649
+ await client.send_text(json.dumps(meta))
650
+ await client.send_bytes(img_bytes)
651
+ except:
652
+ disconnected.add(client)
653
+ connected_clients.difference_update(disconnected)
654
+ else:
655
+ # Fallback to base64 JSON
656
+ image_data = self.obs_to_base64(self.obs)
657
+
658
+ if self.frame_count < 5:
659
+ logger.info(
660
+ f"Frame {self.frame_count}: base64_len={len(image_data)} ai={info.get('ai_fps',0):.1f}")
661
+
662
+ frame_data = {
663
+ 'type': 'frame',
664
+ 'image': image_data,
665
+ 'frame_count': self.frame_count,
666
+ 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
667
+ 'info': str(info) if info else "",
668
+ 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
669
+ 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False
670
+ }
671
+
672
+ disconnected = set()
673
+ for client in connected_clients.copy():
674
+ try:
675
+ await client.send_text(json.dumps(frame_data))
676
+ except:
677
+ disconnected.add(client)
678
+ connected_clients.difference_update(disconnected)
679
 
680
  self.frame_count += 1
681
+
682
+ # Adaptive sleep so we don't waste idle time when GPU faster than display FPS
683
+ loop_elapsed = self.time_module.time() - loop_start_time
684
+ sleep_for = max((1.0 / self.fps) - loop_elapsed, 0)
685
+ if sleep_for:
686
+ await asyncio.sleep(sleep_for)
687
 
688
  except Exception as e:
689
  logger.error(f"Error in game loop: {e}")
690
  await asyncio.sleep(0.1)
691
 
692
+ async def _inference_worker(self):
693
+ """Runs AI inference in background to avoid blocking I/O."""
694
+ logger.info("Inference worker started")
695
+ next_inference_time = self.time_module.time()
696
+
697
+ while True:
698
+ obs, web_state = await self._in_queue.get()
699
+
700
+ # Timing control: maintain steady AI_FPS like play.py's clock.tick()
701
+ now = self.time_module.time()
702
+ if now < next_inference_time:
703
+ await asyncio.sleep(next_inference_time - now)
704
+ next_inference_time += 1.0 / self.ai_fps
705
+
706
+ # Run inference directly in asyncio (not thread pool) with autocast for speed
707
+ try:
708
+ start = self.time_module.time()
709
+
710
+ # Use FP16 autocast for faster inference (like play.py can do with modern GPUs)
711
+ from torch.cuda.amp import autocast
712
+ with autocast(dtype=torch.float16, enabled=torch.cuda.is_available()):
713
+ res = self.play_env.step_from_web_input(**web_state)
714
+
715
+ infer_t = self.time_module.time() - start
716
+ await self._out_queue.put((*res, infer_t))
717
+ except Exception as e:
718
+ logger.error(f"Inference worker error: {e}")
719
+ # Put a dummy result to avoid hanging
720
+ dummy_obs = self.obs if self.obs is not None else torch.zeros(3, 150, 600)
721
+ await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
722
+
723
  # Global game engine instance
724
  game_engine = WebGameEngine()
725
 
src/csgo/spawn/0/act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4170e93000886d1dfd379ed48f1c360897c80700db20996ea4cd5ba1464423eb
3
+ size 208
src/csgo/spawn/0/full_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4170e93000886d1dfd379ed48f1c360897c80700db20996ea4cd5ba1464423eb
3
+ size 208
src/csgo/spawn/0/info.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dummy": true}
src/csgo/spawn/0/low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4170e93000886d1dfd379ed48f1c360897c80700db20996ea4cd5ba1464423eb
3
+ size 208
src/csgo/spawn/0/next_act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4170e93000886d1dfd379ed48f1c360897c80700db20996ea4cd5ba1464423eb
3
+ size 208
src/envs/__pycache__/world_model_env.cpython-310.pyc CHANGED
Binary files a/src/envs/__pycache__/world_model_env.cpython-310.pyc and b/src/envs/__pycache__/world_model_env.cpython-310.pyc differ
 
src/game/__pycache__/web_play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ