sherwin6180 commited on
Commit
a30b826
·
verified ·
1 Parent(s): b22393c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -45
app.py CHANGED
@@ -8,7 +8,7 @@ from typing import Iterable
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
10
 
11
- # --- Mock Spaces ---
12
  class MockSpaces:
13
  def GPU(self, duration=0):
14
  def decorator(func):
@@ -16,7 +16,7 @@ class MockSpaces:
16
  return decorator
17
  spaces = MockSpaces()
18
 
19
- # --- Theme Setup ---
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
@@ -84,47 +84,39 @@ class SteelBlueTheme(Soft):
84
  )
85
  steel_blue_theme = SteelBlueTheme()
86
 
87
- # --- Debug Info ---
88
- print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
89
- print("GPU Count:", torch.cuda.device_count())
90
-
91
  from diffusers import FlowMatchEulerDiscreteScheduler
92
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
93
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
94
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
95
 
96
- # --- FIX: Determine Device Strategy ---
97
- # 如果在 HF 构建环境(无 GPU),使用 CPU 防止报错
98
- # 如果在 RunPod (有 GPU),使用 "balanced" 策略 (Pipeline 不支持 "auto")
99
- if torch.cuda.device_count() > 0:
100
- device_strategy = "balanced" # Pipeline level strategy
101
- transformer_strategy = "auto" # Transformer level strategy
102
- dtype = torch.bfloat16
103
- print(f"Running on GPU with strategy: {device_strategy}")
104
- else:
105
- device_strategy = "cpu"
106
- transformer_strategy = "cpu"
107
- dtype = torch.float32 # CPU usually prefers float32
108
- print("Running on CPU (Build Environment detected)")
109
-
110
- print("Loading Transformer...")
111
- transformer_model = QwenImageTransformer2DModel.from_pretrained(
112
- "linoyts/Qwen-Image-Edit-Rapid-AIO",
113
- subfolder='transformer',
114
- torch_dtype=dtype,
115
- device_map=transformer_strategy
116
- )
117
-
118
- print("Loading Pipeline...")
119
- pipe = QwenImageEditPlusPipeline.from_pretrained(
120
- "Qwen/Qwen-Image-Edit-2509",
121
- transformer=transformer_model,
122
- torch_dtype=dtype,
123
- device_map=device_strategy # <--- 这里必须是 balanced 或 cpu,不能是 auto
124
- )
125
 
126
- # Only load LoRAs and optimization if on GPU to avoid build errors
127
- if torch.cuda.device_count() > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  print("Loading LoRAs...")
129
  pipe.load_lora_weights("autoweeb/Qwen-Image-Edit-2509-Photo-to-Anime", weight_name="Qwen-Image-Edit-2509-Photo-to-Anime_000001000.safetensors", adapter_name="anime")
130
  pipe.load_lora_weights("dx8152/Qwen-Edit-2509-Multiple-angles", weight_name="镜头转换.safetensors", adapter_name="multiple-angles")
@@ -139,6 +131,8 @@ if torch.cuda.device_count() > 0:
139
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
140
  except Exception as e:
141
  print(f"Warning: FA3 set skipped: {e}")
 
 
142
 
143
  MAX_SEED = np.iinfo(np.int32).max
144
 
@@ -160,13 +154,13 @@ def update_dimensions_on_upload(image):
160
 
161
  @spaces.GPU(duration=30)
162
  def infer(input_image, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
163
  if input_image is None:
164
  raise gr.Error("Please upload an image to edit.")
165
 
166
- # 如果没有 GPU (比如在 HF 预览界面),直接报错提示
167
- if torch.cuda.device_count() == 0:
168
- raise gr.Error("Running on CPU-only environment. Please run on GPU.")
169
-
170
  adapters_map = {
171
  "Photo-to-Anime": "anime",
172
  "Multiple-Angles": "multiple-angles",
@@ -204,9 +198,7 @@ def infer(input_image, prompt, lora_adapter, seed, randomize_seed, guidance_scal
204
 
205
  @spaces.GPU(duration=30)
206
  def infer_example(input_image, prompt, lora_adapter):
207
- # 如果是 HF 预览构建过程,跳过
208
- if torch.cuda.device_count() == 0:
209
- return None, 0
210
  input_pil = input_image.convert("RGB")
211
  result, seed = infer(input_pil, prompt, lora_adapter, 0, True, 1.0, 4)
212
  return result, seed
@@ -218,7 +210,7 @@ css="""
218
 
219
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
220
  with gr.Column(elem_id="col-container"):
221
- gr.Markdown("# **Qwen-Image-Edit-2509-LoRAs-Fast (2xA40 Ready)**", elem_id="main-title")
222
 
223
  with gr.Row(equal_height=True):
224
  with gr.Column():
 
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
10
 
11
+ # --- Mock Spaces (保持不变) ---
12
  class MockSpaces:
13
  def GPU(self, duration=0):
14
  def decorator(func):
 
16
  return decorator
17
  spaces = MockSpaces()
18
 
19
+ # --- Theme Setup (保持不变) ---
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
 
84
  )
85
  steel_blue_theme = SteelBlueTheme()
86
 
87
+ # --- 关键修改:按需加载 ---
 
 
 
88
  from diffusers import FlowMatchEulerDiscreteScheduler
89
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
90
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
91
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
92
 
93
+ pipe = None # 全局变量初始化为空
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # 检测逻辑:只有在真正有 GPU 的时候才加载模型
96
+ # 这样 HF 的 CPU 构建服务器会直接跳过这里,瞬间完成构建
97
+ if torch.cuda.is_available():
98
+ print("GPU detected! Initializing model for 2x A40 Environment...")
99
+ dtype = torch.bfloat16
100
+
101
+ # 1. Load Transformer (device_map="auto" for multi-gpu split)
102
+ print("Loading Transformer...")
103
+ transformer_model = QwenImageTransformer2DModel.from_pretrained(
104
+ "linoyts/Qwen-Image-Edit-Rapid-AIO",
105
+ subfolder='transformer',
106
+ torch_dtype=dtype,
107
+ device_map="auto"
108
+ )
109
+
110
+ # 2. Load Pipeline (device_map="balanced" compatible with diffusers)
111
+ print("Loading Pipeline...")
112
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
113
+ "Qwen/Qwen-Image-Edit-2509",
114
+ transformer=transformer_model,
115
+ torch_dtype=dtype,
116
+ device_map="balanced"
117
+ )
118
+
119
+ # 3. Load LoRAs
120
  print("Loading LoRAs...")
121
  pipe.load_lora_weights("autoweeb/Qwen-Image-Edit-2509-Photo-to-Anime", weight_name="Qwen-Image-Edit-2509-Photo-to-Anime_000001000.safetensors", adapter_name="anime")
122
  pipe.load_lora_weights("dx8152/Qwen-Edit-2509-Multiple-angles", weight_name="镜头转换.safetensors", adapter_name="multiple-angles")
 
131
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
132
  except Exception as e:
133
  print(f"Warning: FA3 set skipped: {e}")
134
+ else:
135
+ print("No GPU detected (likely HF Build Environment). SKIPPING MODEL LOAD to save memory.")
136
 
137
  MAX_SEED = np.iinfo(np.int32).max
138
 
 
154
 
155
  @spaces.GPU(duration=30)
156
  def infer(input_image, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
157
+ # 运行时检查:如果 pipe 没加载(说明没 GPU),直接报错
158
+ if pipe is None:
159
+ raise gr.Error("Error: Model not loaded. Is a GPU available?")
160
+
161
  if input_image is None:
162
  raise gr.Error("Please upload an image to edit.")
163
 
 
 
 
 
164
  adapters_map = {
165
  "Photo-to-Anime": "anime",
166
  "Multiple-Angles": "multiple-angles",
 
198
 
199
  @spaces.GPU(duration=30)
200
  def infer_example(input_image, prompt, lora_adapter):
201
+ if pipe is None: return None, 0
 
 
202
  input_pil = input_image.convert("RGB")
203
  result, seed = infer(input_pil, prompt, lora_adapter, 0, True, 1.0, 4)
204
  return result, seed
 
210
 
211
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
212
  with gr.Column(elem_id="col-container"):
213
+ gr.Markdown("# **Qwen-Image-Edit-2509-LoRAs-Fast (RunPod Optimized)**", elem_id="main-title")
214
 
215
  with gr.Row(equal_height=True):
216
  with gr.Column():