PIWM / download_models.py
musictimer's picture
Fix initial bugs
02c6351
#!/usr/bin/env python3
"""
Script to download models from Hugging Face Hub if not present locally
"""
import logging
import os
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def download_checkpoint_if_needed():
"""Download model checkpoint if not present locally"""
# Check if we have any local checkpoints
possible_checkpoints = [
Path("agent_epoch_00206.pt"),
Path("agent_epoch_00003.pt"),
Path("checkpoints/agent_epoch_00206.pt"),
Path("checkpoints/agent_epoch_00003.pt"),
]
for ckpt_path in possible_checkpoints:
if ckpt_path.exists():
logger.info(f"Found local checkpoint: {ckpt_path}")
return True
logger.info("No local checkpoint found, attempting to download from Hugging Face Hub...")
try:
from huggingface_hub import hf_hub_download
# This would download from a hypothetical HF model repository
# You would need to upload your models to HF Hub first
# Example:
# checkpoint_path = hf_hub_download(
# repo_id="your-username/diamond-csgo-model",
# filename="agent_epoch_00206.pt",
# cache_dir="./checkpoints"
# )
logger.warning("Model download not implemented yet.")
logger.warning("Please ensure you have model checkpoints available locally.")
return False
except ImportError:
logger.error("huggingface_hub not installed. Cannot download models.")
return False
except Exception as e:
logger.error(f"Failed to download models: {e}")
return False
def setup_demo_data():
"""Set up minimal demo data if models are not available"""
spawn_dir = Path("csgo/spawn/0")
spawn_dir.mkdir(parents=True, exist_ok=True)
# Create minimal dummy files for demo
import numpy as np
import json
files_to_create = {
"act.npy": np.zeros((100, 51)), # 100 timesteps, 51 actions
"low_res.npy": np.zeros((100, 3, 150, 600)), # 100 frames
"full_res.npy": np.zeros((100, 3, 300, 1200)), # 100 high-res frames
"next_act.npy": np.zeros((100, 51)),
}
for filename, data in files_to_create.items():
file_path = spawn_dir / filename
if not file_path.exists():
np.save(file_path, data)
logger.info(f"Created dummy file: {file_path}")
# Create info.json
info_path = spawn_dir / "info.json"
if not info_path.exists():
info_data = {
"episode_length": 100,
"total_reward": 0.0,
"demo": True
}
with open(info_path, 'w') as f:
json.dump(info_data, f)
logger.info(f"Created info file: {info_path}")
if __name__ == "__main__":
logger.info("Setting up Diamond CSGO demo...")
# Try to download models
has_models = download_checkpoint_if_needed()
# Set up demo data
setup_demo_data()
if not has_models:
logger.warning("Running in demo mode without trained models.")
logger.warning("The AI agent will not function properly without model checkpoints.")
logger.info("Setup complete!")