Adnan commited on
Commit
8fff660
·
verified ·
1 Parent(s): 519e224

Create image_generator.py

Browse files
Files changed (1) hide show
  1. image_generator.py +284 -0
image_generator.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TimeLapseForge — Image Generator Module
3
+ Uses SDXL Turbo / SDXL Base for consistent timelapse panel generation.
4
+ Key technique: img2img chaining with low strength for natural transitions.
5
+ """
6
+
7
+ import torch
8
+ import gc
9
+ from typing import List, Dict, Optional, Callable, Tuple
10
+ from PIL import Image
11
+
12
+ # Model cache
13
+ _pipelines = {}
14
+
15
+
16
+ def flush_memory():
17
+ """Clear GPU and CPU memory caches."""
18
+ gc.collect()
19
+ if torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+
22
+
23
+ def get_pipeline(model_id: str = "stabilityai/sdxl-turbo", pipeline_type: str = "t2i"):
24
+ """
25
+ Load and cache the diffusion pipeline.
26
+ Uses from_pipe for img2img to share components and save memory.
27
+ """
28
+ global _pipelines
29
+ from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
30
+
31
+ cache_key = f"{model_id}_{pipeline_type}"
32
+
33
+ if cache_key not in _pipelines:
34
+ if pipeline_type == "t2i":
35
+ pipe = AutoPipelineForText2Image.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.float16,
38
+ variant="fp16" if "turbo" in model_id or "sdxl" in model_id.lower() else None,
39
+ )
40
+ _pipelines[cache_key] = pipe
41
+
42
+ elif pipeline_type == "i2i":
43
+ # Try to reuse t2i components
44
+ t2i_key = f"{model_id}_t2i"
45
+ if t2i_key in _pipelines:
46
+ pipe = AutoPipelineForImage2Image.from_pipe(_pipelines[t2i_key])
47
+ else:
48
+ pipe = AutoPipelineForImage2Image.from_pretrained(
49
+ model_id,
50
+ torch_dtype=torch.float16,
51
+ variant="fp16" if "turbo" in model_id or "sdxl" in model_id.lower() else None,
52
+ )
53
+ _pipelines[cache_key] = pipe
54
+
55
+ return _pipelines[cache_key]
56
+
57
+
58
+ def get_model_config(model_id: str) -> Dict:
59
+ """Get optimal generation parameters for each model."""
60
+ configs = {
61
+ "stabilityai/sdxl-turbo": {
62
+ "num_inference_steps": 4,
63
+ "guidance_scale": 0.0,
64
+ "default_resolution": (512, 512),
65
+ "supports_guidance": False,
66
+ },
67
+ "stabilityai/stable-diffusion-xl-base-1.0": {
68
+ "num_inference_steps": 25,
69
+ "guidance_scale": 7.5,
70
+ "default_resolution": (1024, 1024),
71
+ "supports_guidance": True,
72
+ },
73
+ "runwayml/stable-diffusion-v1-5": {
74
+ "num_inference_steps": 25,
75
+ "guidance_scale": 7.5,
76
+ "default_resolution": (512, 512),
77
+ "supports_guidance": True,
78
+ },
79
+ }
80
+ # Default config for unknown models
81
+ return configs.get(model_id, {
82
+ "num_inference_steps": 20,
83
+ "guidance_scale": 7.0,
84
+ "default_resolution": (512, 512),
85
+ "supports_guidance": True,
86
+ })
87
+
88
+
89
+ class ImageGenerator:
90
+ """
91
+ Generates timelapse panels using Stable Diffusion.
92
+ Uses img2img chaining for visual consistency between panels.
93
+ """
94
+
95
+ def __init__(self, model_id: str = "stabilityai/sdxl-turbo"):
96
+ self.model_id = model_id
97
+ self.config = get_model_config(model_id)
98
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
99
+
100
+ def _move_to_device(self, pipe):
101
+ """Move pipeline to the appropriate device."""
102
+ if self.device == "cuda":
103
+ pipe.to("cuda")
104
+ return pipe
105
+
106
+ def generate_first_panel(
107
+ self,
108
+ prompt: str,
109
+ negative_prompt: str = "",
110
+ seed: int = 42,
111
+ width: Optional[int] = None,
112
+ height: Optional[int] = None,
113
+ steps: Optional[int] = None,
114
+ guidance: Optional[float] = None,
115
+ ) -> Image.Image:
116
+ """Generate the first panel using text-to-image."""
117
+ pipe = get_pipeline(self.model_id, "t2i")
118
+ pipe = self._move_to_device(pipe)
119
+
120
+ w, h = width or self.config["default_resolution"][0], height or self.config["default_resolution"][1]
121
+ n_steps = steps or self.config["num_inference_steps"]
122
+ cfg = guidance if guidance is not None else self.config["guidance_scale"]
123
+ generator = torch.Generator(device=self.device).manual_seed(seed)
124
+
125
+ result = pipe(
126
+ prompt=prompt,
127
+ negative_prompt=negative_prompt if negative_prompt else None,
128
+ num_inference_steps=n_steps,
129
+ guidance_scale=cfg,
130
+ width=w,
131
+ height=h,
132
+ generator=generator,
133
+ )
134
+ return result.images[0]
135
+
136
+ def generate_next_panel(
137
+ self,
138
+ prompt: str,
139
+ previous_image: Image.Image,
140
+ negative_prompt: str = "",
141
+ strength: float = 0.4,
142
+ seed: int = 42,
143
+ steps: Optional[int] = None,
144
+ guidance: Optional[float] = None,
145
+ ) -> Image.Image:
146
+ """Generate the next panel using img2img from the previous panel."""
147
+ pipe = get_pipeline(self.model_id, "i2i")
148
+ pipe = self._move_to_device(pipe)
149
+
150
+ n_steps = steps or self.config["num_inference_steps"]
151
+ cfg = guidance if guidance is not None else self.config["guidance_scale"]
152
+ generator = torch.Generator(device=self.device).manual_seed(seed)
153
+
154
+ # Resize previous image to match model's expected resolution
155
+ target_w, target_h = self.config["default_resolution"]
156
+ prev_resized = previous_image.resize((target_w, target_h), Image.LANCZOS)
157
+
158
+ result = pipe(
159
+ prompt=prompt,
160
+ image=prev_resized,
161
+ negative_prompt=negative_prompt if negative_prompt else None,
162
+ num_inference_steps=n_steps,
163
+ guidance_scale=cfg,
164
+ strength=strength,
165
+ generator=generator,
166
+ )
167
+ return result.images[0]
168
+
169
+ def generate_all_panels(
170
+ self,
171
+ prompts: List[Dict[str, str]],
172
+ strength: float = 0.4,
173
+ base_seed: int = 42,
174
+ steps: Optional[int] = None,
175
+ guidance: Optional[float] = None,
176
+ width: Optional[int] = None,
177
+ height: Optional[int] = None,
178
+ progress_callback: Optional[Callable] = None,
179
+ reference_image: Optional[Image.Image] = None,
180
+ ) -> List[Image.Image]:
181
+ """
182
+ Generate all panels in sequence.
183
+ Panel 1: text-to-image (or img2img from reference if provided)
184
+ Panel 2+: img2img from previous panel with controlled strength
185
+ """
186
+ images = []
187
+
188
+ for i, prompt_data in enumerate(prompts):
189
+ main_prompt = prompt_data.get("main_prompt", "")
190
+ style = prompt_data.get("style_suffix", "")
191
+ full_prompt = f"{main_prompt}, {style}" if style else main_prompt
192
+ neg_prompt = prompt_data.get("negative_prompt", "")
193
+ seed = base_seed + i
194
+
195
+ try:
196
+ if i == 0:
197
+ if reference_image is not None:
198
+ # Use reference image as base for first panel
199
+ img = self.generate_next_panel(
200
+ prompt=full_prompt,
201
+ previous_image=reference_image,
202
+ negative_prompt=neg_prompt,
203
+ strength=max(strength, 0.5), # Slightly higher for first panel from ref
204
+ seed=seed,
205
+ steps=steps,
206
+ guidance=guidance,
207
+ )
208
+ else:
209
+ img = self.generate_first_panel(
210
+ prompt=full_prompt,
211
+ negative_prompt=neg_prompt,
212
+ seed=seed,
213
+ width=width,
214
+ height=height,
215
+ steps=steps,
216
+ guidance=guidance,
217
+ )
218
+ else:
219
+ img = self.generate_next_panel(
220
+ prompt=full_prompt,
221
+ previous_image=images[-1],
222
+ negative_prompt=neg_prompt,
223
+ strength=strength,
224
+ seed=seed,
225
+ steps=steps,
226
+ guidance=guidance,
227
+ )
228
+
229
+ images.append(img)
230
+
231
+ except Exception as e:
232
+ print(f"Error generating panel {i + 1}: {e}")
233
+ # Use previous image as fallback (or a blank image for panel 1)
234
+ if images:
235
+ images.append(images[-1].copy())
236
+ else:
237
+ fallback = Image.new("RGB", self.config["default_resolution"], (50, 50, 50))
238
+ images.append(fallback)
239
+
240
+ if progress_callback:
241
+ progress_callback(i + 1, len(prompts))
242
+
243
+ flush_memory()
244
+ return images
245
+
246
+ def regenerate_single_panel(
247
+ self,
248
+ panel_index: int,
249
+ prompts: List[Dict[str, str]],
250
+ existing_images: List[Image.Image],
251
+ strength: float = 0.4,
252
+ base_seed: int = 42,
253
+ steps: Optional[int] = None,
254
+ guidance: Optional[float] = None,
255
+ ) -> Tuple[Image.Image, List[Image.Image]]:
256
+ """
257
+ Regenerate a single panel and optionally cascade changes forward.
258
+ Returns the new image and the updated image list.
259
+ """
260
+ prompt_data = prompts[panel_index]
261
+ main_prompt = prompt_data.get("main_prompt", "")
262
+ style = prompt_data.get("style_suffix", "")
263
+ full_prompt = f"{main_prompt}, {style}" if style else main_prompt
264
+ neg_prompt = prompt_data.get("negative_prompt", "")
265
+ seed = base_seed + panel_index
266
+
267
+ if panel_index == 0:
268
+ new_img = self.generate_first_panel(
269
+ prompt=full_prompt, negative_prompt=neg_prompt, seed=seed, steps=steps, guidance=guidance
270
+ )
271
+ else:
272
+ new_img = self.generate_next_panel(
273
+ prompt=full_prompt,
274
+ previous_image=existing_images[panel_index - 1],
275
+ negative_prompt=neg_prompt,
276
+ strength=strength,
277
+ seed=seed,
278
+ steps=steps,
279
+ guidance=guidance,
280
+ )
281
+
282
+ updated_images = existing_images.copy()
283
+ updated_images[panel_index] = new_img
284
+ return new_img, updated_images