Spaces:
Sleeping
Sleeping
Commit
·
745484d
1
Parent(s):
1d96a61
Fix initial bugs
Browse files- Dockerfile +9 -2
- app.py +29 -3
Dockerfile
CHANGED
|
@@ -19,12 +19,19 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
# Create necessary directories
|
| 22 |
-
RUN mkdir -p csgo/spawn config checkpoints cache
|
|
|
|
| 23 |
|
| 24 |
# Set environment variables
|
| 25 |
ENV PYTHONPATH=/app/src:/app
|
| 26 |
ENV CUDA_VISIBLE_DEVICES=""
|
| 27 |
-
ENV OMP_NUM_THREADS=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Expose port
|
| 30 |
EXPOSE 7860
|
|
|
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
# Create necessary directories
|
| 22 |
+
RUN mkdir -p csgo/spawn config checkpoints cache && \
|
| 23 |
+
mkdir -p /tmp/torch /tmp/huggingface /tmp/transformers
|
| 24 |
|
| 25 |
# Set environment variables
|
| 26 |
ENV PYTHONPATH=/app/src:/app
|
| 27 |
ENV CUDA_VISIBLE_DEVICES=""
|
| 28 |
+
ENV OMP_NUM_THREADS=2
|
| 29 |
+
ENV MKL_NUM_THREADS=2
|
| 30 |
+
|
| 31 |
+
# Set cache directories to writable locations
|
| 32 |
+
ENV TORCH_HOME=/tmp/torch
|
| 33 |
+
ENV HF_HOME=/tmp/huggingface
|
| 34 |
+
ENV TRANSFORMERS_CACHE=/tmp/transformers
|
| 35 |
|
| 36 |
# Expose port
|
| 37 |
EXPOSE 7860
|
app.py
CHANGED
|
@@ -9,6 +9,7 @@ import io
|
|
| 9 |
import json
|
| 10 |
import logging
|
| 11 |
import os
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
from typing import Dict, List, Optional, Set
|
| 14 |
|
|
@@ -51,7 +52,21 @@ app = FastAPI(title="Diamond CSGO AI Player")
|
|
| 51 |
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
|
| 52 |
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 53 |
os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
connected_clients: Set[WebSocket] = set()
|
| 56 |
|
| 57 |
class WebKeyMap:
|
|
@@ -117,9 +132,20 @@ class WebGameEngine:
|
|
| 117 |
|
| 118 |
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 119 |
|
| 120 |
-
# Use torch.hub to download and load state dict
|
| 121 |
logger.info(f"Loading state dict from {model_url}")
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
self.download_progress = 60
|
| 125 |
self.loading_status = "Loading model weights into agent..."
|
|
|
|
| 9 |
import json
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
+
import tempfile
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Dict, List, Optional, Set
|
| 15 |
|
|
|
|
| 52 |
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
|
| 53 |
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 54 |
os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
|
| 55 |
+
|
| 56 |
+
# Fix OMP_NUM_THREADS for HF Spaces (must be positive integer)
|
| 57 |
+
if "OMP_NUM_THREADS" not in os.environ or not os.environ["OMP_NUM_THREADS"].isdigit():
|
| 58 |
+
os.environ["OMP_NUM_THREADS"] = "2"
|
| 59 |
+
|
| 60 |
+
# Set up proper cache directories for HF Spaces
|
| 61 |
+
temp_dir = tempfile.gettempdir()
|
| 62 |
+
os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
|
| 63 |
+
os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
|
| 64 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
|
| 65 |
+
|
| 66 |
+
# Create cache directories
|
| 67 |
+
for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
|
| 68 |
+
cache_path = os.environ[cache_var]
|
| 69 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 70 |
connected_clients: Set[WebSocket] = set()
|
| 71 |
|
| 72 |
class WebKeyMap:
|
|
|
|
| 132 |
|
| 133 |
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 134 |
|
| 135 |
+
# Use torch.hub to download and load state dict with custom cache dir
|
| 136 |
logger.info(f"Loading state dict from {model_url}")
|
| 137 |
+
|
| 138 |
+
# Set custom cache directory that we have write permissions for
|
| 139 |
+
cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
|
| 140 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
# Use torch.hub with custom cache directory
|
| 143 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 144 |
+
model_url,
|
| 145 |
+
map_location=device,
|
| 146 |
+
model_dir=cache_dir,
|
| 147 |
+
check_hash=False # Skip hash check for faster loading
|
| 148 |
+
)
|
| 149 |
|
| 150 |
self.download_progress = 60
|
| 151 |
self.loading_status = "Loading model weights into agent..."
|