Spaces:
Sleeping
Sleeping
| """ | |
| Pre-deployment validation script for Hugging Face Spaces | |
| Checks all dependencies and files before deployment | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| from pathlib import Path | |
| import importlib.util | |
| # Fix Windows encoding issues | |
| if sys.platform == 'win32': | |
| import io | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') | |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') | |
| class Colors: | |
| GREEN = '\033[92m' | |
| RED = '\033[91m' | |
| YELLOW = '\033[93m' | |
| BLUE = '\033[94m' | |
| RESET = '\033[0m' | |
| def print_success(msg): | |
| print(f"{Colors.GREEN}✓{Colors.RESET} {msg}") | |
| def print_error(msg): | |
| print(f"{Colors.RED}✗{Colors.RESET} {msg}") | |
| def print_warning(msg): | |
| print(f"{Colors.YELLOW}⚠{Colors.RESET} {msg}") | |
| def print_info(msg): | |
| print(f"{Colors.BLUE}ℹ{Colors.RESET} {msg}") | |
| def check_python_version(): | |
| """Check Python version""" | |
| print("\n[1/8] Checking Python version...") | |
| version = sys.version_info | |
| if version.major == 3 and version.minor >= 8: | |
| print_success(f"Python {version.major}.{version.minor}.{version.micro}") | |
| return True | |
| else: | |
| print_error(f"Python {version.major}.{version.minor}.{version.micro} - Requires Python 3.8+") | |
| return False | |
| def check_required_files(): | |
| """Check for required files""" | |
| print("\n[2/8] Checking required files...") | |
| required_files = { | |
| 'app.py': 'Main application file', | |
| 'requirements.txt': 'Dependencies list', | |
| 'README.md': 'Space documentation' | |
| } | |
| all_found = True | |
| for file, description in required_files.items(): | |
| if Path(file).exists(): | |
| print_success(f"{file:<20} - {description}") | |
| else: | |
| print_error(f"{file:<20} - MISSING ({description})") | |
| all_found = False | |
| return all_found | |
| def check_dependencies(): | |
| """Check if all dependencies can be imported""" | |
| print("\n[3/8] Checking critical dependencies...") | |
| critical_deps = { | |
| 'torch': 'PyTorch', | |
| 'gradio': 'Gradio', | |
| 'numpy': 'NumPy', | |
| 'einops': 'Einops', | |
| 'scipy': 'SciPy', | |
| 'matplotlib': 'Matplotlib', | |
| 'trimesh': 'Trimesh', | |
| 'sklearn': 'Scikit-learn', | |
| 'clip': 'OpenAI CLIP', | |
| } | |
| all_installed = True | |
| for module, name in critical_deps.items(): | |
| try: | |
| __import__(module) | |
| print_success(f"{name:<20}") | |
| except ImportError: | |
| print_error(f"{name:<20} - NOT INSTALLED") | |
| all_installed = False | |
| return all_installed | |
| def check_requirements_txt(): | |
| """Validate requirements.txt format""" | |
| print("\n[4/8] Validating requirements.txt...") | |
| if not Path('requirements.txt').exists(): | |
| print_error("requirements.txt not found") | |
| return False | |
| with open('requirements.txt', 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| issues = [] | |
| # Check for encoding issues | |
| if '��' in content or '→' in content: | |
| issues.append("File has encoding issues (contains weird characters)") | |
| # Check for missing versions (but allow some packages without versions like ftfy, regex) | |
| lines = content.split('\n') | |
| missing_versions = [] | |
| allowed_without_version = ['ftfy', 'regex', 'wheel', 'setuptools', 'pip'] # Dependencies that don't need strict versions | |
| for line in lines: | |
| line = line.strip() | |
| if line and not line.startswith('#'): | |
| if '==' not in line and '>=' not in line and not line.startswith('git+'): | |
| # Check if it's in the allowed list | |
| if not any(allowed in line.lower() for allowed in allowed_without_version): | |
| missing_versions.append(line) | |
| if missing_versions: | |
| issues.append(f"Packages without version: {', '.join(missing_versions[:5])}") | |
| # Check for commented critical packages (only check actual package lines, not section headers) | |
| lines_lower = [line.strip().lower() for line in lines if line.strip() and not line.strip().startswith('#')] | |
| has_gradio = any('gradio' in line for line in lines_lower) | |
| if not has_gradio: | |
| # Check if it's commented out in package lines | |
| commented_lines = [line.strip().lower() for line in lines if line.strip().startswith('#')] | |
| if any(line.startswith('# gradio==') or line.startswith('# gradio>=') or line.startswith('#gradio') for line in commented_lines): | |
| issues.append("gradio is commented out") | |
| if issues: | |
| for issue in issues: | |
| print_error(issue) | |
| return False | |
| else: | |
| print_success("requirements.txt is valid") | |
| return True | |
| def check_model_paths(): | |
| """Check if model checkpoint paths exist""" | |
| print("\n[5/8] Checking model checkpoints...") | |
| checkpoints_dir = './checkpoints' | |
| if not Path(checkpoints_dir).exists(): | |
| print_error(f"Checkpoints directory not found: {checkpoints_dir}") | |
| return False | |
| dataset_name = 't2m' | |
| required_paths = [ | |
| f'{checkpoints_dir}/{dataset_name}/t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns', | |
| f'{checkpoints_dir}/{dataset_name}/rvq_nq6_dc512_nc512_noshare_qdp0.2', | |
| f'{checkpoints_dir}/{dataset_name}/length_estimator', | |
| ] | |
| all_found = True | |
| for path in required_paths: | |
| if Path(path).exists(): | |
| print_success(f"Found: {Path(path).name}") | |
| else: | |
| print_warning(f"Missing: {path}") | |
| all_found = False | |
| if not all_found: | |
| print_warning("Some checkpoints are missing - you'll need to download them") | |
| return True # Don't fail on missing checkpoints as they might be downloaded later | |
| def check_readme(): | |
| """Check README.md content""" | |
| print("\n[6/8] Checking README.md...") | |
| if not Path('README.md').exists(): | |
| print_warning("README.md not found") | |
| return False | |
| with open('README.md', 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| required_sections = ['title:', 'sdk:', 'sdk_version:'] | |
| missing = [s for s in required_sections if s.lower() not in content.lower()] | |
| if missing: | |
| print_warning(f"README.md missing metadata: {', '.join(missing)}") | |
| print_info("Add YAML frontmatter for Hugging Face Spaces") | |
| return False | |
| else: | |
| print_success("README.md has required metadata") | |
| return True | |
| def check_huggingface_token(): | |
| """Check if Hugging Face token is set""" | |
| print("\n[7/8] Checking Hugging Face token...") | |
| token = os.getenv('HUGGINGFACE_TOKEN') | |
| if token: | |
| print_success("HUGGINGFACE_TOKEN environment variable is set") | |
| return True | |
| else: | |
| print_error("HUGGINGFACE_TOKEN not set") | |
| print_info("Set with: $env:HUGGINGFACE_TOKEN = 'hf_your_token'") | |
| return False | |
| def check_app_syntax(): | |
| """Check if app.py has valid Python syntax""" | |
| print("\n[8/8] Checking app.py syntax...") | |
| try: | |
| with open('app.py', 'r', encoding='utf-8') as f: | |
| compile(f.read(), 'app.py', 'exec') | |
| print_success("app.py has valid syntax") | |
| return True | |
| except SyntaxError as e: | |
| print_error(f"Syntax error in app.py: {e}") | |
| return False | |
| def main(): | |
| print("=" * 70) | |
| print(" " * 18 + "Pre-Deployment Validation") | |
| print("=" * 70) | |
| checks = [ | |
| ("Python Version", check_python_version), | |
| ("Required Files", check_required_files), | |
| ("Dependencies", check_dependencies), | |
| ("Requirements.txt", check_requirements_txt), | |
| ("Model Paths", check_model_paths), | |
| ("README.md", check_readme), | |
| ("HF Token", check_huggingface_token), | |
| ("App Syntax", check_app_syntax), | |
| ] | |
| results = {} | |
| for name, check_func in checks: | |
| try: | |
| results[name] = check_func() | |
| except Exception as e: | |
| print_error(f"Check failed with error: {e}") | |
| results[name] = False | |
| # Summary | |
| print("\n" + "=" * 70) | |
| print(" " * 25 + "SUMMARY") | |
| print("=" * 70) | |
| passed = sum(1 for v in results.values() if v) | |
| total = len(results) | |
| for name, result in results.items(): | |
| status = f"{Colors.GREEN}PASS{Colors.RESET}" if result else f"{Colors.RED}FAIL{Colors.RESET}" | |
| print(f"{name:<30} {status}") | |
| print("=" * 70) | |
| print(f"\nPassed: {passed}/{total}") | |
| if passed == total: | |
| print_success("\n✓ All checks passed! Ready to deploy.") | |
| print_info("\nRun: python deploy.py") | |
| return 0 | |
| else: | |
| print_error("\n✗ Some checks failed. Fix issues before deploying.") | |
| print_info("\nFix the issues above and run this script again.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |