musictimer commited on
Commit
745484d
·
1 Parent(s): 1d96a61

Fix initial bugs

Browse files
Files changed (2) hide show
  1. Dockerfile +9 -2
  2. app.py +29 -3
Dockerfile CHANGED
@@ -19,12 +19,19 @@ RUN pip install --no-cache-dir -r requirements.txt
19
  COPY . .
20
 
21
  # Create necessary directories
22
- RUN mkdir -p csgo/spawn config checkpoints cache
 
23
 
24
  # Set environment variables
25
  ENV PYTHONPATH=/app/src:/app
26
  ENV CUDA_VISIBLE_DEVICES=""
27
- ENV OMP_NUM_THREADS=4
 
 
 
 
 
 
28
 
29
  # Expose port
30
  EXPOSE 7860
 
19
  COPY . .
20
 
21
  # Create necessary directories
22
+ RUN mkdir -p csgo/spawn config checkpoints cache && \
23
+ mkdir -p /tmp/torch /tmp/huggingface /tmp/transformers
24
 
25
  # Set environment variables
26
  ENV PYTHONPATH=/app/src:/app
27
  ENV CUDA_VISIBLE_DEVICES=""
28
+ ENV OMP_NUM_THREADS=2
29
+ ENV MKL_NUM_THREADS=2
30
+
31
+ # Set cache directories to writable locations
32
+ ENV TORCH_HOME=/tmp/torch
33
+ ENV HF_HOME=/tmp/huggingface
34
+ ENV TRANSFORMERS_CACHE=/tmp/transformers
35
 
36
  # Expose port
37
  EXPOSE 7860
app.py CHANGED
@@ -9,6 +9,7 @@ import io
9
  import json
10
  import logging
11
  import os
 
12
  from pathlib import Path
13
  from typing import Dict, List, Optional, Set
14
 
@@ -51,7 +52,21 @@ app = FastAPI(title="Diamond CSGO AI Player")
51
  os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
52
  os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
53
  os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
54
- os.environ.setdefault("OMP_NUM_THREADS", "1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  connected_clients: Set[WebSocket] = set()
56
 
57
  class WebKeyMap:
@@ -117,9 +132,20 @@ class WebGameEngine:
117
 
118
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
119
 
120
- # Use torch.hub to download and load state dict
121
  logger.info(f"Loading state dict from {model_url}")
122
- state_dict = torch.hub.load_state_dict_from_url(model_url, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  self.download_progress = 60
125
  self.loading_status = "Loading model weights into agent..."
 
9
  import json
10
  import logging
11
  import os
12
+ import tempfile
13
  from pathlib import Path
14
  from typing import Dict, List, Optional, Set
15
 
 
52
  os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
53
  os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
54
  os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
55
+
56
+ # Fix OMP_NUM_THREADS for HF Spaces (must be positive integer)
57
+ if "OMP_NUM_THREADS" not in os.environ or not os.environ["OMP_NUM_THREADS"].isdigit():
58
+ os.environ["OMP_NUM_THREADS"] = "2"
59
+
60
+ # Set up proper cache directories for HF Spaces
61
+ temp_dir = tempfile.gettempdir()
62
+ os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
63
+ os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
64
+ os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
65
+
66
+ # Create cache directories
67
+ for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
68
+ cache_path = os.environ[cache_var]
69
+ os.makedirs(cache_path, exist_ok=True)
70
  connected_clients: Set[WebSocket] = set()
71
 
72
  class WebKeyMap:
 
132
 
133
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
134
 
135
+ # Use torch.hub to download and load state dict with custom cache dir
136
  logger.info(f"Loading state dict from {model_url}")
137
+
138
+ # Set custom cache directory that we have write permissions for
139
+ cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
140
+ os.makedirs(cache_dir, exist_ok=True)
141
+
142
+ # Use torch.hub with custom cache directory
143
+ state_dict = torch.hub.load_state_dict_from_url(
144
+ model_url,
145
+ map_location=device,
146
+ model_dir=cache_dir,
147
+ check_hash=False # Skip hash check for faster loading
148
+ )
149
 
150
  self.download_progress = 60
151
  self.loading_status = "Loading model weights into agent..."