|
|
|
|
|
""" |
|
|
Debug script to understand the expected tensor format for TimeSformer model. |
|
|
This script tests different tensor shapes and formats to find the correct one. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import logging |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
def create_test_frames(num_frames=8, size=(224, 224)): |
|
|
"""Create test frames with different colors to help debug.""" |
|
|
frames = [] |
|
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), |
|
|
(255, 0, 255), (0, 255, 255), (128, 128, 128), (255, 255, 255)] |
|
|
|
|
|
for i in range(num_frames): |
|
|
color = colors[i % len(colors)] |
|
|
frame = Image.new('RGB', size, color) |
|
|
frames.append(frame) |
|
|
|
|
|
return frames |
|
|
|
|
|
def test_tensor_shapes(): |
|
|
"""Test different tensor shapes to see what TimeSformer expects.""" |
|
|
|
|
|
print("π Testing TimeSformer Input Formats") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
from transformers import AutoImageProcessor, TimesformerForVideoClassification |
|
|
|
|
|
|
|
|
print("Loading TimeSformer model...") |
|
|
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400") |
|
|
model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400") |
|
|
model.eval() |
|
|
|
|
|
print("β
Model loaded successfully") |
|
|
print(f"Model config num_frames: {getattr(model.config, 'num_frames', 'Not found')}") |
|
|
print(f"Model config image_size: {getattr(model.config, 'image_size', 'Not found')}") |
|
|
|
|
|
|
|
|
frames = create_test_frames(8, (224, 224)) |
|
|
print(f"β
Created {len(frames)} test frames") |
|
|
|
|
|
|
|
|
print("\nπ Test 1: Using Processor") |
|
|
try: |
|
|
|
|
|
processor_tests = [ |
|
|
("Direct frames", lambda: processor(images=frames, return_tensors="pt")), |
|
|
("List of frames", lambda: processor(images=[frames], return_tensors="pt")), |
|
|
("Videos parameter", lambda: processor(videos=frames, return_tensors="pt") if hasattr(processor, 'videos') else None), |
|
|
("Videos list parameter", lambda: processor(videos=[frames], return_tensors="pt") if hasattr(processor, 'videos') else None), |
|
|
] |
|
|
|
|
|
for test_name, test_func in processor_tests: |
|
|
try: |
|
|
if test_func is None: |
|
|
continue |
|
|
result = test_func() |
|
|
if result and 'pixel_values' in result: |
|
|
tensor = result['pixel_values'] |
|
|
print(f" β
{test_name}: shape {tensor.shape}, dtype {tensor.dtype}, range [{tensor.min():.3f}, {tensor.max():.3f}]") |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
output = model(pixel_values=tensor) |
|
|
print(f" π― Inference successful! Output shape: {output.logits.shape}") |
|
|
return tensor |
|
|
except Exception as inference_error: |
|
|
print(f" β Inference failed: {str(inference_error)[:100]}...") |
|
|
else: |
|
|
print(f" β {test_name}: No pixel_values in result") |
|
|
except Exception as e: |
|
|
print(f" β {test_name}: {str(e)[:100]}...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Processor tests failed: {e}") |
|
|
|
|
|
|
|
|
print("\nπ Test 2: Manual Tensor Creation") |
|
|
|
|
|
|
|
|
frame_arrays = [] |
|
|
for frame in frames: |
|
|
if frame.mode != 'RGB': |
|
|
frame = frame.convert('RGB') |
|
|
if frame.size != (224, 224): |
|
|
frame = frame.resize((224, 224), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
frame_array = np.array(frame, dtype=np.float32) / 255.0 |
|
|
frame_arrays.append(frame_array) |
|
|
|
|
|
print(f"Frame arrays created: {len(frame_arrays)} frames of shape {frame_arrays[0].shape}") |
|
|
|
|
|
|
|
|
tensor_tests = [ |
|
|
|
|
|
("NCHW format", lambda: create_nchw_tensor(frame_arrays)), |
|
|
("NTHW format", lambda: create_nthw_tensor(frame_arrays)), |
|
|
("CTHW format", lambda: create_cthw_tensor(frame_arrays)), |
|
|
("TCHW format", lambda: create_tchw_tensor(frame_arrays)), |
|
|
("Reshaped format", lambda: create_reshaped_tensor(frame_arrays)), |
|
|
] |
|
|
|
|
|
for test_name, create_func in tensor_tests: |
|
|
try: |
|
|
tensor = create_func() |
|
|
print(f" π {test_name}: shape {tensor.shape}, dtype {tensor.dtype}") |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
output = model(pixel_values=tensor) |
|
|
print(f" β
Inference successful! Output logits shape: {output.logits.shape}") |
|
|
|
|
|
|
|
|
probs = torch.softmax(output.logits, dim=-1) |
|
|
top_prob, top_idx = torch.max(probs, dim=-1) |
|
|
label = model.config.id2label[top_idx.item()] |
|
|
print(f" π― Top prediction: {label} ({top_prob.item():.3f})") |
|
|
return tensor |
|
|
|
|
|
except Exception as inference_error: |
|
|
error_msg = str(inference_error) |
|
|
if "channels" in error_msg: |
|
|
print(f" β Channel dimension error: {error_msg[:150]}...") |
|
|
elif "shape" in error_msg: |
|
|
print(f" β Shape error: {error_msg[:150]}...") |
|
|
else: |
|
|
print(f" β Inference error: {error_msg[:150]}...") |
|
|
|
|
|
except Exception as creation_error: |
|
|
print(f" β {test_name}: Creation failed - {creation_error}") |
|
|
|
|
|
print("\nπ₯ No working tensor format found!") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to load model: {e}") |
|
|
return None |
|
|
|
|
|
def create_nchw_tensor(frame_arrays): |
|
|
"""Create tensor in NCHW format: (batch, channels, height, width) for each frame.""" |
|
|
|
|
|
batch_tensors = [] |
|
|
for frame_array in frame_arrays: |
|
|
|
|
|
frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1) |
|
|
batch_tensors.append(frame_tensor) |
|
|
|
|
|
|
|
|
return torch.stack(batch_tensors).unsqueeze(0) |
|
|
|
|
|
def create_nthw_tensor(frame_arrays): |
|
|
"""Create tensor in NTHW format: (batch, frames, height, width) - flattened channels.""" |
|
|
video_array = np.stack(frame_arrays, axis=0) |
|
|
video_tensor = torch.from_numpy(video_array) |
|
|
|
|
|
return video_tensor.view(1, 8 * 3, 224, 224) |
|
|
|
|
|
def create_cthw_tensor(frame_arrays): |
|
|
"""Create tensor in CTHW format: (channels, frames, height, width).""" |
|
|
video_array = np.stack(frame_arrays, axis=0) |
|
|
video_tensor = torch.from_numpy(video_array) |
|
|
|
|
|
video_tensor = video_tensor.permute(3, 0, 1, 2) |
|
|
return video_tensor.unsqueeze(0) |
|
|
|
|
|
def create_tchw_tensor(frame_arrays): |
|
|
"""Create tensor in TCHW format: (frames, channels, height, width).""" |
|
|
video_array = np.stack(frame_arrays, axis=0) |
|
|
video_tensor = torch.from_numpy(video_array) |
|
|
|
|
|
video_tensor = video_tensor.permute(0, 3, 1, 2) |
|
|
return video_tensor.unsqueeze(0) |
|
|
|
|
|
def create_reshaped_tensor(frame_arrays): |
|
|
"""Try reshaping the tensor completely.""" |
|
|
video_array = np.stack(frame_arrays, axis=0) |
|
|
video_tensor = torch.from_numpy(video_array) |
|
|
|
|
|
|
|
|
total_elements = video_tensor.numel() |
|
|
|
|
|
|
|
|
|
|
|
return video_tensor.permute(3, 0, 1, 2).contiguous().view(1, 3*8, 224, 224) |
|
|
|
|
|
def test_working_examples(): |
|
|
"""Test with known working examples from other implementations.""" |
|
|
|
|
|
print("\n㪠Testing Known Working Examples") |
|
|
print("=" * 40) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_tensor = torch.randn(1, 3, 8, 224, 224) |
|
|
print(f"Random tensor shape: {test_tensor.shape}") |
|
|
|
|
|
from transformers import TimesformerForVideoClassification |
|
|
model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400") |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
output = model(pixel_values=test_tensor) |
|
|
print(f"β
Random tensor inference successful! Output shape: {output.logits.shape}") |
|
|
|
|
|
|
|
|
frames = create_test_frames(8, (224, 224)) |
|
|
|
|
|
|
|
|
frame_tensors = [] |
|
|
for frame in frames: |
|
|
if frame.mode != 'RGB': |
|
|
frame = frame.convert('RGB') |
|
|
if frame.size != (224, 224): |
|
|
frame = frame.resize((224, 224), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
frame_array = np.array(frame, dtype=np.float32) / 255.0 |
|
|
frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1) |
|
|
frame_tensors.append(frame_tensor) |
|
|
|
|
|
|
|
|
|
|
|
channel_tensors = [] |
|
|
for c in range(3): |
|
|
channel_frames = [] |
|
|
for frame_tensor in frame_tensors: |
|
|
channel_frames.append(frame_tensor[c]) |
|
|
channel_tensor = torch.stack(channel_frames) |
|
|
channel_tensors.append(channel_tensor) |
|
|
|
|
|
final_tensor = torch.stack(channel_tensors).unsqueeze(0) |
|
|
print(f"Real data tensor shape: {final_tensor.shape}") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(pixel_values=final_tensor) |
|
|
print(f"β
Real data inference successful!") |
|
|
|
|
|
|
|
|
probs = torch.softmax(output.logits, dim=-1) |
|
|
top_probs, top_indices = torch.topk(probs, k=3, dim=-1) |
|
|
|
|
|
print("π― Top 3 predictions:") |
|
|
for i in range(3): |
|
|
idx = top_indices[0][i].item() |
|
|
prob = top_probs[0][i].item() |
|
|
label = model.config.id2label[idx] |
|
|
print(f" {i+1}. {label}: {prob:.3f}") |
|
|
|
|
|
return final_tensor |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Even random tensor failed: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Known examples test failed: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def main(): |
|
|
"""Run all debug tests.""" |
|
|
|
|
|
print("π TimeSformer Input Format Debug") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
working_tensor = test_tensor_shapes() |
|
|
|
|
|
if working_tensor is not None: |
|
|
print(f"\nπ Found working tensor format: {working_tensor.shape}") |
|
|
return 0 |
|
|
|
|
|
|
|
|
working_tensor = test_working_examples() |
|
|
|
|
|
if working_tensor is not None: |
|
|
print(f"\nπ Found working tensor format: {working_tensor.shape}") |
|
|
return 0 |
|
|
|
|
|
print("\nπ₯ No working tensor format found. This suggests a deeper compatibility issue.") |
|
|
print("\nπ§ Recommendations:") |
|
|
print("1. Check if the model version is compatible with your transformers version") |
|
|
print("2. Try using the exact same environment as the original TimeSformer paper") |
|
|
print("3. Check if there are any preprocessing requirements we're missing") |
|
|
|
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|