Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import subprocess | |
| import numpy as np | |
| import torch | |
| import imageio | |
| from skimage.transform import resize | |
| from skimage import img_as_ubyte | |
| import gradio as gr | |
| from PIL import Image | |
| import tempfile | |
| import requests | |
| from io import BytesIO | |
| # Đảm bảo cài đặt các thư viện cần thiết | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-image==0.19.3"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "face-alignment==1.3.5"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "PyYAML==5.3.1"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "imageio-ffmpeg==0.4.5"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"]) | |
| # Cài đặt ffmpeg trong môi trường Ubuntu | |
| os.system("apt-get update && apt-get install -y ffmpeg") | |
| # Clone repo nếu chưa có | |
| if not os.path.exists('first_order_model'): | |
| subprocess.call(['git', 'clone', 'https://github.com/AliaksandrSiarohin/first-order-model.git']) | |
| if os.path.exists('first-order-model'): | |
| os.rename('first-order-model', 'first_order_model') | |
| # Thêm đường dẫn vào PYTHONPATH | |
| sys.path.append('.') | |
| sys.path.append('first_order_model') | |
| # Tạo file helper với hàm load_checkpoints | |
| with open('load_helper.py', 'w') as f: | |
| f.write(""" | |
| import yaml | |
| import torch | |
| from first_order_model.modules.generator import OcclusionAwareGenerator | |
| from first_order_model.modules.keypoint_detector import KPDetector | |
| def load_checkpoints(config_path, checkpoint_path, device='cpu'): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], | |
| **config['model_params']['common_params']) | |
| generator.to(device) | |
| kp_detector = KPDetector(**config['model_params']['kp_detector_params'], | |
| **config['model_params']['common_params']) | |
| kp_detector.to(device) | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| generator.load_state_dict(checkpoint['generator']) | |
| kp_detector.load_state_dict(checkpoint['kp_detector']) | |
| generator.eval() | |
| kp_detector.eval() | |
| return generator, kp_detector | |
| def normalize_kp(kp_source, kp_driving, kp_driving_initial, | |
| use_relative_movement=True, use_relative_jacobian=True, adapt_movement_scale=True): | |
| from first_order_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
| kp_new = {k: v for k, v in kp_driving.items()} | |
| if use_relative_movement: | |
| kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) | |
| kp_value_diff_abs = torch.abs(kp_value_diff) | |
| if adapt_movement_scale: | |
| distance = torch.max(kp_value_diff_abs, dim=2, keepdim=True)[0] | |
| distance = torch.max(distance, dim=1, keepdim=True)[0] | |
| kp_source_diff = torch.abs(kp_source['value']) | |
| kp_source_max = torch.max(kp_source_diff, dim=2, keepdim=True)[0] | |
| kp_source_max = torch.max(kp_source_max, dim=1, keepdim=True)[0] | |
| movement_scale = kp_source_max / (distance + 1e-6) | |
| kp_new['value'] = kp_source['value'] + movement_scale * kp_value_diff | |
| else: | |
| kp_new['value'] = kp_source['value'] + kp_value_diff | |
| if use_relative_jacobian: | |
| jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) | |
| kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) | |
| return kp_new | |
| """) | |
| # Import hàm load_checkpoints từ file helper | |
| from load_helper import load_checkpoints, normalize_kp | |
| # Tải mô hình từ GitHub hoặc mirrors của first-order-model | |
| def download_model(): | |
| # URLs trực tiếp từ sources khác | |
| checkpoint_urls = [ | |
| "https://github.com/AliaksandrSiarohin/first-order-model/releases/download/v1.0.0/vox-cpk.pth.tar", | |
| "https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/checkpoints/vox-cpk.pth.tar", | |
| "https://github.com/snap-research/articulated-animation/raw/master/checkpoints/vox.pth.tar" | |
| ] | |
| config_urls = [ | |
| "https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml", | |
| "https://gist.githubusercontent.com/anonymous/raw/vox-256.yaml" | |
| ] | |
| # Tạo thư mục | |
| model_path = 'checkpoints/vox-cpk.pth.tar' | |
| if not os.path.exists('checkpoints'): | |
| os.makedirs('checkpoints', exist_ok=True) | |
| config_path = 'first_order_model/config/vox-256.yaml' | |
| if not os.path.exists('first_order_model/config'): | |
| os.makedirs('first_order_model/config', exist_ok=True) | |
| # Tải model checkpoint | |
| success = False | |
| for url in checkpoint_urls: | |
| try: | |
| print(f"Đang thử tải mô hình từ: {url}") | |
| response = requests.get(url, stream=True, timeout=30) | |
| if response.status_code == 200: | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| # Kiểm tra kích thước file (checkpoint mô hình thường > 100MB) | |
| if os.path.getsize(model_path) > 100000000: | |
| success = True | |
| break | |
| except Exception as e: | |
| print(f"Lỗi khi tải từ {url}: {str(e)}") | |
| if not success: | |
| raise Exception("Không thể tải mô hình checkpoint từ bất kỳ nguồn nào") | |
| # Tải file cấu hình | |
| config_success = False | |
| for url in config_urls: | |
| try: | |
| print(f"Đang thử tải file cấu hình từ: {url}") | |
| response = requests.get(url, timeout=30) | |
| if response.status_code == 200: | |
| with open(config_path, 'wb') as f: | |
| f.write(response.content) | |
| if os.path.getsize(config_path) > 1000: | |
| config_success = True | |
| break | |
| except Exception as e: | |
| print(f"Lỗi khi tải cấu hình từ {url}: {str(e)}") | |
| if not config_success: | |
| # Tạo file cấu hình đơn giản nếu không tải được | |
| create_simple_config(config_path) | |
| return config_path, model_path | |
| # Tạo file cấu hình đơn giản nếu không tải được | |
| def create_simple_config(config_path): | |
| with open(config_path, 'w') as f: | |
| f.write(""" | |
| model_params: | |
| common_params: | |
| num_kp: 10 | |
| num_channels: 3 | |
| estimate_jacobian: true | |
| kp_detector_params: | |
| temperature: 0.1 | |
| block_expansion: 32 | |
| max_features: 1024 | |
| scale_factor: 0.25 | |
| num_blocks: 5 | |
| generator_params: | |
| block_expansion: 64 | |
| max_features: 512 | |
| num_down_blocks: 2 | |
| num_bottleneck_blocks: 6 | |
| estimate_occlusion_map: true | |
| dense_motion_params: | |
| block_expansion: 64 | |
| max_features: 1024 | |
| num_blocks: 5 | |
| scale_factor: 0.25 | |
| """) | |
| print("Đã tạo file cấu hình đơn giản") | |
| # Hàm tạo animation | |
| def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True): | |
| config_path, checkpoint_path = download_model() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Tải mô hình và cấu hình | |
| generator, kp_detector = load_checkpoints(config_path, checkpoint_path, device=device) | |
| # Đọc source_image và driving_video | |
| source = imageio.imread(source_image) | |
| reader = imageio.get_reader(driving_video) | |
| fps = reader.get_meta_data()['fps'] | |
| driving = [] | |
| try: | |
| for im in reader: | |
| driving.append(im) | |
| except RuntimeError: | |
| pass | |
| reader.close() | |
| # Tiền xử lý | |
| source = resize(source, (256, 256))[..., :3] | |
| driving = [resize(frame, (256, 256))[..., :3] for frame in driving] | |
| # Chuyển đổi thành tensor | |
| source = torch.tensor(source[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) | |
| driving = torch.tensor(np.array(driving).astype(np.float32)).permute(0, 3, 1, 2).to(device) | |
| # Trích xuất keypoints | |
| kp_source = kp_detector(source) | |
| kp_driving_initial = kp_detector(driving[0:1]) | |
| # Tạo animation | |
| with torch.no_grad(): | |
| predictions = [] | |
| for frame_idx in range(driving.shape[0]): | |
| driving_frame = driving[frame_idx:frame_idx+1] | |
| kp_driving = kp_detector(driving_frame) | |
| # Chuẩn hóa keypoints | |
| kp_norm = normalize_kp( | |
| kp_source=kp_source, | |
| kp_driving=kp_driving, | |
| kp_driving_initial=kp_driving_initial, | |
| use_relative_movement=relative, | |
| use_relative_jacobian=relative, | |
| adapt_movement_scale=adapt_movement_scale | |
| ) | |
| # Tạo frame | |
| out = generator(source, kp_source=kp_source, kp_driving=kp_norm) | |
| predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
| # Lưu video kết quả | |
| output_path = f'result_{int(np.random.rand() * 10000)}.mp4' | |
| if os.path.exists(output_path): | |
| os.remove(output_path) # Xóa video nếu tồn tại | |
| # Lưu frames thành video sử dụng imageio | |
| frames = [img_as_ubyte(frame) for frame in predictions] | |
| imageio.mimsave(output_path, frames, fps=fps) | |
| return output_path | |
| # Tải video mẫu | |
| def download_sample_video(): | |
| sample_urls = [ | |
| "https://github.com/AliaksandrSiarohin/first-order-model/raw/master/driving.mp4", | |
| "https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/sample/driving.mp4" | |
| ] | |
| sample_path = "sample_driving.mp4" | |
| for url in sample_urls: | |
| try: | |
| print(f"Đang thử tải video mẫu từ: {url}") | |
| response = requests.get(url, timeout=30) | |
| if response.status_code == 200: | |
| with open(sample_path, 'wb') as f: | |
| f.write(response.content) | |
| if os.path.getsize(sample_path) > 10000: # Kiểm tra kích thước file | |
| return sample_path | |
| except Exception as e: | |
| print(f"Lỗi khi tải video mẫu từ {url}: {str(e)}") | |
| # Nếu không tải được, tạo video đơn giản | |
| create_simple_video(sample_path) | |
| return sample_path | |
| # Tạo video đơn giản nếu không tải được video mẫu | |
| def create_simple_video(output_path): | |
| import cv2 | |
| out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (256, 256)) | |
| # Tạo 100 khung hình với chuyển động đơn giản | |
| for i in range(100): | |
| frame = np.zeros((256, 256, 3), dtype=np.uint8) | |
| # Vẽ khuôn mặt đơn giản chuyển động | |
| x_center = 128 + int(np.sin(i/10) * 20) | |
| y_center = 128 + int(np.cos(i/20) * 10) | |
| # Vẽ khuôn mặt | |
| cv2.circle(frame, (x_center, y_center), 60, (200, 200, 200), -1) # Mặt | |
| cv2.circle(frame, (x_center - 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt trái | |
| cv2.circle(frame, (x_center + 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt phải | |
| # Vẽ miệng | |
| mouth_y = y_center + 20 + int(np.sin(i/5) * 5) | |
| cv2.ellipse(frame, (x_center, mouth_y), (20, 10), 0, 0, 180, (0, 0, 0), -1) | |
| out.write(frame) | |
| out.release() | |
| print("Đã tạo video đơn giản") | |
| # Định nghĩa giao diện Gradio | |
| def animate_fomm(source_image, driving_video_file, relative=True, adapt_scale=True): | |
| if source_image is None: | |
| return None, "Vui lòng tải lên ảnh nguồn." | |
| try: | |
| # Lưu tạm ảnh nguồn | |
| source_path = f"source_image_{int(np.random.rand() * 10000)}.jpg" | |
| source_image.save(source_path) | |
| # Xử lý video tham chiếu | |
| print(f"Type of driving_video: {type(driving_video_file)}") | |
| # Tạo file tạm cho video | |
| driving_path = f"driving_video_{int(np.random.rand() * 10000)}.mp4" | |
| # Kiểm tra nếu đã chọn sử dụng video mẫu | |
| if driving_video_file is None: | |
| # Tải và sử dụng video mẫu | |
| driving_path = download_sample_video() | |
| else: | |
| # Xử lý video được tải lên | |
| if isinstance(driving_video_file, str): | |
| # Nếu là đường dẫn, copy file | |
| if os.path.exists(driving_video_file): | |
| import shutil | |
| shutil.copyfile(driving_video_file, driving_path) | |
| else: | |
| return None, f"Không tìm thấy file video tại đường dẫn: {driving_video_file}" | |
| else: | |
| # Ghi dữ liệu nhị phân vào file | |
| with open(driving_path, 'wb') as f: | |
| f.write(driving_video_file) | |
| # Tạo animation | |
| result_path = make_animation( | |
| source_path, | |
| driving_path, | |
| relative=relative, | |
| adapt_movement_scale=adapt_scale | |
| ) | |
| # Xóa file tạm nếu cần | |
| if os.path.exists(source_path) and source_path != "source_image.jpg": | |
| os.remove(source_path) | |
| if os.path.exists(driving_path) and driving_path != "sample_driving.mp4" and driving_path != "driving_video.mp4": | |
| os.remove(driving_path) | |
| return result_path, "Video được tạo thành công!" | |
| except Exception as e: | |
| import traceback | |
| return None, f"Lỗi: {str(e)}\n{traceback.format_exc()}" | |
| # Tạo giao diện Gradio | |
| with gr.Blocks(title="First Order Motion Model - Tạo video người chuyển động") as demo: | |
| gr.Markdown("# First Order Motion Model") | |
| gr.Markdown("Tạo video người chuyển động từ một ảnh tĩnh và video tham chiếu") | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_image = gr.Image(type="pil", label="Tải lên ảnh nguồn") | |
| # Thêm tùy chọn sử dụng video mẫu | |
| use_sample = gr.Checkbox(label="Sử dụng video mẫu có sẵn", value=True) | |
| # Thay đổi từ gr.Video sang gr.File để xử lý lỗi binary | |
| driving_video_file = gr.File(label="Tải lên video tham chiếu (.mp4)", visible=False) | |
| with gr.Row(): | |
| relative = gr.Checkbox(value=True, label="Chuyển động tương đối") | |
| adapt_scale = gr.Checkbox(value=True, label="Điều chỉnh tỷ lệ chuyển động") | |
| submit_btn = gr.Button("Tạo video") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Video kết quả") | |
| output_message = gr.Textbox(label="Thông báo", lines=5) | |
| # Xử lý sự kiện khi checkbox được chọn | |
| def toggle_video_upload(use_sample_video): | |
| return gr.update(visible=not use_sample_video) | |
| use_sample.change(fn=toggle_video_upload, inputs=[use_sample], outputs=[driving_video_file]) | |
| # Cập nhật hàm xử lý khi nhấn nút | |
| def process_inputs(source_img, use_sample_vid, driving_vid, rel, adapt): | |
| if use_sample_vid: | |
| return animate_fomm(source_img, None, rel, adapt) | |
| else: | |
| return animate_fomm(source_img, driving_vid, rel, adapt) | |
| submit_btn.click( | |
| fn=process_inputs, | |
| inputs=[source_image, use_sample, driving_video_file, relative, adapt_scale], | |
| outputs=[output_video, output_message] | |
| ) | |
| gr.Markdown("### Cách sử dụng") | |
| gr.Markdown("1. Tải lên **ảnh nguồn** - ảnh chứa người/đối tượng bạn muốn làm chuyển động") | |
| gr.Markdown("2. Chọn sử dụng video mẫu có sẵn hoặc tải lên video tham chiếu của riêng bạn") | |
| gr.Markdown("3. Nhấn **Tạo video** và chờ kết quả") | |
| gr.Markdown("### Lưu ý") | |
| gr.Markdown("- Ảnh nguồn và video tham chiếu nên có đối tượng tương tự (người với người, mặt với mặt)") | |
| gr.Markdown("- Đối tượng nên ở vị trí tương tự trong cả ảnh nguồn và khung đầu tiên của video tham chiếu") | |
| gr.Markdown("- Quá trình tạo video có thể mất vài phút") | |
| gr.Markdown("- Nếu gặp vấn đề với việc tải lên video, hãy sử dụng video mẫu có sẵn") | |
| demo.launch() |