r3gm commited on
Commit
940e9ee
·
verified ·
1 Parent(s): 05035da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -68
app.py CHANGED
@@ -2,26 +2,24 @@ import gradio as gr
2
  import subprocess
3
  import os
4
  import spaces
 
 
5
 
6
  # Download the file
7
  subprocess.run([
8
  "wget",
9
- "--no-check-certificate",
10
- "https://drive.google.com/uc?id=1mj9lH6Be7ztYtHAr1xUUGT3hRtWJBy_5",
11
  "-O",
12
- "RIFE_trained_model_v4.13.2.zip"
13
  ], check=True)
14
 
15
  # Unzip the downloaded file
16
  subprocess.run([
17
  "unzip",
18
  "-o",
19
- "RIFE_trained_model_v4.13.2.zip"
20
  ], check=True)
21
 
22
- # The name of your script
23
- SCRIPT_NAME = "inference_video.py"
24
-
25
  @spaces.GPU(duration=120)
26
  def run_rife(
27
  input_video,
@@ -29,7 +27,6 @@ def run_rife(
29
  time_exponent,
30
  fixed_fps,
31
  video_scale,
32
- enable_uhd_mode,
33
  remove_duplicate_frames,
34
  create_montage,
35
  progress=gr.Progress(track_tqdm=True),
@@ -40,58 +37,37 @@ def run_rife(
40
  ext = "mp4"
41
  model_dir = "train_log"
42
 
 
43
  video_path_wo_ext = os.path.splitext(os.path.basename(input_video))[0]
44
- output_base_name = "{}_{}X_{}fps.{}".format(video_path_wo_ext, int(frame_multiplier), int(fixed_fps), ext)
45
-
46
- cmd = ["python3", SCRIPT_NAME]
47
- cmd.extend(["--video", input_video])
48
- cmd.extend(["--output", output_base_name])
49
- cmd.extend(["--multi", str(int(frame_multiplier))])
50
 
51
- if time_exponent != 1:
52
- cmd.extend(["--exp", str(int(time_exponent))])
53
-
54
  if fixed_fps > 0:
55
- cmd.extend(["--fps", str(int(fixed_fps))])
56
  gr.Warning("Will not merge audio because using fps flag!")
57
-
58
- if video_scale != 1.0:
59
- cmd.extend(["--scale", str(video_scale)])
60
-
61
- cmd.extend(["--ext", ext])
62
- cmd.extend(["--model", model_dir])
63
- cmd.append("--fp16")
64
 
65
- if enable_uhd_mode:
66
- cmd.append("--UHD")
67
-
68
- if remove_duplicate_frames:
69
- cmd.append("--skip")
70
-
71
- if create_montage:
72
- cmd.append("--montage")
73
-
74
- print(f"Executing command: {' '.join(cmd)}")
75
 
76
  try:
77
- process = subprocess.run(cmd, capture_output=True, text=True)
78
-
79
- if process.stdout:
80
- print("STDOUT:", process.stdout)
81
- if process.stderr:
82
- print("STDERR:", process.stderr)
83
-
84
- if process.returncode != 0:
85
- raise gr.Error(f"Inference failed. Error: {process.stderr}")
86
-
87
- final_output_file = f"{output_base_name}.{ext}"
88
-
89
- if os.path.exists(final_output_file):
90
- return final_output_file
 
 
 
 
91
  else:
92
- if os.path.exists(output_base_name):
93
- return output_base_name
94
- raise gr.Error(f"Output file not found. Expected: {final_output_file}")
95
 
96
  except Exception as e:
97
  raise gr.Error(f"An error occurred: {str(e)}")
@@ -99,9 +75,9 @@ def run_rife(
99
 
100
  # --- Gradio UI Layout ---
101
 
102
- with gr.Blocks(title="Professional FPS Booster") as app:
103
- gr.Markdown("# ⚡ RIFE: High-Performance FPS Booster")
104
- gr.Markdown("Maximize video fluidity and frame rate using deep flow estimation.")
105
  gr.Markdown("⚠️ **Notice:** Keep input videos under 60 seconds for frame interpolation to prevent GPU task aborts.")
106
 
107
  with gr.Row():
@@ -109,13 +85,11 @@ with gr.Blocks(title="Professional FPS Booster") as app:
109
  with gr.Column(scale=1):
110
  input_vid = gr.Video(label="🎬 Input Source Video", sources=["upload"])
111
 
112
- with gr.Group():
113
- gr.Markdown("### 🚀 Performance Settings")
114
-
115
  multi_param = gr.Dropdown(
116
  choices=["2", "3", "4", "5", "6"],
117
  value="2",
118
- label=" Frame Multiplier",
119
  info="2X = Double FPS (e.g. 30 -> 60). Higher multipliers create more intermediate frames."
120
  )
121
 
@@ -141,15 +115,10 @@ with gr.Blocks(title="Professional FPS Booster") as app:
141
  )
142
 
143
  with gr.Row():
144
- uhd_chk = gr.Checkbox(
145
- label="💎 UHD/4K Optimization",
146
- value=False,
147
- info="Enable for high-resolution sources."
148
- )
149
  skip_chk = gr.Checkbox(
150
  label="⏩ Skip Static Frames",
151
  value=False,
152
- info="Bypass processing for non-moving scenes to save time."
153
  )
154
  montage_chk = gr.Checkbox(
155
  label="🆚 Split-Screen Comparison",
@@ -157,11 +126,11 @@ with gr.Blocks(title="Professional FPS Booster") as app:
157
  info="Output video showing Original vs. Processed."
158
  )
159
 
160
- btn_run = gr.Button("🚀 RENDER HIGH FPS VIDEO", variant="primary", size="lg")
161
 
162
  # --- Right Column: Output ---
163
  with gr.Column(scale=1):
164
- output_vid = gr.Video(label=" High FPS Result")
165
  gr.Markdown("**Status:** Rendering time depends on input resolution and duration.")
166
 
167
  # --- Bind Logic ---
@@ -173,7 +142,6 @@ with gr.Blocks(title="Professional FPS Booster") as app:
173
  exp_param,
174
  fps_param,
175
  scale_param,
176
- uhd_chk,
177
  skip_chk,
178
  montage_chk
179
  ],
 
2
  import subprocess
3
  import os
4
  import spaces
5
+ import inference_video_w
6
+ import torch
7
 
8
  # Download the file
9
  subprocess.run([
10
  "wget",
11
+ "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip",
 
12
  "-O",
13
+ "RIFEv4.26_0921.zip"
14
  ], check=True)
15
 
16
  # Unzip the downloaded file
17
  subprocess.run([
18
  "unzip",
19
  "-o",
20
+ "RIFEv4.26_0921.zip"
21
  ], check=True)
22
 
 
 
 
23
  @spaces.GPU(duration=120)
24
  def run_rife(
25
  input_video,
 
27
  time_exponent,
28
  fixed_fps,
29
  video_scale,
 
30
  remove_duplicate_frames,
31
  create_montage,
32
  progress=gr.Progress(track_tqdm=True),
 
37
  ext = "mp4"
38
  model_dir = "train_log"
39
 
40
+ # Construct output filename pattern to match what inference_video.py expects/generates
41
  video_path_wo_ext = os.path.splitext(os.path.basename(input_video))[0]
42
+ # We pass the desired output name, though the function logic tries to stick to this pattern anyway
43
+ output_base_name = "{}_{}X_fps.{}".format(video_path_wo_ext, int(frame_multiplier), ext)
 
 
 
 
44
 
 
 
 
45
  if fixed_fps > 0:
 
46
  gr.Warning("Will not merge audio because using fps flag!")
 
 
 
 
 
 
 
47
 
48
+ print(f"Starting Inference for: {input_video}")
 
 
 
 
 
 
 
 
 
49
 
50
  try:
51
+ # Call the imported function directly
52
+ result_path = inference_video_w.inference(
53
+ video=input_video,
54
+ output=output_base_name,
55
+ modelDir=model_dir,
56
+ fp16=(True if torch.cuda.is_available() else False),
57
+ UHD=False,
58
+ scale=video_scale,
59
+ skip=remove_duplicate_frames,
60
+ fps=int(fixed_fps) if fixed_fps > 0 else None,
61
+ ext=ext,
62
+ exp=int(time_exponent),
63
+ multi=int(frame_multiplier),
64
+ montage=create_montage
65
+ )
66
+
67
+ if result_path and os.path.exists(result_path):
68
+ return result_path
69
  else:
70
+ raise gr.Error(f"Output file not found. Expected: {result_path}")
 
 
71
 
72
  except Exception as e:
73
  raise gr.Error(f"An error occurred: {str(e)}")
 
75
 
76
  # --- Gradio UI Layout ---
77
 
78
+ with gr.Blocks(title="Frame Rate Enhancer") as app:
79
+ gr.Markdown("# ⚡ RIFE: Frame Rate Enhancer")
80
+ gr.Markdown("Creates extra frames between the original ones to make motion in your videos smoother and more fluid.")
81
  gr.Markdown("⚠️ **Notice:** Keep input videos under 60 seconds for frame interpolation to prevent GPU task aborts.")
82
 
83
  with gr.Row():
 
85
  with gr.Column(scale=1):
86
  input_vid = gr.Video(label="🎬 Input Source Video", sources=["upload"])
87
 
88
+ with gr.Group():
 
 
89
  multi_param = gr.Dropdown(
90
  choices=["2", "3", "4", "5", "6"],
91
  value="2",
92
+ label="🗃️ Frame Multiplier",
93
  info="2X = Double FPS (e.g. 30 -> 60). Higher multipliers create more intermediate frames."
94
  )
95
 
 
115
  )
116
 
117
  with gr.Row():
 
 
 
 
 
118
  skip_chk = gr.Checkbox(
119
  label="⏩ Skip Static Frames",
120
  value=False,
121
+ info="Bypass processing for static frames to save time."
122
  )
123
  montage_chk = gr.Checkbox(
124
  label="🆚 Split-Screen Comparison",
 
126
  info="Output video showing Original vs. Processed."
127
  )
128
 
129
+ btn_run = gr.Button("GENERATE INTERMEDIATE FRAMES", variant="primary", size="lg")
130
 
131
  # --- Right Column: Output ---
132
  with gr.Column(scale=1):
133
+ output_vid = gr.Video(label="INTERPOLATED RESULT")
134
  gr.Markdown("**Status:** Rendering time depends on input resolution and duration.")
135
 
136
  # --- Bind Logic ---
 
142
  exp_param,
143
  fps_param,
144
  scale_param,
 
145
  skip_chk,
146
  montage_chk
147
  ],