Spaces:
Running
on
Zero
Running
on
Zero
| import os, shutil, sys | |
| import argparse | |
| import gdown | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import sys | |
| import requests | |
| import json | |
| import torchvision | |
| import torch | |
| import psutil | |
| import time | |
| try: | |
| from mmcv.cnn import ConvModule | |
| except: | |
| os.system("mim install mmcv") | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from track_anything_code.model import TrackingAnything | |
| from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video | |
| from scripts.compress_videos import compress_video | |
| if __name__ == "__main__": | |
| dataset_path = "Bridge_v1_TT14" | |
| video_name = "combined.mp4" | |
| verbose = True # If this is verbose, you will continue to write the code | |
| ################################################## Model setup #################################################### | |
| # check and download checkpoints if needed | |
| sam_checkpoint = "sam_vit_h_4b8939.pth" | |
| sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| xmem_checkpoint = "XMem-s012.pth" | |
| xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" | |
| folder ="./pretrained" | |
| SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) | |
| xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) | |
| # argument | |
| args = parse_augment() | |
| args.device = "cuda" # Any GPU is ok | |
| # Initialize the Track model | |
| track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) | |
| ################################################################################################################### | |
| # Iterate all files under the folder | |
| for sub_folder_name in sorted(os.listdir(dataset_path)): | |
| ################################################## Setting #################################################### | |
| sub_folder_path = os.path.join(dataset_path, sub_folder_name) | |
| click_state = [[],[]] | |
| interactive_state = { | |
| "inference_times": 0, | |
| "negative_click_times" : 0, | |
| "positive_click_times": 0, | |
| "mask_save": args.mask_save, | |
| "multi_mask": { | |
| "mask_names": [], | |
| "masks": [] | |
| }, | |
| "track_end_number": None, | |
| "resize_ratio": 1 | |
| } | |
| ################################################################################################################### | |
| video_path = os.path.join(sub_folder_path, video_name) | |
| if not os.path.exists(video_path): | |
| print("We cannot find the path of the ", video_path, " and we will compress one") | |
| status = compress_video(sub_folder_path, video_name) | |
| if not status: | |
| print("We still cannot generate a video") | |
| continue | |
| # Read video state | |
| video_state = { | |
| "user_name": "", | |
| "video_name": "", | |
| "origin_images": None, | |
| "painted_images": None, | |
| "masks": None, | |
| "inpaint_masks": None, | |
| "logits": None, | |
| "select_frame_number": 0, | |
| "fps": 30 | |
| } | |
| video_state, template_frame = get_frames_from_video(video_path, video_state, track_model) | |
| ########################################################## Get the sam point based on the data.txt ########################################################### | |
| data_txt_path = os.path.join(sub_folder_path, "data.txt") | |
| if not os.path.exists(data_txt_path): | |
| print("We cannot find data.txt in this folder") | |
| continue | |
| data_file = open(data_txt_path, 'r') | |
| lines = data_file.readlines() | |
| frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point | |
| point_cord = [int(float(horizontal)), int(float(vertical))] | |
| # Process by SAM | |
| track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history | |
| painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord) | |
| ################################################################################################################################################################ | |
| ######################################################### Get the tracking output ######################################################################## | |
| # Track the video for processing | |
| segment_output_path = os.path.join(sub_folder_path, "segment_output.gif") | |
| video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now | |
| # Extract the mask needed by us for further point calculating | |
| masks = video_state["masks"] # In the range [0, 1] | |
| if verbose: | |
| for idx, mask in enumerate(masks): | |
| cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255) | |
| ############################################################################################################################################################## | |