import os from os.path import join as pjoin import gradio as gr import torch import torch.nn.functional as F import numpy as np from torch.distributions.categorical import Categorical from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer from models.vq.model import RVQVAE, LengthEstimator from utils.get_opt import get_opt from utils.fixseed import fixseed from visualization.joints2bvh import Joint2BVHConvertor from utils.motion_process import recover_from_ric from utils.plot_script import plot_3d_motion from utils.paramUtil import t2m_kinematic_chain clip_version = 'ViT-B/32' class MotionGenerator: def __init__(self, checkpoints_dir, dataset_name, model_name, res_name, vq_name, device='auto'): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.dataset_name = dataset_name self.dim_pose = 251 if dataset_name == 'kit' else 263 self.nb_joints = 21 if dataset_name == 'kit' else 22 # Load models print("Loading models...") self.vq_model, self.vq_opt = self._load_vq_model(checkpoints_dir, dataset_name, vq_name) self.t2m_transformer = self._load_trans_model(checkpoints_dir, dataset_name, model_name) self.res_model = self._load_res_model(checkpoints_dir, dataset_name, res_name, self.vq_opt) self.length_estimator = self._load_len_estimator(checkpoints_dir, dataset_name) # Set to eval mode self.vq_model.eval() self.t2m_transformer.eval() self.res_model.eval() self.length_estimator.eval() # Load normalization stats meta_dir = pjoin(checkpoints_dir, dataset_name, vq_name, 'meta') self.mean = np.load(pjoin(meta_dir, 'mean.npy')) self.std = np.load(pjoin(meta_dir, 'std.npy')) self.kinematic_chain = t2m_kinematic_chain self.converter = Joint2BVHConvertor() print("Models loaded successfully!") def _load_vq_model(self, checkpoints_dir, dataset_name, vq_name): vq_opt_path = pjoin(checkpoints_dir, dataset_name, vq_name, 'opt.txt') vq_opt = get_opt(vq_opt_path, device=self.device) vq_opt.dim_pose = self.dim_pose vq_model = RVQVAE(vq_opt, vq_opt.dim_pose, vq_opt.nb_code, vq_opt.code_dim, vq_opt.output_emb_width, vq_opt.down_t, vq_opt.stride_t, vq_opt.width, vq_opt.depth, vq_opt.dilation_growth_rate, vq_opt.vq_act, vq_opt.vq_norm) ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, vq_name, 'model', 'net_best_fid.tar'), map_location=self.device) model_key = 'vq_model' if 'vq_model' in ckpt else 'net' vq_model.load_state_dict(ckpt[model_key]) vq_model.to(self.device) return vq_model, vq_opt def _load_trans_model(self, checkpoints_dir, dataset_name, model_name): model_opt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'opt.txt') model_opt = get_opt(model_opt_path, device=self.device) model_opt.num_tokens = self.vq_opt.nb_code model_opt.num_quantizers = self.vq_opt.num_quantizers model_opt.code_dim = self.vq_opt.code_dim # Set default values for missing attributes if not hasattr(model_opt, 'latent_dim'): model_opt.latent_dim = 384 if not hasattr(model_opt, 'ff_size'): model_opt.ff_size = 1024 if not hasattr(model_opt, 'n_layers'): model_opt.n_layers = 8 if not hasattr(model_opt, 'n_heads'): model_opt.n_heads = 6 if not hasattr(model_opt, 'dropout'): model_opt.dropout = 0.1 if not hasattr(model_opt, 'cond_drop_prob'): model_opt.cond_drop_prob = 0.1 t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim, cond_mode='text', latent_dim=model_opt.latent_dim, ff_size=model_opt.ff_size, num_layers=model_opt.n_layers, num_heads=model_opt.n_heads, dropout=model_opt.dropout, clip_dim=512, cond_drop_prob=model_opt.cond_drop_prob, clip_version=clip_version, opt=model_opt) ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar'), map_location=self.device) model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans' t2m_transformer.load_state_dict(ckpt[model_key], strict=False) t2m_transformer.to(self.device) return t2m_transformer def _load_res_model(self, checkpoints_dir, dataset_name, res_name, vq_opt): res_opt_path = pjoin(checkpoints_dir, dataset_name, res_name, 'opt.txt') res_opt = get_opt(res_opt_path, device=self.device) # The res_name appears to be the same as vq_name, so res_opt is actually vq_opt # We need to use proper model architecture parameters res_opt.num_quantizers = vq_opt.num_quantizers res_opt.num_tokens = vq_opt.nb_code # Set architecture parameters for ResidualTransformer # These should match the main transformer architecture res_opt.latent_dim = 384 # Match with main transformer res_opt.ff_size = 1024 res_opt.n_layers = 9 # Typically slightly more layers for residual res_opt.n_heads = 6 res_opt.dropout = 0.1 res_opt.cond_drop_prob = 0.1 res_opt.share_weight = False print(f"ResidualTransformer config - latent_dim: {res_opt.latent_dim}, ff_size: {res_opt.ff_size}, nlayers: {res_opt.n_layers}, nheads: {res_opt.n_heads}, dropout: {res_opt.dropout}") res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, cond_mode='text', latent_dim=res_opt.latent_dim, ff_size=res_opt.ff_size, num_layers=res_opt.n_layers, num_heads=res_opt.n_heads, dropout=res_opt.dropout, clip_dim=512, shared_codebook=vq_opt.shared_codebook, cond_drop_prob=res_opt.cond_drop_prob, share_weight=res_opt.share_weight, clip_version=clip_version, opt=res_opt) ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, res_name, 'model', 'net_best_fid.tar'), map_location=self.device) # Debug: check available keys print(f"Available checkpoint keys: {ckpt.keys()}") # Try different possible keys for the model state dict model_key = None for key in ['res_transformer', 'trans', 'net', 'model', 'state_dict']: if key in ckpt: model_key = key break if model_key: print(f"Loading ResidualTransformer from key: {model_key}") res_transformer.load_state_dict(ckpt[model_key], strict=False) else: print("Warning: Could not find model weights in checkpoint. Available keys:", list(ckpt.keys())) # If this is actually a VQ model checkpoint, we might need to skip loading or handle differently if 'vq_model' in ckpt or 'net' in ckpt: print("This appears to be a VQ model checkpoint, not a ResidualTransformer checkpoint.") print("Skipping weight loading - using randomly initialized ResidualTransformer.") res_transformer.to(self.device) return res_transformer def _load_len_estimator(self, checkpoints_dir, dataset_name): model = LengthEstimator(512, 50) ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'finest.tar'), map_location=self.device) model.load_state_dict(ckpt['estimator']) model.to(self.device) return model def inv_transform(self, data): return data * self.std + self.mean @torch.no_grad() def generate(self, text_prompt, motion_length=0, time_steps=18, cond_scale=4, temperature=1, topkr=0.9, gumbel_sample=True, seed=42): """ Generate motion from text prompt Args: text_prompt: Text description of the motion motion_length: Desired motion length (0 for auto-estimation) time_steps: Number of denoising steps cond_scale: Classifier-free guidance scale temperature: Sampling temperature topkr: Top-k filtering threshold gumbel_sample: Whether to use Gumbel sampling seed: Random seed """ fixseed(seed) # Convert motion_length to int if needed if isinstance(motion_length, float): motion_length = int(motion_length) # Estimate length if not provided if motion_length == 0: text_embedding = self.t2m_transformer.encode_text([text_prompt]) pred_dis = self.length_estimator(text_embedding) probs = F.softmax(pred_dis, dim=-1) token_lens = Categorical(probs).sample() else: token_lens = torch.LongTensor([motion_length // 4]).to(self.device) m_length = token_lens * 4 # Generate motion tokens mids = self.t2m_transformer.generate([text_prompt], token_lens, timesteps=int(time_steps), cond_scale=float(cond_scale), temperature=float(temperature), topk_filter_thres=float(topkr), gsample=gumbel_sample) # Refine with residual transformer mids = self.res_model.generate(mids, [text_prompt], token_lens, temperature=1, cond_scale=5) # Decode to motion pred_motions = self.vq_model.forward_decoder(mids) pred_motions = pred_motions.detach().cpu().numpy() # Denormalize data = self.inv_transform(pred_motions) joint_data = data[0, :m_length[0]] # Recover 3D joints joint = recover_from_ric(torch.from_numpy(joint_data).float(), self.nb_joints).numpy() return joint, int(m_length[0].item()) def create_gradio_interface(generator, output_dir='./gradio_outputs'): os.makedirs(output_dir, exist_ok=True) def generate_motion(text_prompt): try: print(f"\nGenerating motion for: '{text_prompt}'") print(f"Device: {generator.device}") # Use default parameters for simplicity motion_length = 0 # Auto-estimate time_steps = 18 cond_scale = 4.0 temperature = 1.0 topkr = 0.9 use_gumbel = True seed = 42 use_ik = True # Generate motion joint, actual_length = generator.generate( text_prompt, motion_length, time_steps, cond_scale, temperature, topkr, use_gumbel, seed ) # Save BVH and video timestamp = str(np.random.randint(100000)) video_path = pjoin(output_dir, f'motion_{timestamp}.mp4') # Convert to BVH with foot IK _, joint_processed = generator.converter.convert( joint, filename=None, iterations=100, foot_ik=True ) # Create video plot_3d_motion(video_path, generator.kinematic_chain, joint_processed, title=text_prompt, fps=20) print(f"Video saved: {video_path}") return video_path except Exception as e: import traceback error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_msg) return None # Create Gradio interface with Blocks for custom layout with gr.Blocks(theme=gr.themes.Base( primary_hue="blue", secondary_hue="gray", ).set( body_background_fill="*neutral_950", body_background_fill_dark="*neutral_950", background_fill_primary="*neutral_900", background_fill_primary_dark="*neutral_900", background_fill_secondary="*neutral_800", background_fill_secondary_dark="*neutral_800", block_background_fill="*neutral_900", block_background_fill_dark="*neutral_900", input_background_fill="*neutral_800", input_background_fill_dark="*neutral_800", button_primary_background_fill="*primary_600", button_primary_background_fill_dark="*primary_600", button_primary_text_color="white", button_primary_text_color_dark="white", block_label_text_color="*neutral_200", block_label_text_color_dark="*neutral_200", body_text_color="*neutral_200", body_text_color_dark="*neutral_200", input_placeholder_color="*neutral_500", input_placeholder_color_dark="*neutral_500", ), css=""" footer {display: none !important;} .video-fixed-height { height: 600px !important; } .video-fixed-height video { max-height: 600px !important; object-fit: contain !important; } """) as demo: gr.Markdown("# 🎭 Text-to-Motion Generator") gr.Markdown("Generate 3D human motion animations from text descriptions") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Describe the motion you want to generate", placeholder="e.g., 'a person walks forward and waves'", lines=3 ) submit_btn = gr.Button("Generate Motion", variant="primary") gr.Examples( examples=[ ["a person walks forward"], ["a person jumps in place"], ["someone performs a dance move"], ["a person sits down on a chair"], ["a person runs and then stops"], ], inputs=text_input, label="Try these examples" ) with gr.Column(): video_output = gr.Video(label="Generated Motion", elem_classes="video-fixed-height") submit_btn.click( fn=generate_motion, inputs=text_input, outputs=video_output ) return demo if __name__ == '__main__': # Configuration CHECKPOINTS_DIR = './checkpoints' DATASET_NAME = 't2m' # or 'kit' MODEL_NAME = 't2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns' RES_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2' VQ_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2' # Initialize generator generator = MotionGenerator( checkpoints_dir=CHECKPOINTS_DIR, dataset_name=DATASET_NAME, model_name=MODEL_NAME, res_name=RES_NAME, vq_name=VQ_NAME, device='auto' ) # Create and launch Gradio interface demo = create_gradio_interface(generator) demo.launch(server_name="0.0.0.0", server_port=7860)