Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import os | |
| import sys | |
| import yaml | |
| from pathlib import Path | |
| import time | |
| import threading | |
| import tempfile | |
| import shutil | |
| import gc # ガベージコレクション用 | |
| # --- 定数 --- | |
| # Dockerfile内でクローンされるパスに合わせる | |
| SCRIPT_DIR = Path(__file__).parent | |
| SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2" | |
| # ダウンロード用ファイルの一時置き場 (コンテナ内に作成) | |
| OUTPUT_DIR = SCRIPT_DIR / "outputs" | |
| # --- ヘルパー関数 --- | |
| def add_sbv2_to_path(): | |
| """Style-Bert-VITS2リポジトリのパスを sys.path に追加""" | |
| repo_path_str = str(SBV2_REPO_PATH.resolve()) | |
| if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path: | |
| sys.path.insert(0, repo_path_str) | |
| print(f"Added {repo_path_str} to sys.path") | |
| elif not SBV2_REPO_PATH.exists(): | |
| print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}") | |
| def stream_process_output(process, log_list): | |
| """サブプロセスの標準出力/エラーをリアルタイムでリストに追加""" | |
| try: | |
| if process.stdout: | |
| for line in iter(process.stdout.readline, ''): | |
| log_list.append(line.strip()) # 余分な改行を削除 | |
| if process.stderr: | |
| for line in iter(process.stderr.readline, ''): | |
| processed_line = f"stderr: {line.strip()}" | |
| # 警告はそのまま、他はエラーとして強調 (任意) | |
| if "warning" not in line.lower(): | |
| processed_line = f"ERROR (stderr): {line.strip()}" | |
| log_list.append(processed_line) | |
| except Exception as e: | |
| log_list.append(f"Error reading process stream: {e}") | |
| # --- Gradio アプリのバックエンド関数 --- | |
| def convert_safetensors_to_onnx_gradio( | |
| safetensors_file_obj, | |
| config_file_obj, | |
| style_vectors_file_obj | |
| ): # gr.Progress は削除 | |
| """ | |
| アップロードされたSafetensors, config.json, style_vectors.npy を使って | |
| ONNXに変換し、結果をダウンロード可能にする。 | |
| """ | |
| log = ["Starting ONNX conversion..."] | |
| # 初期状態ではダウンロードファイルは空 | |
| yield "\n".join(log), None | |
| # --- ファイルアップロードの検証 --- | |
| if safetensors_file_obj is None: | |
| log.append("❌ Error: Safetensors file is missing. Please upload the .safetensors file.") | |
| yield "\n".join(log), None | |
| return | |
| if config_file_obj is None: | |
| log.append("❌ Error: config.json file is missing. Please upload the config.json file.") | |
| yield "\n".join(log), None | |
| return | |
| if style_vectors_file_obj is None: | |
| log.append("❌ Error: style_vectors.npy file is missing. Please upload the style_vectors.npy file.") | |
| yield "\n".join(log), None | |
| return | |
| # --- Style-Bert-VITS2 パスの確認 --- | |
| add_sbv2_to_path() | |
| if not SBV2_REPO_PATH.exists(): | |
| log.append(f"❌ Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.") | |
| yield "\n".join(log), None | |
| return | |
| # --- 出力ディレクトリ作成 --- | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| onnx_output_path_str = None # 最終的なONNXファイルパス (文字列) | |
| current_log = log[:] # ログリストをコピー | |
| try: | |
| # --- 一時ディレクトリを作成して処理 --- | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_dir_path = Path(temp_dir) | |
| current_log.append(f"📁 Created temporary directory: {temp_dir_path}") | |
| yield "\n".join(current_log), None # UI更新 | |
| # --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 --- | |
| # モデル名を .safetensors ファイル名から取得 (拡張子なし) | |
| safetensors_filename = Path(safetensors_file_obj.name).name | |
| if not safetensors_filename.lower().endswith(".safetensors"): | |
| current_log.append(f"❌ Error: Invalid safetensors filename: {safetensors_filename}") | |
| yield "\n".join(current_log), None | |
| return | |
| model_name = Path(safetensors_filename).stem # 拡張子を除いた部分 | |
| # assets_root を一時ディレクトリ自体にする | |
| assets_root = temp_dir_path | |
| # assets_root の下に model_name のディレクトリを作成 | |
| model_dir_in_temp = assets_root / model_name | |
| model_dir_in_temp.mkdir(exist_ok=True) | |
| current_log.append(f" - Created model directory: {model_dir_in_temp.relative_to(assets_root)}") | |
| yield "\n".join(current_log), None | |
| # --- 3つのファイルを model_dir_in_temp にコピー --- | |
| files_to_copy = { | |
| "safetensors": safetensors_file_obj, | |
| "config.json": config_file_obj, | |
| "style_vectors.npy": style_vectors_file_obj, | |
| } | |
| copied_paths = {} | |
| for file_key, file_obj in files_to_copy.items(): | |
| original_filename = Path(file_obj.name).name | |
| # ファイル名の基本的な検証 (サニタイズはより厳密に行うことも可能) | |
| if "/" in original_filename or "\\" in original_filename or ".." in original_filename: | |
| current_log.append(f"❌ Error: Invalid characters found in filename: {original_filename}") | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| # 期待されるファイル名と一致しているか確認 (config と style_vectors) | |
| if file_key == "config.json" and original_filename.lower() != "config.json": | |
| current_log.append(f"⚠️ Warning: Uploaded JSON file name is '{original_filename}', expected 'config.json'. Using uploaded name.") | |
| if file_key == "style_vectors.npy" and original_filename.lower() != "style_vectors.npy": | |
| current_log.append(f"⚠️ Warning: Uploaded NPY file name is '{original_filename}', expected 'style_vectors.npy'. Using uploaded name.") | |
| destination_path = model_dir_in_temp / original_filename | |
| try: | |
| shutil.copy(file_obj.name, destination_path) | |
| current_log.append(f" - Copied '{original_filename}' to model directory.") | |
| # .safetensorsファイルのパスを保存しておく | |
| if file_key == "safetensors": | |
| copied_paths["safetensors"] = destination_path | |
| except Exception as e: | |
| current_log.append(f"❌ Error copying file '{original_filename}': {e}") | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| yield "\n".join(current_log), None # 各ファイルコピー後にUI更新 | |
| # safetensorsファイルがコピーされたか確認 | |
| temp_safetensors_path = copied_paths.get("safetensors") | |
| if not temp_safetensors_path: | |
| current_log.append("❌ Error: Failed to locate the copied safetensors file in the temporary directory.") | |
| yield "\n".join(current_log), None | |
| return | |
| current_log.append(f"✅ All required files copied to temporary model directory.") | |
| current_log.append(f" - Using temporary assets_root: {assets_root}") | |
| yield "\n".join(current_log), None | |
| # --- paths.yml を一時的に設定 --- | |
| config_path = SBV2_REPO_PATH / "configs" / "paths.yml" | |
| config_path.parent.mkdir(parents=True, exist_ok=True) | |
| # dataset_root は今回は使わないが設定はしておく (assets_rootと同じ場所) | |
| paths_config = {"dataset_root": str(assets_root.resolve()), "assets_root": str(assets_root.resolve())} | |
| with open(config_path, "w", encoding="utf-8") as f: | |
| yaml.dump(paths_config, f) | |
| current_log.append(f" - Saved temporary paths config to {config_path}") | |
| yield "\n".join(current_log), None | |
| # --- ONNX変換スクリプト実行 --- | |
| current_log.append(f"\n🚀 Starting ONNX conversion script for model '{model_name}'...") | |
| convert_script = SBV2_REPO_PATH / "convert_onnx.py" | |
| if not convert_script.exists(): | |
| current_log.append(f"❌ Error: convert_onnx.py not found at '{convert_script}'. Check repository setup.") | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| python_executable = sys.executable | |
| command = [ | |
| python_executable, | |
| str(convert_script.resolve()), | |
| "--model", | |
| str(temp_safetensors_path.resolve()) # 一時ディレクトリ内の .safetensors パス | |
| ] | |
| current_log.append(f"\n Running command: {' '.join(command)}") | |
| yield "\n".join(current_log), None | |
| process_env = os.environ.copy() | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, | |
| text=True, encoding='utf-8', errors='replace', | |
| cwd=SBV2_REPO_PATH, # スクリプトの場所で実行 | |
| env=process_env) | |
| # ログ出力用リスト (スレッドと共有) | |
| process_output_lines = ["\n--- Conversion Script Output ---"] | |
| thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines)) | |
| thread.start() | |
| # 進捗表示のためのループ | |
| while thread.is_alive(): | |
| yield "\n".join(current_log + process_output_lines), None | |
| time.sleep(0.3) # 更新頻度 | |
| # スレッド終了待ちとプロセス終了待ち | |
| thread.join() | |
| try: | |
| process.wait(timeout=12000) # 20分タイムアウト (モデルサイズにより調整) | |
| except subprocess.TimeoutExpired: | |
| current_log.extend(process_output_lines) # ここまでのログを追加 | |
| current_log.append("\n❌ Error: Conversion process timed out after 20 minutes.") | |
| process.kill() | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| # 最終的なプロセス出力を取得 | |
| final_stdout, final_stderr = process.communicate() | |
| if final_stdout: | |
| process_output_lines.extend(final_stdout.strip().split('\n')) | |
| if final_stderr: | |
| processed_stderr = [] | |
| for line in final_stderr.strip().split('\n'): | |
| processed_line = f"stderr: {line.strip()}" | |
| if "warning" not in line.lower() and line.strip(): # 空行と警告以外 | |
| processed_line = f"ERROR (stderr): {line.strip()}" | |
| processed_stderr.append(processed_line) | |
| if any(line.startswith("ERROR") for line in processed_stderr): | |
| process_output_lines.append("--- Errors/Warnings (stderr) ---") | |
| process_output_lines.extend(processed_stderr) | |
| process_output_lines.append("-----------------------------") | |
| elif processed_stderr: # 警告のみの場合 | |
| process_output_lines.append("--- Warnings (stderr) ---") | |
| process_output_lines.extend(processed_stderr) | |
| process_output_lines.append("------------------------") | |
| # 全てのプロセスログをメインログに追加 | |
| current_log.extend(process_output_lines) | |
| current_log.append("--- End Script Output ---") | |
| current_log.append("\n-------------------------------") | |
| # --- 結果の確認と出力ファイルのコピー --- | |
| if process.returncode == 0: | |
| current_log.append("✅ ONNX conversion command finished successfully.") | |
| # 期待されるONNXファイルパス (入力と同じディレクトリ内) | |
| expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx") | |
| if expected_onnx_path_in_temp.exists(): | |
| current_log.append(f" - Found converted ONNX file: {expected_onnx_path_in_temp.name}") | |
| # 一時ディレクトリから永続的な出力ディレクトリにコピー | |
| final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name | |
| try: | |
| shutil.copy(expected_onnx_path_in_temp, final_onnx_path) | |
| current_log.append(f" - Copied ONNX file for download to: {final_onnx_path.relative_to(SCRIPT_DIR)}") | |
| onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定 | |
| except Exception as e: | |
| current_log.append(f"❌ Error copying ONNX file to output directory: {e}") | |
| else: | |
| current_log.append(f"⚠️ Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp.name}'. Check script output above.") | |
| else: | |
| current_log.append(f"❌ ONNX conversion command failed with return code {process.returncode}.") | |
| current_log.append(" Please check the logs above for errors (especially lines starting with 'ERROR').") | |
| # 一時ディレクトリが自動で削除される前に最終結果をyield | |
| yield "\n".join(current_log), onnx_output_path_str | |
| except FileNotFoundError as e: | |
| current_log.append(f"\n❌ Error: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.") | |
| current_log.append(f"{e}") | |
| yield "\n".join(current_log), None | |
| except Exception as e: | |
| current_log.append(f"\n❌ An unexpected error occurred: {e}") | |
| import traceback | |
| current_log.append(traceback.format_exc()) | |
| yield "\n".join(current_log), None | |
| finally: | |
| # ガベージコレクション | |
| gc.collect() | |
| print("Conversion function finished.") # サーバーログ用 | |
| # 最後のyieldでUIを最終状態に更新 | |
| # yield "\n".join(current_log), onnx_output_path_str # tryブロック内で既に返している | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter") | |
| gr.Markdown( | |
| "Upload your model's `.safetensors`, `config.json`, and `style_vectors.npy` files. " | |
| "The application will convert the model to ONNX format, and you can download the resulting `.onnx` file." | |
| ) | |
| gr.Markdown( | |
| "_(Environment setup is handled automatically when this Space starts.)_" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Upload Model Files") | |
| safetensors_upload = gr.File( | |
| label="Safetensors Model (.safetensors)", | |
| file_types=[".safetensors"], | |
| ) | |
| config_upload = gr.File( | |
| label="Config File (config.json)", | |
| file_types=[".json"], | |
| ) | |
| style_vectors_upload = gr.File( | |
| label="Style Vectors (style_vectors.npy)", | |
| file_types=[".npy"], | |
| ) | |
| convert_button = gr.Button("2. Convert to ONNX", variant="primary", elem_id="convert_button") | |
| gr.Markdown("---") | |
| gr.Markdown("### 3. Download Result") | |
| onnx_download = gr.File( | |
| label="ONNX Model (.onnx)", | |
| interactive=False, # 出力専用 | |
| ) | |
| gr.Markdown( | |
| "**Note:** Conversion can take **several minutes** (5-20+ min depending on model size and hardware). " | |
| "Please be patient. The log on the right shows the progress." | |
| ) | |
| with gr.Column(scale=2): | |
| output_log = gr.Textbox( | |
| label="Conversion Log", | |
| lines=30, # 高さをさらに増やす | |
| interactive=False, | |
| autoscroll=True, | |
| max_lines=2000 # ログが増える可能性 | |
| ) | |
| # ボタンクリック時のアクション設定 | |
| convert_button.click( | |
| convert_safetensors_to_onnx_gradio, | |
| inputs=[safetensors_upload, config_upload, style_vectors_upload], | |
| outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力 | |
| ) | |
| # --- アプリの起動 --- | |
| if __name__ == "__main__": | |
| # Style-Bert-VITS2 へのパスを追加 | |
| add_sbv2_to_path() | |
| # 出力ディレクトリ作成 (存在確認含む) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| print(f"Output directory: {OUTPUT_DIR.resolve()}") | |
| # Gradioアプリを起動 | |
| demo.queue().launch() |