Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Script to download models from Hugging Face Hub if not present locally | |
| """ | |
| import logging | |
| import os | |
| from pathlib import Path | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def download_checkpoint_if_needed(): | |
| """Download model checkpoint if not present locally""" | |
| # Check if we have any local checkpoints | |
| possible_checkpoints = [ | |
| Path("agent_epoch_00206.pt"), | |
| Path("agent_epoch_00003.pt"), | |
| Path("checkpoints/agent_epoch_00206.pt"), | |
| Path("checkpoints/agent_epoch_00003.pt"), | |
| ] | |
| for ckpt_path in possible_checkpoints: | |
| if ckpt_path.exists(): | |
| logger.info(f"Found local checkpoint: {ckpt_path}") | |
| return True | |
| logger.info("No local checkpoint found, attempting to download from Hugging Face Hub...") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| # This would download from a hypothetical HF model repository | |
| # You would need to upload your models to HF Hub first | |
| # Example: | |
| # checkpoint_path = hf_hub_download( | |
| # repo_id="your-username/diamond-csgo-model", | |
| # filename="agent_epoch_00206.pt", | |
| # cache_dir="./checkpoints" | |
| # ) | |
| logger.warning("Model download not implemented yet.") | |
| logger.warning("Please ensure you have model checkpoints available locally.") | |
| return False | |
| except ImportError: | |
| logger.error("huggingface_hub not installed. Cannot download models.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to download models: {e}") | |
| return False | |
| def setup_demo_data(): | |
| """Set up minimal demo data if models are not available""" | |
| spawn_dir = Path("csgo/spawn/0") | |
| spawn_dir.mkdir(parents=True, exist_ok=True) | |
| # Create minimal dummy files for demo | |
| import numpy as np | |
| import json | |
| files_to_create = { | |
| "act.npy": np.zeros((100, 51)), # 100 timesteps, 51 actions | |
| "low_res.npy": np.zeros((100, 3, 150, 600)), # 100 frames | |
| "full_res.npy": np.zeros((100, 3, 300, 1200)), # 100 high-res frames | |
| "next_act.npy": np.zeros((100, 51)), | |
| } | |
| for filename, data in files_to_create.items(): | |
| file_path = spawn_dir / filename | |
| if not file_path.exists(): | |
| np.save(file_path, data) | |
| logger.info(f"Created dummy file: {file_path}") | |
| # Create info.json | |
| info_path = spawn_dir / "info.json" | |
| if not info_path.exists(): | |
| info_data = { | |
| "episode_length": 100, | |
| "total_reward": 0.0, | |
| "demo": True | |
| } | |
| with open(info_path, 'w') as f: | |
| json.dump(info_data, f) | |
| logger.info(f"Created info file: {info_path}") | |
| if __name__ == "__main__": | |
| logger.info("Setting up Diamond CSGO demo...") | |
| # Try to download models | |
| has_models = download_checkpoint_if_needed() | |
| # Set up demo data | |
| setup_demo_data() | |
| if not has_models: | |
| logger.warning("Running in demo mode without trained models.") | |
| logger.warning("The AI agent will not function properly without model checkpoints.") | |
| logger.info("Setup complete!") | |