musictimer commited on
Commit
7deb5ff
·
1 Parent(s): bbfa773
app.py CHANGED
@@ -95,82 +95,63 @@ class WebGameEngine:
95
 
96
  def load_model_weights():
97
  """Load model weights in thread pool to avoid blocking"""
98
- state_dict = None
99
-
100
- # Try torch.hub method first
101
  try:
102
- logger.info("Trying to load model using torch.hub...")
103
- self.loading_status = "Downloading model with torch.hub..."
 
104
  self.download_progress = 10
105
 
106
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
107
- state_dict = torch.hub.load_state_dict_from_url(
108
- model_url,
109
- map_location=device,
110
- progress=False,
111
- check_hash=False
112
- )
113
- logger.info("Successfully loaded model using torch.hub")
114
 
115
- except Exception as e:
116
- logger.warning(f"Failed to load model with torch.hub: {e}")
 
117
 
118
- # Try huggingface_hub method as fallback
119
- try:
120
- logger.info("Trying to load model using huggingface_hub...")
121
- self.loading_status = "Downloading model with huggingface_hub..."
122
- self.download_progress = 10
123
-
124
- from huggingface_hub import hf_hub_download
125
-
126
- # Download the file
127
- model_path = hf_hub_download(
128
- repo_id="Etadingrui/diamond-1B",
129
- filename="agent_epoch_00003.pt",
130
- cache_dir=None # Use default cache
131
- )
132
- self.download_progress = 40
133
- self.loading_status = "Loading downloaded model..."
134
-
135
- # Load the state dict
136
- state_dict = torch.load(model_path, map_location=device)
137
- logger.info("Successfully loaded model using huggingface_hub")
138
-
139
- except Exception as e2:
140
- logger.error(f"Failed to load model with huggingface_hub: {e2}")
141
- raise Exception("All model loading methods failed")
142
-
143
- if state_dict is None:
144
- raise Exception("Failed to load model state dict")
145
 
146
- # Load state dict into agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  try:
148
  logger.info("Model download completed, loading weights...")
149
  self.download_progress = 60
150
- self.loading_status = "Model downloaded, loading weights..."
151
 
152
- # Load each component of the agent using extract_state_dict (same as agent.load method)
153
- if any(k.startswith("denoiser") for k in state_dict.keys()):
154
- agent.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
155
- logger.info("Loaded denoiser weights")
156
-
157
- self.download_progress = 70
158
- self.loading_status = "Loading upsampler..."
159
- if any(k.startswith("upsampler") for k in state_dict.keys()) and agent.upsampler is not None:
160
- agent.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
161
- logger.info("Loaded upsampler weights")
162
-
163
- self.download_progress = 80
164
- self.loading_status = "Loading reward model..."
165
- if any(k.startswith("rew_end_model") for k in state_dict.keys()) and agent.rew_end_model is not None:
166
- agent.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
167
- logger.info("Loaded reward model weights")
168
-
169
- self.download_progress = 90
170
- self.loading_status = "Loading actor critic..."
171
- if any(k.startswith("actor_critic") for k in state_dict.keys()) and agent.actor_critic is not None:
172
- agent.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
173
- logger.info("Loaded actor critic weights")
174
 
175
  self.download_progress = 100
176
  self.loading_status = "Model loaded successfully!"
 
95
 
96
  def load_model_weights():
97
  """Load model weights in thread pool to avoid blocking"""
 
 
 
98
  try:
99
+ # Direct download without any caching to avoid permission issues on HF Spaces
100
+ logger.info("Downloading model directly without caching...")
101
+ self.loading_status = "Downloading model without caching..."
102
  self.download_progress = 10
103
 
104
  model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
 
 
 
 
 
 
 
105
 
106
+ # Use requests to download directly into memory
107
+ import requests
108
+ import io
109
 
110
+ logger.info(f"Starting direct download from {model_url}")
111
+ response = requests.get(model_url, stream=True)
112
+ response.raise_for_status()
113
+
114
+ # Get the total file size for progress tracking
115
+ total_size = int(response.headers.get('content-length', 0))
116
+ logger.info(f"Model file size: {total_size / (1024*1024):.1f} MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Download with progress tracking
119
+ downloaded_data = io.BytesIO()
120
+ downloaded_size = 0
121
+
122
+ for chunk in response.iter_content(chunk_size=8192):
123
+ if chunk:
124
+ downloaded_data.write(chunk)
125
+ downloaded_size += len(chunk)
126
+
127
+ # Update progress
128
+ if total_size > 0:
129
+ progress = min(50, int((downloaded_size / total_size) * 40) + 10) # 10-50%
130
+ if progress != self.download_progress:
131
+ self.download_progress = progress
132
+ logger.info(f"Download progress: {progress}%")
133
+
134
+ self.download_progress = 50
135
+ self.loading_status = "Download complete, loading model..."
136
+ logger.info("Download completed, loading state dict...")
137
+
138
+ # Reset to beginning of buffer and load
139
+ downloaded_data.seek(0)
140
+ state_dict = torch.load(downloaded_data, map_location=device)
141
+ logger.info("Successfully loaded model using direct download")
142
+
143
+ except Exception as e:
144
+ logger.error(f"Failed to download model directly: {e}")
145
+ raise Exception(f"Direct download failed: {str(e)}")
146
+
147
+ # Load state dict into agent using the new load_state_dict method
148
  try:
149
  logger.info("Model download completed, loading weights...")
150
  self.download_progress = 60
151
+ self.loading_status = "Loading model weights into agent..."
152
 
153
+ # Use the agent's new load_state_dict method
154
+ agent.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  self.download_progress = 100
157
  self.loading_status = "Model loaded successfully!"
requirements.txt CHANGED
@@ -13,6 +13,7 @@ fastapi>=0.68.0
13
  uvicorn>=0.15.0
14
  websockets>=10.0
15
  python-multipart>=0.0.5
 
16
 
17
  # Image processing
18
  opencv-python-headless>=4.5.0
 
13
  uvicorn>=0.15.0
14
  websockets>=10.0
15
  python-multipart>=0.0.5
16
+ requests>=2.25.0
17
 
18
  # Image processing
19
  opencv-python-headless>=4.5.0
src/__pycache__/agent.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/agent.cpython-310.pyc and b/src/__pycache__/agent.cpython-310.pyc differ
 
src/agent.py CHANGED
@@ -64,11 +64,22 @@ class Agent(nn.Module):
64
  load_actor_critic: bool = True,
65
  ) -> None:
66
  sd = torch.load(Path(path_to_ckpt), map_location=self.device)
 
 
 
 
 
 
 
 
 
 
 
67
  if load_denoiser:
68
- self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser"))
69
- if load_upsampler:
70
- self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler"))
71
  if load_rew_end_model and self.rew_end_model is not None:
72
- self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model"))
73
  if load_actor_critic and self.actor_critic is not None:
74
- self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic"))
 
64
  load_actor_critic: bool = True,
65
  ) -> None:
66
  sd = torch.load(Path(path_to_ckpt), map_location=self.device)
67
+ self.load_state_dict(sd, load_denoiser, load_upsampler, load_rew_end_model, load_actor_critic)
68
+
69
+ def load_state_dict(
70
+ self,
71
+ state_dict: dict,
72
+ load_denoiser: bool = True,
73
+ load_upsampler: bool = True,
74
+ load_rew_end_model: bool = True,
75
+ load_actor_critic: bool = True,
76
+ ) -> None:
77
+ """Load state dict directly without file I/O"""
78
  if load_denoiser:
79
+ self.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
80
+ if load_upsampler and self.upsampler is not None:
81
+ self.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
82
  if load_rew_end_model and self.rew_end_model is not None:
83
+ self.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
84
  if load_actor_critic and self.actor_critic is not None:
85
+ self.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))