nocapdev's picture
Upload folder using huggingface_hub
a2e709b verified
import os
from os.path import join as pjoin
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from torch.distributions.categorical import Categorical
from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
from models.vq.model import RVQVAE, LengthEstimator
from utils.get_opt import get_opt
from utils.fixseed import fixseed
from visualization.joints2bvh import Joint2BVHConvertor
from utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motion
from utils.paramUtil import t2m_kinematic_chain
clip_version = 'ViT-B/32'
class MotionGenerator:
def __init__(self, checkpoints_dir, dataset_name, model_name, res_name, vq_name, device='auto'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.dataset_name = dataset_name
self.dim_pose = 251 if dataset_name == 'kit' else 263
self.nb_joints = 21 if dataset_name == 'kit' else 22
# Load models
print("Loading models...")
self.vq_model, self.vq_opt = self._load_vq_model(checkpoints_dir, dataset_name, vq_name)
self.t2m_transformer = self._load_trans_model(checkpoints_dir, dataset_name, model_name)
self.res_model = self._load_res_model(checkpoints_dir, dataset_name, res_name, self.vq_opt)
self.length_estimator = self._load_len_estimator(checkpoints_dir, dataset_name)
# Set to eval mode
self.vq_model.eval()
self.t2m_transformer.eval()
self.res_model.eval()
self.length_estimator.eval()
# Load normalization stats
meta_dir = pjoin(checkpoints_dir, dataset_name, vq_name, 'meta')
self.mean = np.load(pjoin(meta_dir, 'mean.npy'))
self.std = np.load(pjoin(meta_dir, 'std.npy'))
self.kinematic_chain = t2m_kinematic_chain
self.converter = Joint2BVHConvertor()
print("Models loaded successfully!")
def _load_vq_model(self, checkpoints_dir, dataset_name, vq_name):
vq_opt_path = pjoin(checkpoints_dir, dataset_name, vq_name, 'opt.txt')
vq_opt = get_opt(vq_opt_path, device=self.device)
vq_opt.dim_pose = self.dim_pose
vq_model = RVQVAE(vq_opt,
vq_opt.dim_pose,
vq_opt.nb_code,
vq_opt.code_dim,
vq_opt.output_emb_width,
vq_opt.down_t,
vq_opt.stride_t,
vq_opt.width,
vq_opt.depth,
vq_opt.dilation_growth_rate,
vq_opt.vq_act,
vq_opt.vq_norm)
ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, vq_name, 'model', 'net_best_fid.tar'),
map_location=self.device)
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
vq_model.load_state_dict(ckpt[model_key])
vq_model.to(self.device)
return vq_model, vq_opt
def _load_trans_model(self, checkpoints_dir, dataset_name, model_name):
model_opt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'opt.txt')
model_opt = get_opt(model_opt_path, device=self.device)
model_opt.num_tokens = self.vq_opt.nb_code
model_opt.num_quantizers = self.vq_opt.num_quantizers
model_opt.code_dim = self.vq_opt.code_dim
# Set default values for missing attributes
if not hasattr(model_opt, 'latent_dim'):
model_opt.latent_dim = 384
if not hasattr(model_opt, 'ff_size'):
model_opt.ff_size = 1024
if not hasattr(model_opt, 'n_layers'):
model_opt.n_layers = 8
if not hasattr(model_opt, 'n_heads'):
model_opt.n_heads = 6
if not hasattr(model_opt, 'dropout'):
model_opt.dropout = 0.1
if not hasattr(model_opt, 'cond_drop_prob'):
model_opt.cond_drop_prob = 0.1
t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
cond_mode='text',
latent_dim=model_opt.latent_dim,
ff_size=model_opt.ff_size,
num_layers=model_opt.n_layers,
num_heads=model_opt.n_heads,
dropout=model_opt.dropout,
clip_dim=512,
cond_drop_prob=model_opt.cond_drop_prob,
clip_version=clip_version,
opt=model_opt)
ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar'),
map_location=self.device)
model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
t2m_transformer.to(self.device)
return t2m_transformer
def _load_res_model(self, checkpoints_dir, dataset_name, res_name, vq_opt):
res_opt_path = pjoin(checkpoints_dir, dataset_name, res_name, 'opt.txt')
res_opt = get_opt(res_opt_path, device=self.device)
# The res_name appears to be the same as vq_name, so res_opt is actually vq_opt
# We need to use proper model architecture parameters
res_opt.num_quantizers = vq_opt.num_quantizers
res_opt.num_tokens = vq_opt.nb_code
# Set architecture parameters for ResidualTransformer
# These should match the main transformer architecture
res_opt.latent_dim = 384 # Match with main transformer
res_opt.ff_size = 1024
res_opt.n_layers = 9 # Typically slightly more layers for residual
res_opt.n_heads = 6
res_opt.dropout = 0.1
res_opt.cond_drop_prob = 0.1
res_opt.share_weight = False
print(f"ResidualTransformer config - latent_dim: {res_opt.latent_dim}, ff_size: {res_opt.ff_size}, nlayers: {res_opt.n_layers}, nheads: {res_opt.n_heads}, dropout: {res_opt.dropout}")
res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
cond_mode='text',
latent_dim=res_opt.latent_dim,
ff_size=res_opt.ff_size,
num_layers=res_opt.n_layers,
num_heads=res_opt.n_heads,
dropout=res_opt.dropout,
clip_dim=512,
shared_codebook=vq_opt.shared_codebook,
cond_drop_prob=res_opt.cond_drop_prob,
share_weight=res_opt.share_weight,
clip_version=clip_version,
opt=res_opt)
ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, res_name, 'model', 'net_best_fid.tar'),
map_location=self.device)
# Debug: check available keys
print(f"Available checkpoint keys: {ckpt.keys()}")
# Try different possible keys for the model state dict
model_key = None
for key in ['res_transformer', 'trans', 'net', 'model', 'state_dict']:
if key in ckpt:
model_key = key
break
if model_key:
print(f"Loading ResidualTransformer from key: {model_key}")
res_transformer.load_state_dict(ckpt[model_key], strict=False)
else:
print("Warning: Could not find model weights in checkpoint. Available keys:", list(ckpt.keys()))
# If this is actually a VQ model checkpoint, we might need to skip loading or handle differently
if 'vq_model' in ckpt or 'net' in ckpt:
print("This appears to be a VQ model checkpoint, not a ResidualTransformer checkpoint.")
print("Skipping weight loading - using randomly initialized ResidualTransformer.")
res_transformer.to(self.device)
return res_transformer
def _load_len_estimator(self, checkpoints_dir, dataset_name):
model = LengthEstimator(512, 50)
ckpt = torch.load(pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'finest.tar'),
map_location=self.device)
model.load_state_dict(ckpt['estimator'])
model.to(self.device)
return model
def inv_transform(self, data):
return data * self.std + self.mean
@torch.no_grad()
def generate(self, text_prompt, motion_length=0, time_steps=18, cond_scale=4,
temperature=1, topkr=0.9, gumbel_sample=True, seed=42):
"""
Generate motion from text prompt
Args:
text_prompt: Text description of the motion
motion_length: Desired motion length (0 for auto-estimation)
time_steps: Number of denoising steps
cond_scale: Classifier-free guidance scale
temperature: Sampling temperature
topkr: Top-k filtering threshold
gumbel_sample: Whether to use Gumbel sampling
seed: Random seed
"""
fixseed(seed)
# Convert motion_length to int if needed
if isinstance(motion_length, float):
motion_length = int(motion_length)
# Estimate length if not provided
if motion_length == 0:
text_embedding = self.t2m_transformer.encode_text([text_prompt])
pred_dis = self.length_estimator(text_embedding)
probs = F.softmax(pred_dis, dim=-1)
token_lens = Categorical(probs).sample()
else:
token_lens = torch.LongTensor([motion_length // 4]).to(self.device)
m_length = token_lens * 4
# Generate motion tokens
mids = self.t2m_transformer.generate([text_prompt], token_lens,
timesteps=int(time_steps),
cond_scale=float(cond_scale),
temperature=float(temperature),
topk_filter_thres=float(topkr),
gsample=gumbel_sample)
# Refine with residual transformer
mids = self.res_model.generate(mids, [text_prompt], token_lens,
temperature=1, cond_scale=5)
# Decode to motion
pred_motions = self.vq_model.forward_decoder(mids)
pred_motions = pred_motions.detach().cpu().numpy()
# Denormalize
data = self.inv_transform(pred_motions)
joint_data = data[0, :m_length[0]]
# Recover 3D joints
joint = recover_from_ric(torch.from_numpy(joint_data).float(), self.nb_joints).numpy()
return joint, int(m_length[0].item())
def create_gradio_interface(generator, output_dir='./gradio_outputs'):
os.makedirs(output_dir, exist_ok=True)
def generate_motion(text_prompt):
try:
print(f"\nGenerating motion for: '{text_prompt}'")
print(f"Device: {generator.device}")
# Use default parameters for simplicity
motion_length = 0 # Auto-estimate
time_steps = 18
cond_scale = 4.0
temperature = 1.0
topkr = 0.9
use_gumbel = True
seed = 42
use_ik = True
# Generate motion
joint, actual_length = generator.generate(
text_prompt,
motion_length,
time_steps,
cond_scale,
temperature,
topkr,
use_gumbel,
seed
)
# Save BVH and video
timestamp = str(np.random.randint(100000))
video_path = pjoin(output_dir, f'motion_{timestamp}.mp4')
# Convert to BVH with foot IK
_, joint_processed = generator.converter.convert(
joint, filename=None, iterations=100, foot_ik=True
)
# Create video
plot_3d_motion(video_path, generator.kinematic_chain, joint_processed,
title=text_prompt, fps=20)
print(f"Video saved: {video_path}")
return video_path
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return None
# Create Gradio interface with Blocks for custom layout
with gr.Blocks(theme=gr.themes.Base(
primary_hue="blue",
secondary_hue="gray",
).set(
body_background_fill="*neutral_950",
body_background_fill_dark="*neutral_950",
background_fill_primary="*neutral_900",
background_fill_primary_dark="*neutral_900",
background_fill_secondary="*neutral_800",
background_fill_secondary_dark="*neutral_800",
block_background_fill="*neutral_900",
block_background_fill_dark="*neutral_900",
input_background_fill="*neutral_800",
input_background_fill_dark="*neutral_800",
button_primary_background_fill="*primary_600",
button_primary_background_fill_dark="*primary_600",
button_primary_text_color="white",
button_primary_text_color_dark="white",
block_label_text_color="*neutral_200",
block_label_text_color_dark="*neutral_200",
body_text_color="*neutral_200",
body_text_color_dark="*neutral_200",
input_placeholder_color="*neutral_500",
input_placeholder_color_dark="*neutral_500",
),
css="""
footer {display: none !important;}
.video-fixed-height {
height: 600px !important;
}
.video-fixed-height video {
max-height: 600px !important;
object-fit: contain !important;
}
""") as demo:
gr.Markdown("# 🎭 Text-to-Motion Generator")
gr.Markdown("Generate 3D human motion animations from text descriptions")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Describe the motion you want to generate",
placeholder="e.g., 'a person walks forward and waves'",
lines=3
)
submit_btn = gr.Button("Generate Motion", variant="primary")
gr.Examples(
examples=[
["a person walks forward"],
["a person jumps in place"],
["someone performs a dance move"],
["a person sits down on a chair"],
["a person runs and then stops"],
],
inputs=text_input,
label="Try these examples"
)
with gr.Column():
video_output = gr.Video(label="Generated Motion", elem_classes="video-fixed-height")
submit_btn.click(
fn=generate_motion,
inputs=text_input,
outputs=video_output
)
return demo
if __name__ == '__main__':
# Configuration
CHECKPOINTS_DIR = './checkpoints'
DATASET_NAME = 't2m' # or 'kit'
MODEL_NAME = 't2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns'
RES_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
VQ_NAME = 'rvq_nq6_dc512_nc512_noshare_qdp0.2'
# Initialize generator
generator = MotionGenerator(
checkpoints_dir=CHECKPOINTS_DIR,
dataset_name=DATASET_NAME,
model_name=MODEL_NAME,
res_name=RES_NAME,
vq_name=VQ_NAME,
device='auto'
)
# Create and launch Gradio interface
demo = create_gradio_interface(generator)
demo.launch(server_name="0.0.0.0", server_port=7860)