rework temp files
Browse files- demo.py +6 -11
- dust3r +1 -1
- mast3r/demo.py +70 -21
demo.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
import os
|
| 9 |
import torch
|
| 10 |
import tempfile
|
|
|
|
| 11 |
|
| 12 |
from mast3r.demo import get_args_parser, main_demo
|
| 13 |
|
|
@@ -36,17 +37,11 @@ if __name__ == '__main__':
|
|
| 36 |
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
| 37 |
chkpt_tag = hash_md5(weights_path)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
| 43 |
os.makedirs(cache_path, exist_ok=True)
|
| 44 |
main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
|
| 45 |
-
share=args.share)
|
| 46 |
-
else:
|
| 47 |
-
with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
|
| 48 |
-
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
| 49 |
-
os.makedirs(cache_path, exist_ok=True)
|
| 50 |
-
main_demo(tmpdirname, model, args.device, args.image_size,
|
| 51 |
-
server_name, args.server_port, silent=args.silent,
|
| 52 |
-
share=args.share)
|
|
|
|
| 8 |
import os
|
| 9 |
import torch
|
| 10 |
import tempfile
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
|
| 13 |
from mast3r.demo import get_args_parser, main_demo
|
| 14 |
|
|
|
|
| 37 |
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
| 38 |
chkpt_tag = hash_md5(weights_path)
|
| 39 |
|
| 40 |
+
def get_context(tmp_dir):
|
| 41 |
+
return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
|
| 42 |
+
else nullcontext(tmp_dir)
|
| 43 |
+
with get_context(args.tmp_dir) as tmpdirname:
|
| 44 |
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
| 45 |
os.makedirs(cache_path, exist_ok=True)
|
| 46 |
main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
|
| 47 |
+
share=args.share, gradio_delete_cache=args.gradio_delete_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dust3r
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 8cc725dd11a9b7371bfca37994f8585ca78b42e5
|
mast3r/demo.py
CHANGED
|
@@ -13,6 +13,8 @@ import functools
|
|
| 13 |
import trimesh
|
| 14 |
import copy
|
| 15 |
from scipy.spatial.transform import Rotation
|
|
|
|
|
|
|
| 16 |
|
| 17 |
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
| 18 |
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
|
@@ -27,9 +29,30 @@ from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
|
| 27 |
import matplotlib.pyplot as pl
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def get_args_parser():
|
| 31 |
parser = dust3r_get_args_parser()
|
| 32 |
parser.add_argument('--share', action='store_true')
|
|
|
|
|
|
|
| 33 |
|
| 34 |
actions = parser._actions
|
| 35 |
for action in actions:
|
|
@@ -40,7 +63,7 @@ def get_args_parser():
|
|
| 40 |
return parser
|
| 41 |
|
| 42 |
|
| 43 |
-
def _convert_scene_output_to_glb(
|
| 44 |
cam_color=None, as_pointcloud=False,
|
| 45 |
transparent_cams=False, silent=False):
|
| 46 |
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
|
@@ -53,14 +76,17 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
|
|
| 53 |
|
| 54 |
# full pointcloud
|
| 55 |
if as_pointcloud:
|
| 56 |
-
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)])
|
| 57 |
-
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
| 58 |
-
|
|
|
|
| 59 |
scene.add_geometry(pct)
|
| 60 |
else:
|
| 61 |
meshes = []
|
| 62 |
for i in range(len(imgs)):
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 65 |
scene.add_geometry(mesh)
|
| 66 |
|
|
@@ -77,20 +103,22 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
|
|
| 77 |
rot = np.eye(4)
|
| 78 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 79 |
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
| 80 |
-
outfile = os.path.join(outdir, 'scene.glb')
|
| 81 |
if not silent:
|
| 82 |
print('(exporting 3D scene to', outfile, ')')
|
| 83 |
scene.export(file_obj=outfile)
|
| 84 |
return outfile
|
| 85 |
|
| 86 |
|
| 87 |
-
def get_3D_model_from_scene(
|
| 88 |
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
|
| 89 |
"""
|
| 90 |
extract 3D_model (glb file) from a reconstructed scene
|
| 91 |
"""
|
| 92 |
if scene is None:
|
| 93 |
return None
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# get optimized values from scene
|
| 96 |
rgbimg = scene.imgs
|
|
@@ -104,14 +132,14 @@ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud
|
|
| 104 |
else:
|
| 105 |
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
|
| 106 |
msk = to_numpy([c > min_conf_thr for c in confs])
|
| 107 |
-
return _convert_scene_output_to_glb(
|
| 108 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
| 109 |
|
| 110 |
|
| 111 |
-
def get_reconstructed_scene(outdir, model, device, silent, image_size,
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
**kw):
|
| 115 |
"""
|
| 116 |
from a list of images, run mast3r inference, sparse global aligner.
|
| 117 |
then run get_3D_model_from_scene
|
|
@@ -134,11 +162,23 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist,
|
|
| 134 |
if optim_level == 'coarse':
|
| 135 |
niter2 = 0
|
| 136 |
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
|
| 139 |
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
| 140 |
matching_conf_thr=matching_conf_thr, **kw)
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
clean_depth, transparent_cams, cam_size, TSDF_thresh)
|
| 143 |
return scene, outfile
|
| 144 |
|
|
@@ -169,13 +209,24 @@ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
|
| 169 |
return win_col, winsize, win_cyclic, refid
|
| 170 |
|
| 171 |
|
| 172 |
-
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
|
|
|
|
| 173 |
if not silent:
|
| 174 |
print('Outputing stuff in', tmpdirname)
|
| 175 |
|
| 176 |
-
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device,
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
| 180 |
scene = gradio.State(None)
|
| 181 |
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
|
@@ -212,7 +263,6 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
|
|
| 212 |
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
| 213 |
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
| 214 |
minimum=0, maximum=0, step=1, visible=False)
|
| 215 |
-
|
| 216 |
run_btn = gradio.Button("Run")
|
| 217 |
|
| 218 |
with gradio.Row():
|
|
@@ -241,7 +291,7 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
|
|
| 241 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
| 242 |
outputs=[win_col, winsize, win_cyclic, refid])
|
| 243 |
run_btn.click(fn=recon_fun,
|
| 244 |
-
inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
| 245 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
| 246 |
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
|
| 247 |
outputs=[scene, outmodel])
|
|
@@ -274,4 +324,3 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
|
|
| 274 |
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 275 |
outputs=outmodel)
|
| 276 |
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
| 277 |
-
|
|
|
|
| 13 |
import trimesh
|
| 14 |
import copy
|
| 15 |
from scipy.spatial.transform import Rotation
|
| 16 |
+
import tempfile
|
| 17 |
+
import shutil
|
| 18 |
|
| 19 |
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
| 20 |
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
|
|
|
| 29 |
import matplotlib.pyplot as pl
|
| 30 |
|
| 31 |
|
| 32 |
+
class SparseGAState():
|
| 33 |
+
def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
|
| 34 |
+
self.sparse_ga = sparse_ga
|
| 35 |
+
self.cache_dir = cache_dir
|
| 36 |
+
self.outfile_name = outfile_name
|
| 37 |
+
self.should_delete = should_delete
|
| 38 |
+
|
| 39 |
+
def __getattr__(self, name):
|
| 40 |
+
return getattr(self.sparse_ga, name)
|
| 41 |
+
|
| 42 |
+
def __del__(self):
|
| 43 |
+
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
|
| 44 |
+
shutil.rmtree(self.cache_dir)
|
| 45 |
+
self.cache_dir = None
|
| 46 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
| 47 |
+
os.remove(self.outfile_name)
|
| 48 |
+
self.outfile_name = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
def get_args_parser():
|
| 52 |
parser = dust3r_get_args_parser()
|
| 53 |
parser.add_argument('--share', action='store_true')
|
| 54 |
+
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
| 55 |
+
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
| 56 |
|
| 57 |
actions = parser._actions
|
| 58 |
for action in actions:
|
|
|
|
| 63 |
return parser
|
| 64 |
|
| 65 |
|
| 66 |
+
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
| 67 |
cam_color=None, as_pointcloud=False,
|
| 68 |
transparent_cams=False, silent=False):
|
| 69 |
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
|
|
|
| 76 |
|
| 77 |
# full pointcloud
|
| 78 |
if as_pointcloud:
|
| 79 |
+
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
|
| 80 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
|
| 81 |
+
valid_msk = np.isfinite(pts.sum(axis=1))
|
| 82 |
+
pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
|
| 83 |
scene.add_geometry(pct)
|
| 84 |
else:
|
| 85 |
meshes = []
|
| 86 |
for i in range(len(imgs)):
|
| 87 |
+
pts3d_i = pts3d[i].reshape(imgs[i].shape)
|
| 88 |
+
msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
|
| 89 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
|
| 90 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 91 |
scene.add_geometry(mesh)
|
| 92 |
|
|
|
|
| 103 |
rot = np.eye(4)
|
| 104 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 105 |
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
|
|
|
| 106 |
if not silent:
|
| 107 |
print('(exporting 3D scene to', outfile, ')')
|
| 108 |
scene.export(file_obj=outfile)
|
| 109 |
return outfile
|
| 110 |
|
| 111 |
|
| 112 |
+
def get_3D_model_from_scene(silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
|
| 113 |
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
|
| 114 |
"""
|
| 115 |
extract 3D_model (glb file) from a reconstructed scene
|
| 116 |
"""
|
| 117 |
if scene is None:
|
| 118 |
return None
|
| 119 |
+
outfile = scene.outfile_name
|
| 120 |
+
if outfile is None:
|
| 121 |
+
return None
|
| 122 |
|
| 123 |
# get optimized values from scene
|
| 124 |
rgbimg = scene.imgs
|
|
|
|
| 132 |
else:
|
| 133 |
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
|
| 134 |
msk = to_numpy([c > min_conf_thr for c in confs])
|
| 135 |
+
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
| 136 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
| 137 |
|
| 138 |
|
| 139 |
+
def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
|
| 140 |
+
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
| 141 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
| 142 |
+
win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
| 143 |
"""
|
| 144 |
from a list of images, run mast3r inference, sparse global aligner.
|
| 145 |
then run get_3D_model_from_scene
|
|
|
|
| 162 |
if optim_level == 'coarse':
|
| 163 |
niter2 = 0
|
| 164 |
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
| 165 |
+
if current_scene_state is not None and current_scene_state.cache_dir is not None:
|
| 166 |
+
cache_dir = current_scene_state.cache_dir
|
| 167 |
+
elif gradio_delete_cache:
|
| 168 |
+
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
|
| 169 |
+
else:
|
| 170 |
+
cache_dir = os.path.join(outdir, 'cache')
|
| 171 |
+
scene = sparse_global_alignment(filelist, pairs, cache_dir,
|
| 172 |
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
|
| 173 |
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
| 174 |
matching_conf_thr=matching_conf_thr, **kw)
|
| 175 |
+
if current_scene_state is not None and current_scene_state.outfile_name is not None:
|
| 176 |
+
outfile_name = current_scene_state.outfile_name
|
| 177 |
+
else:
|
| 178 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
| 179 |
+
|
| 180 |
+
scene = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
|
| 181 |
+
outfile = get_3D_model_from_scene(silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 182 |
clean_depth, transparent_cams, cam_size, TSDF_thresh)
|
| 183 |
return scene, outfile
|
| 184 |
|
|
|
|
| 209 |
return win_col, winsize, win_cyclic, refid
|
| 210 |
|
| 211 |
|
| 212 |
+
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
|
| 213 |
+
share=False, gradio_delete_cache=False):
|
| 214 |
if not silent:
|
| 215 |
print('Outputing stuff in', tmpdirname)
|
| 216 |
|
| 217 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
|
| 218 |
+
silent, image_size)
|
| 219 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
| 220 |
+
|
| 221 |
+
def get_context(delete_cache):
|
| 222 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
| 223 |
+
title = "MASt3R Demo"
|
| 224 |
+
if delete_cache:
|
| 225 |
+
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
|
| 226 |
+
else:
|
| 227 |
+
return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
|
| 228 |
+
|
| 229 |
+
with get_context(gradio_delete_cache) as demo:
|
| 230 |
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
| 231 |
scene = gradio.State(None)
|
| 232 |
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
|
|
|
| 263 |
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
| 264 |
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
| 265 |
minimum=0, maximum=0, step=1, visible=False)
|
|
|
|
| 266 |
run_btn = gradio.Button("Run")
|
| 267 |
|
| 268 |
with gradio.Row():
|
|
|
|
| 291 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
| 292 |
outputs=[win_col, winsize, win_cyclic, refid])
|
| 293 |
run_btn.click(fn=recon_fun,
|
| 294 |
+
inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
| 295 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
| 296 |
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
|
| 297 |
outputs=[scene, outmodel])
|
|
|
|
| 324 |
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 325 |
outputs=outmodel)
|
| 326 |
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
|
|