garyuzair commited on
Commit
da66c01
Β·
verified Β·
1 Parent(s): d4c3da5

Update src/app_hf_space_optimized.py

Browse files
Files changed (1) hide show
  1. src/app_hf_space_optimized.py +507 -89
src/app_hf_space_optimized.py CHANGED
@@ -1,6 +1,3 @@
1
- # βœ… Fully Functional POV Automation App (Streamlit for HF Spaces)
2
- # Optimized for Free Tier: LLM (TinyLlama), SD 1.4, Parler-TTS, FFmpeg image+audio β†’ video
3
-
4
  import os
5
  import gc
6
  import torch
@@ -9,6 +6,9 @@ import tempfile
9
  import json
10
  import subprocess
11
  from huggingface_hub import hf_hub_download
 
 
 
12
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM
14
  from parler_tts import ParlerTTSForConditionalGeneration
@@ -17,116 +17,534 @@ from PIL import Image
17
  import soundfile as sf
18
 
19
  # --- Config ---
20
- st.set_page_config(layout="wide", page_title="⚑ POV Generator (Lite HF Space)")
21
 
22
- LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.1"
23
  IMG_MODEL_ID = "CompVis/stable-diffusion-v1-4"
24
- TTS_MODEL_ID = "parler-tts/parler-tts-mini-v1.1"
25
- CACHE_DIR = "/tmp/hf_cache"
 
 
 
 
26
  os.makedirs(CACHE_DIR, exist_ok=True)
27
  os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # --- Util ---
30
- def clear_torch():
31
  gc.collect()
32
  if torch.cuda.is_available():
33
  torch.cuda.empty_cache()
34
 
35
- # --- Step 1: Generate JSON Story ---
36
- def generate_story(prompt: str, num_scenes: int):
37
- st.info("🧠 Generating story...")
38
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, cache_dir=CACHE_DIR)
39
  model = AutoModelForCausalLM.from_pretrained(
40
- LLM_MODEL_ID,
41
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
42
  device_map="auto",
43
  cache_dir=CACHE_DIR
44
  )
 
 
 
 
45
 
46
- sys_prompt = (
47
- f"You are a scriptwriter. Write a short POV story in exactly {num_scenes} scenes. "
48
- f"Respond ONLY with valid JSON in this format: "
49
- f"{{\"title\": \"Your Title\", \"scenes\": [\"scene 1\", \"scene 2\", ..., \"scene {num_scenes}\"]}}"
 
 
50
  )
 
 
 
51
 
52
- full_prompt = tokenizer.apply_chat_template([
53
- {"role": "system", "content": sys_prompt},
54
- {"role": "user", "content": prompt}
55
- ], tokenize=False)
56
- input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(model.device)
57
- output = model.generate(input_ids, max_new_tokens=1024, do_sample=True)
58
- result = tokenizer.decode(output[0], skip_special_tokens=True)
59
- clear_torch()
60
- try:
61
- json_part = result[result.index("{"): result.rindex("}")+1]
62
- return json.loads(json_part)
63
- except:
64
- st.error("⚠️ Failed to parse JSON. Check model output.")
65
- st.code(result)
66
- return None
67
-
68
- # --- Step 2: Generate Images ---
69
- def generate_images(scenes):
70
- st.info("🎨 Generating images...")
71
- pipe = StableDiffusionPipeline.from_pretrained(
72
- IMG_MODEL_ID,
73
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
74
  cache_dir=CACHE_DIR
75
  )
76
- pipe.to("cuda" if torch.cuda.is_available() else "cpu")
77
- images = []
78
- for i, scene in enumerate(scenes):
79
- image = pipe(scene).images[0]
80
- images.append(image)
81
- st.image(image, caption=f"Scene {i+1}")
82
- clear_torch()
83
- return images
84
 
85
- # --- Step 3: Generate TTS ---
86
- def generate_audios(scenes):
87
- st.info("πŸ”Š Generating audio...")
88
- tts = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_ID, device_map="auto", cache_dir=CACHE_DIR)
89
- tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID, cache_dir=CACHE_DIR)
90
- desc_tokenizer = AutoTokenizer.from_pretrained(tts.config.text_encoder._name_or_path, cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  audio_paths = []
93
- for i, scene in enumerate(scenes):
94
- desc_ids = desc_tokenizer("Neutral narrator", return_tensors="pt").input_ids.to(tts.device)
95
- prompt_ids = tokenizer(scene, return_tensors="pt").input_ids.to(tts.device)
96
- wav = tts.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids).to(torch.float32).cpu().numpy()
97
- path = f"audio_{i+1}.wav"
98
- sf.write(path, wav, 24000)
99
- audio_paths.append(path)
100
- st.audio(path)
101
- clear_torch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  return audio_paths
103
 
104
  # --- Step 4: Create Video ---
105
- def create_video(images, audio_paths):
106
- st.info("πŸ“Ή Creating video...")
107
- frames_dir = tempfile.mkdtemp()
108
- for idx, img in enumerate(images):
109
- img.save(os.path.join(frames_dir, f"frame_{idx:03}.png"))
110
-
111
- video_path = "final_output.mp4"
112
- command = [
113
- "ffmpeg", "-y", "-r", "1", "-i", f"{frames_dir}/frame_%03d.png",
114
- "-i", audio_paths[0],
115
- "-c:v", "libx264", "-pix_fmt", "yuv420p",
116
- video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  ]
118
- subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
119
- st.video(video_path)
120
-
121
- # --- UI ---
122
- st.title("⚑ POV Generator – Hugging Face Free Tier Ready")
123
- prompt = st.text_area("Enter a POV prompt:", "POV: You wake up on Mars and can’t remember Earth")
124
- num_scenes = st.slider("Number of Scenes", 6, 20, 6)
125
-
126
- if st.button("πŸš€ Generate Story"):
127
- story = generate_story(prompt, num_scenes)
128
- if story:
129
- st.subheader(story['title'])
130
- images = generate_images(story['scenes'])
131
- audios = generate_audios(story['scenes'])
132
- create_video(images, audios)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gc
3
  import torch
 
6
  import json
7
  import subprocess
8
  from huggingface_hub import hf_hub_download
9
+ import shutil
10
+ from datetime import datetime
11
+ from io import BytesIO
12
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM
14
  from parler_tts import ParlerTTSForConditionalGeneration
 
17
  import soundfile as sf
18
 
19
  # --- Config ---
20
+ st.set_page_config(layout="wide", page_title="⚑ POV Generator Pro")
21
 
22
+ LLM_MODEL_ID = "openai-community/gpt2-medium" # Slightly larger GPT-2 model
23
  IMG_MODEL_ID = "CompVis/stable-diffusion-v1-4"
24
+ TTS_MODEL_ID = "parler-tts/parler-tts-mini-v1.1" # Make sure this matches your desired ParlerTTS model version
25
+
26
+ # Using Streamlit's native caching for Hugging Face Hub downloads if possible,
27
+ # otherwise, this explicit cache dir is fine.
28
+ # For HF Spaces, /tmp is ephemeral but fine for a session.
29
+ CACHE_DIR = os.path.join(tempfile.gettempdir(), "hf_cache_pov_generator")
30
  os.makedirs(CACHE_DIR, exist_ok=True)
31
  os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
32
+ os.environ['HF_HOME'] = CACHE_DIR # Also sets the general Hugging Face home
33
+ os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
34
+ os.environ['DIFFUSERS_CACHE'] = CACHE_DIR
35
+
36
+ # --- Session State Initialization ---
37
+ if 'run_id' not in st.session_state:
38
+ st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
39
+ if 'story_data' not in st.session_state:
40
+ st.session_state.story_data = None
41
+ if 'pil_images' not in st.session_state:
42
+ st.session_state.pil_images = None
43
+ if 'image_paths_for_video' not in st.session_state:
44
+ st.session_state.image_paths_for_video = None
45
+ if 'audio_paths' not in st.session_state:
46
+ st.session_state.audio_paths = None
47
+ if 'video_path' not in st.session_state:
48
+ st.session_state.video_path = None
49
+ if 'temp_base_dir' not in st.session_state:
50
+ st.session_state.temp_base_dir = None
51
+
52
+ # --- Utility ---
53
+ def get_session_temp_dir():
54
+ if st.session_state.temp_base_dir and os.path.exists(st.session_state.temp_base_dir):
55
+ return st.session_state.temp_base_dir
56
+
57
+ # Define a base directory for all temporary files for this session run
58
+ # This helps in cleaning up everything related to one generation run
59
+ base_dir = os.path.join(tempfile.gettempdir(), f"pov_generator_run_{st.session_state.run_id}")
60
+ os.makedirs(base_dir, exist_ok=True)
61
+ st.session_state.temp_base_dir = base_dir
62
+ return base_dir
63
+
64
+ def cleanup_temp_files(specific_dir=None):
65
+ """Cleans up temporary files."""
66
+ path_to_clean = specific_dir or st.session_state.get("temp_base_dir")
67
+ if path_to_clean and os.path.exists(path_to_clean):
68
+ try:
69
+ shutil.rmtree(path_to_clean)
70
+ if specific_dir is None: # Only reset if cleaning the main session temp dir
71
+ st.session_state.temp_base_dir = None
72
+ print(f"Cleaned up temp directory: {path_to_clean}")
73
+ except Exception as e:
74
+ print(f"Error cleaning up temp directory {path_to_clean}: {e}")
75
+ # Clean up individual files if they were stored outside temp_base_dir (legacy or direct)
76
+ # For this improved version, all temp files should be within temp_base_dir
77
 
78
+ def clear_torch_cache():
 
79
  gc.collect()
80
  if torch.cuda.is_available():
81
  torch.cuda.empty_cache()
82
 
83
+ # --- Model Loading (Cached) ---
84
+ @st.cache_resource
85
+ def load_llm_model_and_tokenizer(model_id):
86
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
87
  model = AutoModelForCausalLM.from_pretrained(
88
+ model_id,
89
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
90
  device_map="auto",
91
  cache_dir=CACHE_DIR
92
  )
93
+ if tokenizer.pad_token_id is None: # GPT-2 might not have a pad token by default
94
+ tokenizer.pad_token = tokenizer.eos_token
95
+ model.config.pad_token_id = model.config.eos_token_id
96
+ return model, tokenizer
97
 
98
+ @st.cache_resource
99
+ def load_sd_pipeline(model_id):
100
+ pipe = StableDiffusionPipeline.from_pretrained(
101
+ model_id,
102
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
103
+ cache_dir=CACHE_DIR
104
  )
105
+ if torch.cuda.is_available():
106
+ pipe = pipe.to("cuda")
107
+ return pipe
108
 
109
+ @st.cache_resource
110
+ def load_tts_model_and_tokenizers(model_id):
111
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
112
+ model_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
114
+ device_map="auto",
115
  cache_dir=CACHE_DIR
116
  )
117
+ prompt_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
118
+ # Ensure text_encoder config attribute is correctly accessed
119
+ desc_tokenizer_path = tts_model.config.text_encoder.name_or_path if hasattr(tts_model.config.text_encoder, 'name_or_path') else tts_model.config.text_encoder._name_or_path
120
+ desc_tokenizer = AutoTokenizer.from_pretrained(desc_tokenizer_path, cache_dir=CACHE_DIR)
121
+ return tts_model, prompt_tokenizer, desc_tokenizer
 
 
 
122
 
123
+ # --- Step 1: Generate JSON Story ---
124
+ def generate_story(prompt: str, num_scenes: int):
125
+ model, tokenizer = load_llm_model_and_tokenizer(LLM_MODEL_ID)
126
+
127
+ # Refined prompt for better scene separation and count
128
+ story_prompt = (
129
+ f"Generate a compelling short POV story based on the following prompt: '{prompt}'. "
130
+ f"The story should consist of exactly {num_scenes} distinct scenes. "
131
+ f"Clearly separate each scene with the delimiter '###'. "
132
+ f"Do not include any introductory or concluding text outside of the scenes and their separators. "
133
+ f"Each scene should be a paragraph of 2-4 sentences."
134
+ )
135
+
136
+ input_ids = tokenizer.encode(story_prompt, return_tensors="pt").to(model.device)
137
+
138
+ # Calculate max_new_tokens, ensuring it doesn't exceed model capacity
139
+ # Model's max context length (e.g., 1024 for GPT-2, 2048 for GPT-2-medium/large)
140
+ # model.config.n_ctx might not always be present or accurate for all models, using common values.
141
+ # For gpt2-medium, n_positions is 1024.
142
+ max_model_tokens = getattr(model.config, 'n_positions', 1024)
143
+ max_possible_new_tokens = max_model_tokens - input_ids.shape[1] - 20 # Safety buffer
144
+
145
+ desired_tokens_per_scene = 75 # Avg tokens per scene
146
+ desired_total_tokens = num_scenes * desired_tokens_per_scene
147
+
148
+ # Cap generated tokens to prevent overly long outputs and stay within model limits
149
+ max_new_tokens_val = min(desired_total_tokens, 700, max_possible_new_tokens)
150
+
151
+ if max_new_tokens_val <= 0:
152
+ st.error("Prompt is too long, or an issue with calculating max tokens. Not enough space for generating new tokens.")
153
+ return None
154
+
155
+ output = model.generate(
156
+ input_ids,
157
+ max_new_tokens=max_new_tokens_val,
158
+ do_sample=True,
159
+ temperature=0.7,
160
+ top_k=50,
161
+ pad_token_id=tokenizer.eos_token_id
162
+ )
163
+ full_result = tokenizer.decode(output[0], skip_special_tokens=True)
164
+
165
+ # Remove the input prompt from the beginning of the result
166
+ if full_result.startswith(story_prompt):
167
+ generated_text = full_result[len(story_prompt):].strip()
168
+ else:
169
+ # Fallback: sometimes the model doesn't perfectly echo the input.
170
+ # Try to find common start of generation if input is complex or long.
171
+ # For now, assume it generates after the prompt or just the story.
172
+ # A simple heuristic is to take the part after the last occurrence of a keyword from the prompt.
173
+ # This is fragile; good prompt engineering is key.
174
+ # For now, let's assume it doesn't include the prompt in the output or the above split works.
175
+ # Or, that the '###' split will handle it.
176
+ generated_text = full_result # If unsure, process the whole output.
177
 
178
+ scenes_raw = generated_text.split("###")
179
+ processed_scenes = []
180
+ for s in scenes_raw:
181
+ s_clean = s.strip()
182
+ if s_clean: # Skip empty scenes
183
+ processed_scenes.append(s_clean)
184
+
185
+ final_scenes = processed_scenes
186
+ # If more scenes than requested, take the first N. If fewer, use what's available.
187
+ if len(final_scenes) > num_scenes:
188
+ final_scenes = final_scenes[:num_scenes]
189
+ st.warning(f"LLM generated more scenes than requested. Using the first {num_scenes}.")
190
+ elif len(final_scenes) < num_scenes:
191
+ st.warning(f"LLM generated {len(final_scenes)} scenes, but {num_scenes} were requested. Using available scenes.")
192
+
193
+ if not final_scenes:
194
+ st.error("Failed to parse scenes from LLM output. The output was: " + generated_text)
195
+ return None
196
+
197
+ clear_torch_cache()
198
+ return {"title": prompt[:60].capitalize(), "scenes": final_scenes}
199
+
200
+ # --- Step 2: Generate Images ---
201
+ def generate_images_for_scenes(scenes):
202
+ pipe = load_sd_pipeline(IMG_MODEL_ID)
203
+ pil_images = []
204
+
205
+ # Create a directory for storing frame images for the video
206
+ frames_dir = os.path.join(get_session_temp_dir(), "frames_for_video")
207
+ os.makedirs(frames_dir, exist_ok=True)
208
+ image_paths_for_video = []
209
+
210
+ cols = st.columns(3) # Adjust number of columns as preferred
211
+ col_idx = 0
212
+
213
+ for i, scene_text in enumerate(scenes):
214
+ with st.spinner(f"Generating image for scene {i+1}..."):
215
+ try:
216
+ # Add a style modifier for better visual appeal, can be user-configurable
217
+ styled_prompt = f"{scene_text}, cinematic lighting, detailed, high quality"
218
+ image = pipe(styled_prompt, num_inference_steps=30).images[0] # Reduced steps for speed
219
+ pil_images.append(image)
220
+
221
+ # Save image for video creation
222
+ img_path = os.path.join(frames_dir, f"frame_{i:03d}.png")
223
+ image.save(img_path)
224
+ image_paths_for_video.append(img_path)
225
+
226
+ with cols[col_idx % len(cols)]:
227
+ st.image(image, caption=f"Scene {i+1}: {scene_text[:100]}...")
228
+
229
+ # Download button for individual image
230
+ img_byte_arr = BytesIO()
231
+ image.save(img_byte_arr, format='PNG')
232
+ st.download_button(
233
+ label=f"Download Scene {i+1} Image",
234
+ data=img_byte_arr.getvalue(),
235
+ file_name=f"scene_{i+1}_image.png",
236
+ mime="image/png",
237
+ key=f"download_img_{i}"
238
+ )
239
+ col_idx += 1
240
+ except Exception as e:
241
+ st.error(f"Error generating image for scene {i+1}: {e}")
242
+ pil_images.append(None) # Placeholder for failed image
243
+ image_paths_for_video.append(None) # Placeholder
244
+
245
+ clear_torch_cache()
246
+ return pil_images, image_paths_for_video
247
+
248
+ # --- Step 3: Generate TTS ---
249
+ def generate_audios_for_scenes(scenes):
250
+ tts_model, prompt_tokenizer, desc_tokenizer = load_tts_model_and_tokenizers(TTS_MODEL_ID)
251
+
252
+ audio_dir = os.path.join(get_session_temp_dir(), "audio_files")
253
+ os.makedirs(audio_dir, exist_ok=True)
254
  audio_paths = []
255
+
256
+ cols = st.columns(3) # Adjust number of columns
257
+ col_idx = 0
258
+
259
+ # User-configurable description, or keep it fixed
260
+ tts_description = "A neutral and clear narrator voice."
261
+
262
+ for i, scene_text in enumerate(scenes):
263
+ with st.spinner(f"Generating audio for scene {i+1}..."):
264
+ try:
265
+ desc_ids = desc_tokenizer(tts_description, return_tensors="pt").input_ids.to(tts_model.device)
266
+ prompt_ids = prompt_tokenizer(scene_text, return_tensors="pt").input_ids.to(tts_model.device)
267
+
268
+ # Generate audio
269
+ # For parler-tts, generation_kwargs might be useful, e.g., temperature for description
270
+ # generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids, temperature=0.7) # Example
271
+ generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids)
272
+
273
+ audio_waveform = generation_output.to(torch.float32).cpu().numpy()
274
+
275
+ file_path = os.path.join(audio_dir, f"audio_scene_{i+1}.wav")
276
+ sf.write(file_path, audio_waveform, tts_model.config.sampling_rate) # Use model's sampling rate
277
+ audio_paths.append(file_path)
278
+
279
+ with cols[col_idx % len(cols)]:
280
+ st.markdown(f"**Audio for Scene {i+1}**")
281
+ st.audio(file_path)
282
+ with open(file_path, "rb") as f_audio:
283
+ st.download_button(
284
+ label=f"Download Scene {i+1} Audio",
285
+ data=f_audio.read(), # Read bytes for download
286
+ file_name=f"scene_{i+1}_audio.wav",
287
+ mime="audio/wav",
288
+ key=f"download_audio_{i}"
289
+ )
290
+ col_idx += 1
291
+ except Exception as e:
292
+ st.error(f"Error generating audio for scene {i+1}: {e}")
293
+ audio_paths.append(None) # Placeholder
294
+
295
+ clear_torch_cache()
296
  return audio_paths
297
 
298
  # --- Step 4: Create Video ---
299
+ def create_video_from_scenes(image_file_paths, audio_file_paths, output_filename="final_pov_video.mp4"):
300
+ if not image_file_paths or not audio_file_paths or len(image_file_paths) != len(audio_file_paths):
301
+ st.error("Mismatch in number of images and audio files, or missing assets. Cannot create video.")
302
+ return None
303
+
304
+ # Ensure ffmpeg is installed and accessible
305
+ try:
306
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
307
+ except (subprocess.CalledProcessError, FileNotFoundError):
308
+ st.error("FFMPEG is not installed or not found in PATH. Video creation is not possible.")
309
+ st.markdown("Please install FFMPEG: `sudo apt update && sudo apt install ffmpeg` (Linux) or `brew install ffmpeg` (macOS).")
310
+ return None
311
+
312
+ temp_clips_dir = os.path.join(get_session_temp_dir(), "temp_video_clips")
313
+ os.makedirs(temp_clips_dir, exist_ok=True)
314
+
315
+ video_clips_paths = []
316
+ valid_scene_count = 0
317
+
318
+ for i, (img_path, audio_path) in enumerate(zip(image_file_paths, audio_file_paths)):
319
+ if img_path is None or audio_path is None:
320
+ st.warning(f"Skipping scene {i+1} in video due to missing image or audio.")
321
+ continue
322
+
323
+ try:
324
+ audio_info = sf.info(audio_path)
325
+ audio_duration = audio_info.duration
326
+ if audio_duration <= 0.1: # Minimum practical duration
327
+ st.warning(f"Audio for scene {i+1} is too short ({audio_duration:.2f}s). Using a minimum duration of 1s.")
328
+ audio_duration = 1.0 # Enforce a minimum duration
329
+
330
+ clip_path = os.path.join(temp_clips_dir, f"clip_{i:03d}.mp4")
331
+
332
+ # Create individual clip: loop image, add audio, set duration to audio length
333
+ command = [
334
+ "ffmpeg", "-y",
335
+ "-loop", "1", "-i", img_path, # Loop the image
336
+ "-i", audio_path, # Input audio
337
+ "-c:v", "libx264", "-preset", "medium", "-tune", "stillimage",
338
+ "-c:a", "aac", "-b:a", "192k",
339
+ "-pix_fmt", "yuv420p",
340
+ "-t", str(audio_duration), # Duration of this clip
341
+ "-shortest", # End when shortest input (audio) ends
342
+ clip_path
343
+ ]
344
+ process = subprocess.run(command, capture_output=True, text=True)
345
+ if process.returncode != 0:
346
+ st.error(f"FFMPEG error creating clip for scene {i+1}:\n{process.stderr}")
347
+ continue # Skip this clip
348
+ video_clips_paths.append(clip_path)
349
+ valid_scene_count += 1
350
+ except Exception as e:
351
+ st.error(f"Error processing scene {i+1} for video: {e}")
352
+ continue
353
+
354
+ if not video_clips_paths or valid_scene_count == 0:
355
+ st.error("No valid video clips were generated. Cannot create final video.")
356
+ cleanup_temp_files(temp_clips_dir) # Clean up partial clips
357
+ return None
358
+
359
+ # Create a file list for ffmpeg concat
360
+ concat_list_file = os.path.join(temp_clips_dir, "concat_list.txt")
361
+ with open(concat_list_file, "w") as f:
362
+ for clip_p in video_clips_paths:
363
+ # Paths in concat file need to be relative or absolute, ensure correct format for ffmpeg
364
+ # Using absolute paths is safer here if concat_list.txt is in a different dir than clips.
365
+ # Since they are in the same dir, relative is fine.
366
+ f.write(f"file '{os.path.basename(clip_p)}'\n")
367
+
368
+ final_video_path = os.path.join(get_session_temp_dir(), output_filename)
369
+ concat_command = [
370
+ "ffmpeg", "-y",
371
+ "-f", "concat", "-safe", "0", "-i", concat_list_file,
372
+ "-c", "copy", # Re-mux, don't re-encode if codecs are compatible
373
+ final_video_path
374
  ]
375
+ process = subprocess.run(concat_command, capture_output=True, text=True, cwd=temp_clips_dir) # Run from clips dir
376
+ if process.returncode != 0:
377
+ st.error(f"FFMPEG error concatenating video clips:\n{process.stderr}")
378
+ cleanup_temp_files(temp_clips_dir) # Clean up partial clips
379
+ return None
380
+
381
+ st.success("Video created successfully!")
382
+ # cleanup_temp_files(temp_clips_dir) # Optionally clean up intermediate clips after final video is made
383
+ # Better to clean up everything at session end or via button.
384
+ return final_video_path
385
+
386
+ # --- Main App UI ---
387
+ st.title("⚑ POV Story Generator Pro")
388
+ st.markdown("Create engaging POV stories with AI-generated text, images, audio, and video.")
389
+ st.markdown("---")
390
+
391
+ # Sidebar for inputs
392
+ with st.sidebar:
393
+ st.header("πŸ“ Story Configuration")
394
+ prompt = st.text_area(
395
+ "Enter your POV story prompt:",
396
+ st.session_state.get("user_prompt", "POV: You are a detective solving a mystery in a futuristic city."),
397
+ height=100,
398
+ key="user_prompt_input"
399
+ )
400
+ num_scenes = st.slider("Number of Scenes:", min_value=2, max_value=10, value=st.session_state.get("num_scenes_val", 3), key="num_scenes_slider")
401
+
402
+ st.markdown("---")
403
+ if st.button("πŸš€ Generate Full Story & Assets", type="primary", use_container_width=True):
404
+ # Reset states for a new generation run
405
+ st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New unique ID for this run
406
+ cleanup_temp_files() # Clean up any previous run's temp files
407
+
408
+ st.session_state.story_data = None
409
+ st.session_state.pil_images = None
410
+ st.session_state.image_paths_for_video = None
411
+ st.session_state.audio_paths = None
412
+ st.session_state.video_path = None
413
+
414
+ st.session_state.user_prompt = prompt # Save current input values
415
+ st.session_state.num_scenes_val = num_scenes
416
+
417
+ # Trigger generation flags (optional, direct execution is fine too)
418
+ st.session_state.generate_all = True
419
+
420
+ st.markdown("---")
421
+ st.header("πŸ› οΈ Utilities")
422
+ if st.button("🧹 Clear Cache & Temp Files & Restart", use_container_width=True):
423
+ # Clear model caches
424
+ st.cache_resource.clear()
425
+ # Clear session state related to generated artifacts
426
+ keys_to_clear = ['story_data', 'pil_images', 'image_paths_for_video',
427
+ 'audio_paths', 'video_path', 'temp_base_dir', 'generate_all']
428
+ for key in keys_to_clear:
429
+ if key in st.session_state:
430
+ del st.session_state[key]
431
+ cleanup_temp_files() # Ensure physical temp files are deleted
432
+ st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New ID after clear
433
+ st.success("Caches and temporary files cleared. App will restart.")
434
+ st.rerun()
435
+
436
+
437
+ # Main content area
438
+ if st.session_state.get("generate_all"):
439
+ # --- 1. Generate Story ---
440
+ with st.status("🧠 Generating story...", expanded=True) as status_story:
441
+ try:
442
+ st.session_state.story_data = generate_story(st.session_state.user_prompt, st.session_state.num_scenes_val)
443
+ if st.session_state.story_data:
444
+ status_story.update(label="Story generated successfully!", state="complete")
445
+ else:
446
+ status_story.update(label="Story generation failed.", state="error")
447
+ st.session_state.generate_all = False # Stop further processing
448
+ except Exception as e:
449
+ st.error(f"An unexpected error occurred during story generation: {e}")
450
+ status_story.update(label="Story generation error.", state="error")
451
+ st.session_state.generate_all = False
452
+
453
+
454
+ # --- Display Story ---
455
+ if st.session_state.story_data:
456
+ st.subheader(f"🎬 Story: {st.session_state.story_data['title']}")
457
+ for i, scene_text in enumerate(st.session_state.story_data['scenes']):
458
+ st.markdown(f"**Scene {i+1}:** {scene_text}")
459
+
460
+ story_json = json.dumps(st.session_state.story_data, indent=2)
461
+ st.download_button(
462
+ label="Download Story (JSON)",
463
+ data=story_json,
464
+ file_name=f"{st.session_state.story_data['title'].replace(' ', '_').lower()}_story.json",
465
+ mime="application/json"
466
+ )
467
+ st.markdown("---")
468
+
469
+
470
+ # --- 2. Generate Images (if story succeeded) ---
471
+ if st.session_state.get("generate_all") and st.session_state.story_data:
472
+ with st.status("🎨 Generating images for scenes...", expanded=True) as status_images:
473
+ try:
474
+ st.session_state.pil_images, st.session_state.image_paths_for_video = generate_images_for_scenes(st.session_state.story_data['scenes'])
475
+ if all(img is not None for img in st.session_state.pil_images): # Basic check
476
+ status_images.update(label="Images generated successfully!", state="complete")
477
+ elif any(img is not None for img in st.session_state.pil_images):
478
+ status_images.update(label="Some images generated. Check for errors.", state="warning")
479
+ else:
480
+ status_images.update(label="Image generation failed for all scenes.", state="error")
481
+ st.session_state.generate_all = False # Stop further processing
482
+ except Exception as e:
483
+ st.error(f"An unexpected error occurred during image generation: {e}")
484
+ status_images.update(label="Image generation error.", state="error")
485
+ st.session_state.generate_all = False
486
+ st.markdown("---")
487
+
488
+
489
+ # --- 3. Generate Audio (if images succeeded or partially) ---
490
+ if st.session_state.get("generate_all") and st.session_state.story_data and st.session_state.pil_images:
491
+ with st.status("πŸ”Š Generating audio for scenes...", expanded=True) as status_audio:
492
+ try:
493
+ st.session_state.audio_paths = generate_audios_for_scenes(st.session_state.story_data['scenes'])
494
+ if all(p is not None for p in st.session_state.audio_paths): # Basic check
495
+ status_audio.update(label="Audio generated successfully!", state="complete")
496
+ elif any(p is not None for p in st.session_state.audio_paths):
497
+ status_audio.update(label="Some audio files generated. Check for errors.", state="warning")
498
+ else:
499
+ status_audio.update(label="Audio generation failed for all scenes.", state="error")
500
+ st.session_state.generate_all = False # Stop further processing
501
+ except Exception as e:
502
+ st.error(f"An unexpected error occurred during audio generation: {e}")
503
+ status_audio.update(label="Audio generation error.", state="error")
504
+ st.session_state.generate_all = False
505
+ st.markdown("---")
506
+
507
+ # --- 4. Create Video (if audio succeeded or partially) ---
508
+ if st.session_state.get("generate_all") and st.session_state.image_paths_for_video and st.session_state.audio_paths:
509
+ # Ensure there's at least one valid pair of image and audio
510
+ valid_assets = sum(1 for img, aud in zip(st.session_state.image_paths_for_video, st.session_state.audio_paths) if img and aud)
511
+ if valid_assets > 0:
512
+ with st.status("πŸ“Ή Creating final video...", expanded=True) as status_video:
513
+ try:
514
+ st.session_state.video_path = create_video_from_scenes(
515
+ st.session_state.image_paths_for_video,
516
+ st.session_state.audio_paths
517
+ )
518
+ if st.session_state.video_path:
519
+ status_video.update(label="Video created successfully!", state="complete")
520
+ else:
521
+ status_video.update(label="Video creation failed.", state="error")
522
+ except Exception as e:
523
+ st.error(f"An unexpected error occurred during video creation: {e}")
524
+ status_video.update(label="Video creation error.", state="error")
525
+
526
+ if st.session_state.video_path:
527
+ st.subheader("🎞️ Final Video Presentation")
528
+ st.video(st.session_state.video_path)
529
+ with open(st.session_state.video_path, "rb") as f_video:
530
+ st.download_button(
531
+ label="Download Final Video",
532
+ data=f_video.read(),
533
+ file_name=os.path.basename(st.session_state.video_path),
534
+ mime="video/mp4"
535
+ )
536
+ st.markdown("---")
537
+ else:
538
+ st.warning("Not enough valid image/audio pairs to create a video.")
539
+
540
+ # Reset generation trigger
541
+ if "generate_all" in st.session_state: # Check if key exists before deleting
542
+ del st.session_state.generate_all
543
+
544
+ elif not st.session_state.get("user_prompt"): # Show initial message if no prompt yet
545
+ st.info("Configure your story in the sidebar and click 'Generate Full Story & Assets' to begin!")
546
+
547
+
548
+ # --- Final Cleanup Instruction (Optional: can be tied to session end if platform supports) ---
549
+ # For Streamlit, manual cleanup via button or at start of new run is common.
550
+ # The `cleanup_temp_files()` is called at the start of a new generation.