r3gm commited on
Commit
05035da
·
verified ·
1 Parent(s): 638b624

Upload inference_video_w.py

Browse files
Files changed (1) hide show
  1. inference_video_w.py +316 -0
inference_video_w.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from torch.nn import functional as F
7
+ import warnings
8
+ import _thread
9
+ import skvideo.io
10
+ from queue import Queue, Empty
11
+ from model.pytorch_msssim import ssim_matlab
12
+ import shutil
13
+ import tempfile
14
+ import time
15
+
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # Utility class to mimic argparse object
19
+ class Args:
20
+ def __init__(self, **kwargs):
21
+ self.__dict__.update(kwargs)
22
+
23
+ def transferAudio(sourceVideo, targetVideo):
24
+ # generate a unique temp directory for this user
25
+ unique_temp_dir = tempfile.mkdtemp()
26
+ tempAudioFileName = os.path.join(unique_temp_dir, "audio.mkv")
27
+
28
+ # extract audio from video
29
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
30
+
31
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
32
+ os.rename(targetVideo, targetNoAudio)
33
+ # combine audio file and new video file
34
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
35
+
36
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
37
+ tempAudioFileName = os.path.join(unique_temp_dir, "audio.m4a")
38
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
39
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
40
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
41
+ os.rename(targetNoAudio, targetVideo)
42
+ print("Audio transfer failed. Interpolated video will have no audio")
43
+ else:
44
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
45
+ # remove audio-less video
46
+ os.remove(targetNoAudio)
47
+ else:
48
+ os.remove(targetNoAudio)
49
+
50
+ # remove temp directory
51
+ shutil.rmtree(unique_temp_dir)
52
+
53
+ def inference(
54
+ video=None,
55
+ output=None,
56
+ img=None,
57
+ montage=False,
58
+ modelDir='train_log',
59
+ fp16=False,
60
+ UHD=False,
61
+ scale=1.0,
62
+ skip=False,
63
+ fps=None,
64
+ png=False,
65
+ ext='mp4',
66
+ exp=1,
67
+ multi=2
68
+ ):
69
+ # Initialize Arguments Object
70
+ args = Args(
71
+ video=video, output=output, img=img, montage=montage,
72
+ modelDir=modelDir, fp16=fp16, UHD=UHD, scale=scale,
73
+ skip=skip, fps=fps, png=png, ext=ext, exp=exp, multi=multi
74
+ )
75
+
76
+ # Argument Logic Adjustment
77
+ if args.exp != 1:
78
+ args.multi = (2 ** args.exp)
79
+
80
+ # Assertions
81
+ assert (not args.video is None or not args.img is None)
82
+ if args.skip:
83
+ print("skip flag is abandoned, please refer to issue #207.")
84
+ if args.UHD and args.scale==1.0:
85
+ args.scale = 0.5
86
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
87
+ if not args.img is None:
88
+ args.png = True
89
+
90
+ # Device Setup
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ torch.set_grad_enabled(False)
93
+ if torch.cuda.is_available():
94
+ torch.backends.cudnn.enabled = True
95
+ torch.backends.cudnn.benchmark = True
96
+ if(args.fp16):
97
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
98
+
99
+ # Load Model
100
+ from train_log.RIFE_HDv3 import Model
101
+ model = Model()
102
+ if not hasattr(model, 'version'):
103
+ model.version = 0
104
+ model.load_model(args.modelDir, -1)
105
+ print("Loaded 3.x/4.x HD model.")
106
+ model.eval()
107
+ model.device()
108
+
109
+ # Video/Image Setup
110
+ if not args.video is None:
111
+ videoCapture = cv2.VideoCapture(args.video)
112
+ original_fps = videoCapture.get(cv2.CAP_PROP_FPS)
113
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
114
+ videoCapture.release()
115
+
116
+ if args.fps is None or args.fps == 0:
117
+ fpsNotAssigned = True
118
+ args.fps = original_fps * args.multi
119
+ else:
120
+ fpsNotAssigned = False
121
+
122
+ videogen = skvideo.io.vreader(args.video)
123
+ lastframe = next(videogen)
124
+ # fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') # Unused in original logic for skvideo
125
+ video_path_wo_ext, ext = os.path.splitext(args.video)
126
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, original_fps, args.fps))
127
+
128
+ if args.png == False and fpsNotAssigned == True:
129
+ print("The audio will be merged after interpolation process")
130
+ else:
131
+ print("Will not merge audio because using png or fps flag!")
132
+ else:
133
+ videogen = []
134
+ for f in os.listdir(args.img):
135
+ if 'png' in f:
136
+ videogen.append(f)
137
+ tot_frame = len(videogen)
138
+ videogen.sort(key= lambda x:int(x[:-4]))
139
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
140
+ videogen = videogen[1:]
141
+
142
+ h, w, _ = lastframe.shape
143
+ vid_out_name = None
144
+ vid_out = None
145
+
146
+ if args.png:
147
+ if not os.path.exists('vid_out'):
148
+ os.mkdir('vid_out')
149
+ else:
150
+ if args.output is not None:
151
+ vid_out_name = args.output
152
+ else:
153
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
154
+
155
+ outputdict = {
156
+ '-c:v': 'libx264',
157
+ '-crf': '17',
158
+ '-preset': 'slow',
159
+ '-pix_fmt': 'yuv420p'
160
+ }
161
+ vid_out = skvideo.io.FFmpegWriter(vid_out_name, inputdict={'-r': str(args.fps)}, outputdict=outputdict)
162
+
163
+ # --- Nested Helper Functions to capture 'args', 'model', 'vid_out' scope ---
164
+
165
+ def clear_write_buffer(write_buffer):
166
+ cnt = 0
167
+ while True:
168
+ item = write_buffer.get()
169
+ if item is None:
170
+ break
171
+ if args.png:
172
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
173
+ cnt += 1
174
+ else:
175
+ vid_out.writeFrame(item)
176
+
177
+ def build_read_buffer(read_buffer, videogen):
178
+ try:
179
+ for frame in videogen:
180
+ if not args.img is None:
181
+ frame = cv2.imread(os.path.join(args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
182
+ if args.montage:
183
+ frame = frame[:, left: left + w]
184
+ read_buffer.put(frame)
185
+ except:
186
+ pass
187
+ read_buffer.put(None)
188
+
189
+ def make_inference(I0, I1, n):
190
+ if model.version >= 3.9:
191
+ res = []
192
+ for i in range(n):
193
+ res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
194
+ return res
195
+ else:
196
+ middle = model.inference(I0, I1, args.scale)
197
+ if n == 1:
198
+ return [middle]
199
+ first_half = make_inference(I0, middle, n=n//2)
200
+ second_half = make_inference(middle, I1, n=n//2)
201
+ if n%2:
202
+ return [*first_half, middle, *second_half]
203
+ else:
204
+ return [*first_half, *second_half]
205
+
206
+ def pad_image(img):
207
+ if(args.fp16):
208
+ return F.pad(img, padding).half()
209
+ else:
210
+ return F.pad(img, padding)
211
+
212
+ # --- Pre-Loop Setup ---
213
+
214
+ left = 0 # Define default
215
+ if args.montage:
216
+ left = w // 4
217
+ w = w // 2
218
+
219
+ tmp = max(128, int(128 / args.scale))
220
+ ph = ((h - 1) // tmp + 1) * tmp
221
+ pw = ((w - 1) // tmp + 1) * tmp
222
+ padding = (0, pw - w, 0, ph - h)
223
+
224
+ pbar = tqdm(total=tot_frame)
225
+ if args.montage:
226
+ lastframe = lastframe[:, left: left + w]
227
+
228
+ write_buffer = Queue(maxsize=500)
229
+ read_buffer = Queue(maxsize=500)
230
+
231
+ # Start threads
232
+ _thread.start_new_thread(build_read_buffer, (read_buffer, videogen))
233
+ _thread.start_new_thread(clear_write_buffer, (write_buffer,))
234
+
235
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
236
+ I1 = pad_image(I1)
237
+ temp = None
238
+
239
+ # --- Main Loop ---
240
+
241
+ while True:
242
+ if temp is not None:
243
+ frame = temp
244
+ temp = None
245
+ else:
246
+ frame = read_buffer.get()
247
+ if frame is None:
248
+ break
249
+ I0 = I1
250
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
251
+ I1 = pad_image(I1)
252
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
253
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
254
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
255
+
256
+ break_flag = False
257
+ if ssim > 0.996:
258
+ frame = read_buffer.get() # read a new frame
259
+ if frame is None:
260
+ break_flag = True
261
+ frame = lastframe
262
+ else:
263
+ temp = frame
264
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
265
+ I1 = pad_image(I1)
266
+ I1 = model.inference(I0, I1, scale=args.scale)
267
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
268
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
269
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
270
+
271
+ if ssim < 0.2:
272
+ output_frames = []
273
+ for i in range(args.multi - 1):
274
+ output_frames.append(I0)
275
+ else:
276
+ output_frames = make_inference(I0, I1, args.multi - 1)
277
+
278
+ if args.montage:
279
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
280
+ for mid in output_frames:
281
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
282
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
283
+ else:
284
+ write_buffer.put(lastframe)
285
+ for mid in output_frames:
286
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
287
+ write_buffer.put(mid[:h, :w])
288
+ pbar.update(1)
289
+ lastframe = frame
290
+ if break_flag:
291
+ break
292
+
293
+ if args.montage:
294
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
295
+ else:
296
+ write_buffer.put(lastframe)
297
+
298
+ write_buffer.put(None)
299
+
300
+ while(not write_buffer.empty()):
301
+ time.sleep(0.1)
302
+ pbar.close()
303
+
304
+ if not vid_out is None:
305
+ vid_out.close()
306
+
307
+ # Audio Transfer Logic
308
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
309
+ try:
310
+ transferAudio(args.video, vid_out_name)
311
+ except:
312
+ print("Audio transfer failed. Interpolated video will have no audio")
313
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
314
+ os.rename(targetNoAudio, vid_out_name)
315
+
316
+ return vid_out_name