Spaces:
Sleeping
Sleeping
| import os | |
| from os.path import join as pjoin | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.distributions.categorical import Categorical | |
| from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer | |
| from models.vq.model import RVQVAE, LengthEstimator | |
| from utils.get_opt import get_opt | |
| from utils.fixseed import fixseed | |
| from visualization.joints2bvh import Joint2BVHConvertor | |
| from utils.motion_process import recover_from_ric | |
| from utils.plot_script import plot_3d_motion | |
| from utils.paramUtil import t2m_kinematic_chain | |
| clip_version = 'ViT-B/32' | |
| class MotionGenerator: | |
| def __init__(self, checkpoints_dir, dataset_name, model_name, res_name, vq_name, device='auto'): | |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| self.dataset_name = dataset_name | |
| self.dim_pose = 251 if dataset_name == 'kit' else 263 | |
| self.nb_joints = 21 if dataset_name == 'kit' else 22 | |
| # Load models | |
| print("Loading models...") | |
| self.vq_model, self.vq_opt = self._load_vq_model(checkpoints_dir, dataset_name, vq_name) | |
| self.t2m_transformer = self._load_trans_model(checkpoints_dir, dataset_name, model_name) | |
| self.res_model = self._load_res_model(checkpoints_dir, dataset_name, res_name, self.vq_opt) | |
| self.length_estimator = self._load_len_estimator(checkpoints_dir, dataset_name) | |
| # Set to eval mode | |
| self.vq_model.eval() | |
| self.t2m_transformer.eval() | |
| self.res_model.eval() | |
| self.length_estimator.eval() | |
| # Load normalization stats | |
| meta_dir = pjoin(checkpoints_dir, dataset_name, vq_name, 'meta') | |
| self.mean = np.load(pjoin(meta_dir, 'mean.npy')) | |
| self.std = np.load(pjoin(meta_dir, 'std.npy')) | |
| self.kinematic_chain = t2m_kinematic_chain | |
| self.converter = Joint2BVHConvertor() | |
| print("Models loaded successfully!") | |
| def _load_vq_model(self, checkpoints_dir, dataset_name, vq_name): | |
| vq_opt_path = pjoin(checkpoints_dir, dataset_name, vq_name, 'opt.txt') | |
| vq_opt = get_opt(vq_opt_path, device=self.device) | |
| vq_opt.dim_pose = self.dim_pose | |
| vq_model = RVQVAE(vq_opt, | |
| vq_opt.dim_pose, | |
| vq_opt.nb_code, | |
| vq_opt.code_dim, | |
| vq_opt.output_emb_width, | |
| vq_opt.down_t, | |
| vq_opt.stride_t, | |
| vq_opt.width, | |
| vq_opt.depth, | |
| vq_opt.dilation_growth_rate, | |
| vq_opt.vq_act, | |
| vq_opt.vq_norm) | |
| ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, vq_name, 'model', 'net_best_fid.tar'), | |
| map_location=self.device) | |
| model_key = 'vq_model' if 'vq_model' in ckpt else 'net' | |
| vq_model.load_state_dict(ckpt[model_key]) | |
| vq_model.to(self.device) | |
| return vq_model, vq_opt | |
| def _load_trans_model(self, checkpoints_dir, dataset_name, model_name): | |
| model_opt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'opt.txt') | |
| model_opt = get_opt(model_opt_path, device=self.device) | |
| model_opt.num_tokens = self.vq_opt.nb_code | |
| model_opt.num_quantizers = self.vq_opt.num_quantizers | |
| model_opt.code_dim = self.vq_opt.code_dim | |
| # Set default values for missing attributes | |
| if not hasattr(model_opt, 'latent_dim'): | |
| model_opt.latent_dim = 384 | |
| if not hasattr(model_opt, 'ff_size'): | |
| model_opt.ff_size = 1024 | |
| if not hasattr(model_opt, 'n_layers'): | |
| model_opt.n_layers = 8 | |
| if not hasattr(model_opt, 'n_heads'): | |
| model_opt.n_heads = 6 | |
| if not hasattr(model_opt, 'dropout'): | |
| model_opt.dropout = 0.1 | |
| if not hasattr(model_opt, 'cond_drop_prob'): | |
| model_opt.cond_drop_prob = 0.1 | |
| t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim, | |
| cond_mode='text', | |
| latent_dim=model_opt.latent_dim, | |
| ff_size=model_opt.ff_size, | |
| num_layers=model_opt.n_layers, | |
| num_heads=model_opt.n_heads, | |
| dropout=model_opt.dropout, | |
| clip_dim=512, | |
| cond_drop_prob=model_opt.cond_drop_prob, | |
| clip_version=clip_version, | |
| opt=model_opt) | |
| ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar'), | |
| map_location=self.device) | |
| model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans' | |
| t2m_transformer.load_state_dict(ckpt[model_key], strict=False) | |
| t2m_transformer.to(self.device) | |
| return t2m_transformer | |
| def _load_res_model(self, checkpoints_dir, dataset_name, res_name, vq_opt): | |
| res_opt_path = pjoin(checkpoints_dir, dataset_name, res_name, 'opt.txt') | |
| res_opt = get_opt(res_opt_path, device=self.device) | |
| # The res_name appears to be the same as vq_name, so res_opt is actually vq_opt | |
| # We need to use proper model architecture parameters | |
| res_opt.num_quantizers = vq_opt.num_quantizers | |
| res_opt.num_tokens = vq_opt.nb_code | |
| # Set architecture parameters for ResidualTransformer | |
| # These should match the main transformer architecture | |
| res_opt.latent_dim = 384 # Match with main transformer | |
| res_opt.ff_size = 1024 | |
| res_opt.n_layers = 9 # Typically slightly more layers for residual | |
| res_opt.n_heads = 6 | |
| res_opt.dropout = 0.1 | |
| res_opt.cond_drop_prob = 0.1 | |
| res_opt.share_weight = False | |
| print(f"ResidualTransformer config - latent_dim: {res_opt.latent_dim}, ff_size: {res_opt.ff_size}, nlayers: {res_opt.n_layers}, nheads: {res_opt.n_heads}, dropout: {res_opt.dropout}") | |
| res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, | |
| cond_mode='text', | |
| latent_dim=res_opt.latent_dim, | |
| ff_size=res_opt.ff_size, | |
| num_layers=res_opt.n_layers, | |
| num_heads=res_opt.n_heads, | |
| dropout=res_opt.dropout, | |
| clip_dim=512, | |
| shared_codebook=vq_opt.shared_codebook, | |
| cond_drop_prob=res_opt.cond_drop_prob, | |
| share_weight=res_opt.share_weight, | |
| clip_version=clip_version, | |
| opt=res_opt) | |
| ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, res_name, 'model', 'net_best_fid.tar'), | |
| map_location=self.device) | |
| # Debug: check available keys | |
| print(f"Available checkpoint keys: {ckpt.keys()}") | |
| # Try different possible keys for the model state dict | |
| model_key = None | |
| for key in ['res_transformer', 'trans', 'net', 'model', 'state_dict']: | |
| if key in ckpt: | |
| model_key = key | |
| break | |
| if model_key: | |
| print(f"Loading ResidualTransformer from key: {model_key}") | |
| res_transformer.load_state_dict(ckpt[model_key], strict=False) | |
| else: | |
| print("Warning: Could not find model weights in checkpoint. Available keys:", list(ckpt.keys())) | |
| # If this is actually a VQ model checkpoint, we might need to skip loading or handle differently | |
| if 'vq_model' in ckpt or 'net' in ckpt: | |
| print("This appears to be a VQ model checkpoint, not a ResidualTransformer checkpoint.") | |
| print("Skipping weight loading - using randomly initialized ResidualTransformer.") | |
| res_transformer.to(self.device) | |
| return res_transformer | |
| def _load_len_estimator(self, checkpoints_dir, dataset_name): | |
| model = LengthEstimator(512, 50) | |
| ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'finest.tar'), | |
| map_location=self.device) | |
| model.load_state_dict(ckpt['estimator']) | |
| model.to(self.device) | |
| return model | |
| def inv_transform(self, data): | |
| return data * self.std + self.mean | |
| def generate(self, text_prompt, motion_length=0, time_steps=18, cond_scale=4, | |
| temperature=1, topkr=0.9, gumbel_sample=True, seed=42): | |
| """ | |
| Generate motion from text prompt | |
| Args: | |
| text_prompt: Text description of the motion | |
| motion_length: Desired motion length (0 for auto-estimation) | |
| time_steps: Number of denoising steps | |
| cond_scale: Classifier-free guidance scale | |
| temperature: Sampling temperature | |
| topkr: Top-k filtering threshold | |
| gumbel_sample: Whether to use Gumbel sampling | |
| seed: Random seed | |
| """ | |
| fixseed(seed) | |
| # Convert motion_length to int if needed | |
| if isinstance(motion_length, float): | |
| motion_length = int(motion_length) | |
| # Estimate length if not provided | |
| if motion_length == 0: | |
| text_embedding = self.t2m_transformer.encode_text([text_prompt]) | |
| pred_dis = self.length_estimator(text_embedding) | |
| probs = F.softmax(pred_dis, dim=-1) | |
| token_lens = Categorical(probs).sample() | |
| else: | |
| token_lens = torch.LongTensor([motion_length // 4]).to(self.device) | |
| m_length = token_lens * 4 | |
| # Generate motion tokens | |
| mids = self.t2m_transformer.generate([text_prompt], token_lens, | |
| timesteps=int(time_steps), | |
| cond_scale=float(cond_scale), | |
| temperature=float(temperature), | |
| topk_filter_thres=float(topkr), | |
| gsample=gumbel_sample) | |
| # Refine with residual transformer | |
| mids = self.res_model.generate(mids, [text_prompt], token_lens, | |
| temperature=1, cond_scale=5) | |
| # Decode to motion | |
| pred_motions = self.vq_model.forward_decoder(mids) | |
| pred_motions = pred_motions.detach().cpu().numpy() | |
| # Denormalize | |
| data = self.inv_transform(pred_motions) | |
| joint_data = data[0, :m_length[0]] | |
| # Recover 3D joints | |
| joint = recover_from_ric(torch.from_numpy(joint_data).float(), self.nb_joints).numpy() | |
| return joint, int(m_length[0].item()) | |
| def create_gradio_interface(generator, output_dir='./gradio_outputs'): | |
| os.makedirs(output_dir, exist_ok=True) | |
| def generate_motion(text_prompt): | |
| try: | |
| print(f"\nGenerating motion for: '{text_prompt}'") | |
| print(f"Device: {generator.device}") | |
| # Use default parameters for simplicity | |
| motion_length = 0 # Auto-estimate | |
| time_steps = 18 | |
| cond_scale = 4.0 | |
| temperature = 1.0 | |
| topkr = 0.9 | |
| use_gumbel = True | |
| seed = 42 | |
| use_ik = True | |
| # Generate motion | |
| joint, actual_length = generator.generate( | |
| text_prompt, | |
| motion_length, | |
| time_steps, | |
| cond_scale, | |
| temperature, | |
| topkr, | |
| use_gumbel, | |
| seed | |
| ) | |
| # Save BVH and video | |
| timestamp = str(np.random.randint(100000)) | |
| video_path = pjoin(output_dir, f'motion_{timestamp}.mp4') | |
| # Convert to BVH with foot IK | |
| _, joint_processed = generator.converter.convert( | |
| joint, filename=None, iterations=100, foot_ik=True | |
| ) | |
| # Create video | |
| plot_3d_motion(video_path, generator.kinematic_chain, joint_processed, | |
| title=text_prompt, fps=20) | |
| print(f"Video saved: {video_path}") | |
| return video_path | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None | |
| # Create Gradio interface with Blocks for custom layout | |
| with gr.Blocks(theme=gr.themes.Base( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| ).set( | |
| body_background_fill="*neutral_950", | |
| body_background_fill_dark="*neutral_950", | |
| background_fill_primary="*neutral_900", | |
| background_fill_primary_dark="*neutral_900", | |
| background_fill_secondary="*neutral_800", | |
| background_fill_secondary_dark="*neutral_800", | |
| block_background_fill="*neutral_900", | |
| block_background_fill_dark="*neutral_900", | |
| input_background_fill="*neutral_800", | |
| input_background_fill_dark="*neutral_800", | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_dark="*primary_600", | |
| button_primary_text_color="white", | |
| button_primary_text_color_dark="white", | |
| block_label_text_color="*neutral_200", | |
| block_label_text_color_dark="*neutral_200", | |
| body_text_color="*neutral_200", | |
| body_text_color_dark="*neutral_200", | |
| input_placeholder_color="*neutral_500", | |
| input_placeholder_color_dark="*neutral_500", | |
| ), | |
| css=""" | |
| footer {display: none !important;} | |
| .video-fixed-height { | |
| height: 600px !important; | |
| } | |
| .video-fixed-height video { | |
| max-height: 600px !important; | |
| object-fit: contain !important; | |
| } | |
| """) as demo: | |
| gr.Markdown("# 🎭 Text-to-Motion Generator") | |
| gr.Markdown("Generate 3D human motion animations from text descriptions") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Describe the motion you want to generate", | |
| placeholder="e.g., 'a person walks forward and waves'", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Generate Motion", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| ["a person walks forward"], | |
| ["a person jumps in place"], | |
| ["someone performs a dance move"], | |
| ["a person sits down on a chair"], | |
| ["a person runs and then stops"], | |
| ], | |
| inputs=text_input, | |
| label="Try these examples" | |
| ) | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Motion", elem_classes="video-fixed-height") | |
| submit_btn.click( | |
| fn=generate_motion, | |
| inputs=text_input, | |
| outputs=video_output | |
| ) | |
| return demo | |
| if __name__ == '__main__': | |
| # Configuration | |
| CHECKPOINTS_DIR = './checkpoints' | |
| DATASET_NAME = 't2m' # or 'kit' | |
| MODEL_NAME = 't2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns' | |
| RES_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2' | |
| VQ_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2' | |
| # Initialize generator | |
| generator = MotionGenerator( | |
| checkpoints_dir=CHECKPOINTS_DIR, | |
| dataset_name=DATASET_NAME, | |
| model_name=MODEL_NAME, | |
| res_name=RES_NAME, | |
| vq_name=VQ_NAME, | |
| device='auto' | |
| ) | |
| # Create and launch Gradio interface | |
| demo = create_gradio_interface(generator) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |