import os import gc import torch import streamlit as st import tempfile import json import subprocess from huggingface_hub import hf_hub_download import shutil from datetime import datetime from io import BytesIO from transformers import AutoTokenizer, AutoModelForCausalLM from parler_tts import ParlerTTSForConditionalGeneration from diffusers import StableDiffusionPipeline from PIL import Image import soundfile as sf # --- Config --- st.set_page_config(layout="wide", page_title="โšก POV Generator Pro") LLM_MODEL_ID = "openai-community/gpt2-medium" # Slightly larger GPT-2 model IMG_MODEL_ID = "CompVis/stable-diffusion-v1-4" TTS_MODEL_ID = "parler-tts/parler-tts-mini-v1.1" # Make sure this matches your desired ParlerTTS model version # Using Streamlit's native caching for Hugging Face Hub downloads if possible, # otherwise, this explicit cache dir is fine. # For HF Spaces, /tmp is ephemeral but fine for a session. CACHE_DIR = os.path.join(tempfile.gettempdir(), "hf_cache_pov_generator") os.makedirs(CACHE_DIR, exist_ok=True) os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR os.environ['HF_HOME'] = CACHE_DIR # Also sets the general Hugging Face home os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR os.environ['DIFFUSERS_CACHE'] = CACHE_DIR # --- Session State Initialization --- if 'run_id' not in st.session_state: st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") if 'story_data' not in st.session_state: st.session_state.story_data = None if 'pil_images' not in st.session_state: st.session_state.pil_images = None if 'image_paths_for_video' not in st.session_state: st.session_state.image_paths_for_video = None if 'audio_paths' not in st.session_state: st.session_state.audio_paths = None if 'video_path' not in st.session_state: st.session_state.video_path = None if 'temp_base_dir' not in st.session_state: st.session_state.temp_base_dir = None # --- Utility --- def get_session_temp_dir(): if st.session_state.temp_base_dir and os.path.exists(st.session_state.temp_base_dir): return st.session_state.temp_base_dir # Define a base directory for all temporary files for this session run # This helps in cleaning up everything related to one generation run base_dir = os.path.join(tempfile.gettempdir(), f"pov_generator_run_{st.session_state.run_id}") os.makedirs(base_dir, exist_ok=True) st.session_state.temp_base_dir = base_dir return base_dir def cleanup_temp_files(specific_dir=None): """Cleans up temporary files.""" path_to_clean = specific_dir or st.session_state.get("temp_base_dir") if path_to_clean and os.path.exists(path_to_clean): try: shutil.rmtree(path_to_clean) if specific_dir is None: # Only reset if cleaning the main session temp dir st.session_state.temp_base_dir = None print(f"Cleaned up temp directory: {path_to_clean}") except Exception as e: print(f"Error cleaning up temp directory {path_to_clean}: {e}") # Clean up individual files if they were stored outside temp_base_dir (legacy or direct) # For this improved version, all temp files should be within temp_base_dir def clear_torch_cache(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Model Loading (Cached) --- @st.cache_resource def load_llm_model_and_tokenizer(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", cache_dir=CACHE_DIR ) if tokenizer.pad_token_id is None: # GPT-2 might not have a pad token by default tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id return model, tokenizer @st.cache_resource def load_sd_pipeline(model_id): pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir=CACHE_DIR ) if torch.cuda.is_available(): pipe = pipe.to("cuda") return pipe @st.cache_resource def load_tts_model_and_tokenizers(model_id): tts_model = ParlerTTSForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", cache_dir=CACHE_DIR ) prompt_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR) # Ensure text_encoder config attribute is correctly accessed desc_tokenizer_path = tts_model.config.text_encoder.name_or_path if hasattr(tts_model.config.text_encoder, 'name_or_path') else tts_model.config.text_encoder._name_or_path desc_tokenizer = AutoTokenizer.from_pretrained(desc_tokenizer_path, cache_dir=CACHE_DIR) return tts_model, prompt_tokenizer, desc_tokenizer # --- Step 1: Generate JSON Story --- def generate_story(prompt: str, num_scenes: int): model, tokenizer = load_llm_model_and_tokenizer(LLM_MODEL_ID) # Refined prompt for better scene separation and count story_prompt = ( f"Generate a compelling short POV story based on the following prompt: '{prompt}'. " f"The story should consist of exactly {num_scenes} distinct scenes. " f"Clearly separate each scene with the delimiter '###'. " f"Do not include any introductory or concluding text outside of the scenes and their separators. " f"Each scene should be a paragraph of 2-4 sentences." ) input_ids = tokenizer.encode(story_prompt, return_tensors="pt").to(model.device) # Calculate max_new_tokens, ensuring it doesn't exceed model capacity # Model's max context length (e.g., 1024 for GPT-2, 2048 for GPT-2-medium/large) # model.config.n_ctx might not always be present or accurate for all models, using common values. # For gpt2-medium, n_positions is 1024. max_model_tokens = getattr(model.config, 'n_positions', 1024) max_possible_new_tokens = max_model_tokens - input_ids.shape[1] - 20 # Safety buffer desired_tokens_per_scene = 75 # Avg tokens per scene desired_total_tokens = num_scenes * desired_tokens_per_scene # Cap generated tokens to prevent overly long outputs and stay within model limits max_new_tokens_val = min(desired_total_tokens, 700, max_possible_new_tokens) if max_new_tokens_val <= 0: st.error("Prompt is too long, or an issue with calculating max tokens. Not enough space for generating new tokens.") return None output = model.generate( input_ids, max_new_tokens=max_new_tokens_val, do_sample=True, temperature=0.7, top_k=50, pad_token_id=tokenizer.eos_token_id ) full_result = tokenizer.decode(output[0], skip_special_tokens=True) # Remove the input prompt from the beginning of the result if full_result.startswith(story_prompt): generated_text = full_result[len(story_prompt):].strip() else: # Fallback: sometimes the model doesn't perfectly echo the input. # Try to find common start of generation if input is complex or long. # For now, assume it generates after the prompt or just the story. # A simple heuristic is to take the part after the last occurrence of a keyword from the prompt. # This is fragile; good prompt engineering is key. # For now, let's assume it doesn't include the prompt in the output or the above split works. # Or, that the '###' split will handle it. generated_text = full_result # If unsure, process the whole output. scenes_raw = generated_text.split("###") processed_scenes = [] for s in scenes_raw: s_clean = s.strip() if s_clean: # Skip empty scenes processed_scenes.append(s_clean) final_scenes = processed_scenes # If more scenes than requested, take the first N. If fewer, use what's available. if len(final_scenes) > num_scenes: final_scenes = final_scenes[:num_scenes] st.warning(f"LLM generated more scenes than requested. Using the first {num_scenes}.") elif len(final_scenes) < num_scenes: st.warning(f"LLM generated {len(final_scenes)} scenes, but {num_scenes} were requested. Using available scenes.") if not final_scenes: st.error("Failed to parse scenes from LLM output. The output was: " + generated_text) return None clear_torch_cache() return {"title": prompt[:60].capitalize(), "scenes": final_scenes} # --- Step 2: Generate Images --- def generate_images_for_scenes(scenes): pipe = load_sd_pipeline(IMG_MODEL_ID) pil_images = [] # Create a directory for storing frame images for the video frames_dir = os.path.join(get_session_temp_dir(), "frames_for_video") os.makedirs(frames_dir, exist_ok=True) image_paths_for_video = [] cols = st.columns(3) # Adjust number of columns as preferred col_idx = 0 for i, scene_text in enumerate(scenes): with st.spinner(f"Generating image for scene {i+1}..."): try: # Add a style modifier for better visual appeal, can be user-configurable styled_prompt = f"{scene_text}, cinematic lighting, detailed, high quality" image = pipe(styled_prompt, num_inference_steps=30).images[0] # Reduced steps for speed pil_images.append(image) # Save image for video creation img_path = os.path.join(frames_dir, f"frame_{i:03d}.png") image.save(img_path) image_paths_for_video.append(img_path) with cols[col_idx % len(cols)]: st.image(image, caption=f"Scene {i+1}: {scene_text[:100]}...") # Download button for individual image img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') st.download_button( label=f"Download Scene {i+1} Image", data=img_byte_arr.getvalue(), file_name=f"scene_{i+1}_image.png", mime="image/png", key=f"download_img_{i}" ) col_idx += 1 except Exception as e: st.error(f"Error generating image for scene {i+1}: {e}") pil_images.append(None) # Placeholder for failed image image_paths_for_video.append(None) # Placeholder clear_torch_cache() return pil_images, image_paths_for_video # --- Step 3: Generate TTS --- def generate_audios_for_scenes(scenes): tts_model, prompt_tokenizer, desc_tokenizer = load_tts_model_and_tokenizers(TTS_MODEL_ID) audio_dir = os.path.join(get_session_temp_dir(), "audio_files") os.makedirs(audio_dir, exist_ok=True) audio_paths = [] cols = st.columns(3) # Adjust number of columns col_idx = 0 # User-configurable description, or keep it fixed tts_description = "A neutral and clear narrator voice." for i, scene_text in enumerate(scenes): with st.spinner(f"Generating audio for scene {i+1}..."): try: desc_ids = desc_tokenizer(tts_description, return_tensors="pt").input_ids.to(tts_model.device) prompt_ids = prompt_tokenizer(scene_text, return_tensors="pt").input_ids.to(tts_model.device) # Generate audio # For parler-tts, generation_kwargs might be useful, e.g., temperature for description # generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids, temperature=0.7) # Example generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids) audio_waveform = generation_output.to(torch.float32).cpu().numpy() file_path = os.path.join(audio_dir, f"audio_scene_{i+1}.wav") sf.write(file_path, audio_waveform, tts_model.config.sampling_rate) # Use model's sampling rate audio_paths.append(file_path) with cols[col_idx % len(cols)]: st.markdown(f"**Audio for Scene {i+1}**") st.audio(file_path) with open(file_path, "rb") as f_audio: st.download_button( label=f"Download Scene {i+1} Audio", data=f_audio.read(), # Read bytes for download file_name=f"scene_{i+1}_audio.wav", mime="audio/wav", key=f"download_audio_{i}" ) col_idx += 1 except Exception as e: st.error(f"Error generating audio for scene {i+1}: {e}") audio_paths.append(None) # Placeholder clear_torch_cache() return audio_paths # --- Step 4: Create Video --- def create_video_from_scenes(image_file_paths, audio_file_paths, output_filename="final_pov_video.mp4"): if not image_file_paths or not audio_file_paths or len(image_file_paths) != len(audio_file_paths): st.error("Mismatch in number of images and audio files, or missing assets. Cannot create video.") return None # Ensure ffmpeg is installed and accessible try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) except (subprocess.CalledProcessError, FileNotFoundError): st.error("FFMPEG is not installed or not found in PATH. Video creation is not possible.") st.markdown("Please install FFMPEG: `sudo apt update && sudo apt install ffmpeg` (Linux) or `brew install ffmpeg` (macOS).") return None temp_clips_dir = os.path.join(get_session_temp_dir(), "temp_video_clips") os.makedirs(temp_clips_dir, exist_ok=True) video_clips_paths = [] valid_scene_count = 0 for i, (img_path, audio_path) in enumerate(zip(image_file_paths, audio_file_paths)): if img_path is None or audio_path is None: st.warning(f"Skipping scene {i+1} in video due to missing image or audio.") continue try: audio_info = sf.info(audio_path) audio_duration = audio_info.duration if audio_duration <= 0.1: # Minimum practical duration st.warning(f"Audio for scene {i+1} is too short ({audio_duration:.2f}s). Using a minimum duration of 1s.") audio_duration = 1.0 # Enforce a minimum duration clip_path = os.path.join(temp_clips_dir, f"clip_{i:03d}.mp4") # Create individual clip: loop image, add audio, set duration to audio length command = [ "ffmpeg", "-y", "-loop", "1", "-i", img_path, # Loop the image "-i", audio_path, # Input audio "-c:v", "libx264", "-preset", "medium", "-tune", "stillimage", "-c:a", "aac", "-b:a", "192k", "-pix_fmt", "yuv420p", "-t", str(audio_duration), # Duration of this clip "-shortest", # End when shortest input (audio) ends clip_path ] process = subprocess.run(command, capture_output=True, text=True) if process.returncode != 0: st.error(f"FFMPEG error creating clip for scene {i+1}:\n{process.stderr}") continue # Skip this clip video_clips_paths.append(clip_path) valid_scene_count += 1 except Exception as e: st.error(f"Error processing scene {i+1} for video: {e}") continue if not video_clips_paths or valid_scene_count == 0: st.error("No valid video clips were generated. Cannot create final video.") cleanup_temp_files(temp_clips_dir) # Clean up partial clips return None # Create a file list for ffmpeg concat concat_list_file = os.path.join(temp_clips_dir, "concat_list.txt") with open(concat_list_file, "w") as f: for clip_p in video_clips_paths: # Paths in concat file need to be relative or absolute, ensure correct format for ffmpeg # Using absolute paths is safer here if concat_list.txt is in a different dir than clips. # Since they are in the same dir, relative is fine. f.write(f"file '{os.path.basename(clip_p)}'\n") final_video_path = os.path.join(get_session_temp_dir(), output_filename) concat_command = [ "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", concat_list_file, "-c", "copy", # Re-mux, don't re-encode if codecs are compatible final_video_path ] process = subprocess.run(concat_command, capture_output=True, text=True, cwd=temp_clips_dir) # Run from clips dir if process.returncode != 0: st.error(f"FFMPEG error concatenating video clips:\n{process.stderr}") cleanup_temp_files(temp_clips_dir) # Clean up partial clips return None st.success("Video created successfully!") # cleanup_temp_files(temp_clips_dir) # Optionally clean up intermediate clips after final video is made # Better to clean up everything at session end or via button. return final_video_path # --- Main App UI --- st.title("โšก POV Story Generator Pro") st.markdown("Create engaging POV stories with AI-generated text, images, audio, and video.") st.markdown("---") # Sidebar for inputs with st.sidebar: st.header("๐Ÿ“ Story Configuration") prompt = st.text_area( "Enter your POV story prompt:", st.session_state.get("user_prompt", "POV: You are a detective solving a mystery in a futuristic city."), height=100, key="user_prompt_input" ) num_scenes = st.slider("Number of Scenes:", min_value=2, max_value=10, value=st.session_state.get("num_scenes_val", 3), key="num_scenes_slider") st.markdown("---") if st.button("๐Ÿš€ Generate Full Story & Assets", type="primary", use_container_width=True): # Reset states for a new generation run st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New unique ID for this run cleanup_temp_files() # Clean up any previous run's temp files st.session_state.story_data = None st.session_state.pil_images = None st.session_state.image_paths_for_video = None st.session_state.audio_paths = None st.session_state.video_path = None st.session_state.user_prompt = prompt # Save current input values st.session_state.num_scenes_val = num_scenes # Trigger generation flags (optional, direct execution is fine too) st.session_state.generate_all = True st.markdown("---") st.header("๐Ÿ› ๏ธ Utilities") if st.button("๐Ÿงน Clear Cache & Temp Files & Restart", use_container_width=True): # Clear model caches st.cache_resource.clear() # Clear session state related to generated artifacts keys_to_clear = ['story_data', 'pil_images', 'image_paths_for_video', 'audio_paths', 'video_path', 'temp_base_dir', 'generate_all'] for key in keys_to_clear: if key in st.session_state: del st.session_state[key] cleanup_temp_files() # Ensure physical temp files are deleted st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New ID after clear st.success("Caches and temporary files cleared. App will restart.") st.rerun() # Main content area if st.session_state.get("generate_all"): # --- 1. Generate Story --- with st.status("๐Ÿง  Generating story...", expanded=True) as status_story: try: st.session_state.story_data = generate_story(st.session_state.user_prompt, st.session_state.num_scenes_val) if st.session_state.story_data: status_story.update(label="Story generated successfully!", state="complete") else: status_story.update(label="Story generation failed.", state="error") st.session_state.generate_all = False # Stop further processing except Exception as e: st.error(f"An unexpected error occurred during story generation: {e}") status_story.update(label="Story generation error.", state="error") st.session_state.generate_all = False # --- Display Story --- if st.session_state.story_data: st.subheader(f"๐ŸŽฌ Story: {st.session_state.story_data['title']}") for i, scene_text in enumerate(st.session_state.story_data['scenes']): st.markdown(f"**Scene {i+1}:** {scene_text}") story_json = json.dumps(st.session_state.story_data, indent=2) st.download_button( label="Download Story (JSON)", data=story_json, file_name=f"{st.session_state.story_data['title'].replace(' ', '_').lower()}_story.json", mime="application/json" ) st.markdown("---") # --- 2. Generate Images (if story succeeded) --- if st.session_state.get("generate_all") and st.session_state.story_data: with st.status("๐ŸŽจ Generating images for scenes...", expanded=True) as status_images: try: st.session_state.pil_images, st.session_state.image_paths_for_video = generate_images_for_scenes(st.session_state.story_data['scenes']) if all(img is not None for img in st.session_state.pil_images): # Basic check status_images.update(label="Images generated successfully!", state="complete") elif any(img is not None for img in st.session_state.pil_images): status_images.update(label="Some images generated. Check for errors.", state="warning") else: status_images.update(label="Image generation failed for all scenes.", state="error") st.session_state.generate_all = False # Stop further processing except Exception as e: st.error(f"An unexpected error occurred during image generation: {e}") status_images.update(label="Image generation error.", state="error") st.session_state.generate_all = False st.markdown("---") # --- 3. Generate Audio (if images succeeded or partially) --- if st.session_state.get("generate_all") and st.session_state.story_data and st.session_state.pil_images: with st.status("๐Ÿ”Š Generating audio for scenes...", expanded=True) as status_audio: try: st.session_state.audio_paths = generate_audios_for_scenes(st.session_state.story_data['scenes']) if all(p is not None for p in st.session_state.audio_paths): # Basic check status_audio.update(label="Audio generated successfully!", state="complete") elif any(p is not None for p in st.session_state.audio_paths): status_audio.update(label="Some audio files generated. Check for errors.", state="warning") else: status_audio.update(label="Audio generation failed for all scenes.", state="error") st.session_state.generate_all = False # Stop further processing except Exception as e: st.error(f"An unexpected error occurred during audio generation: {e}") status_audio.update(label="Audio generation error.", state="error") st.session_state.generate_all = False st.markdown("---") # --- 4. Create Video (if audio succeeded or partially) --- if st.session_state.get("generate_all") and st.session_state.image_paths_for_video and st.session_state.audio_paths: # Ensure there's at least one valid pair of image and audio valid_assets = sum(1 for img, aud in zip(st.session_state.image_paths_for_video, st.session_state.audio_paths) if img and aud) if valid_assets > 0: with st.status("๐Ÿ“น Creating final video...", expanded=True) as status_video: try: st.session_state.video_path = create_video_from_scenes( st.session_state.image_paths_for_video, st.session_state.audio_paths ) if st.session_state.video_path: status_video.update(label="Video created successfully!", state="complete") else: status_video.update(label="Video creation failed.", state="error") except Exception as e: st.error(f"An unexpected error occurred during video creation: {e}") status_video.update(label="Video creation error.", state="error") if st.session_state.video_path: st.subheader("๐ŸŽž๏ธ Final Video Presentation") st.video(st.session_state.video_path) with open(st.session_state.video_path, "rb") as f_video: st.download_button( label="Download Final Video", data=f_video.read(), file_name=os.path.basename(st.session_state.video_path), mime="video/mp4" ) st.markdown("---") else: st.warning("Not enough valid image/audio pairs to create a video.") # Reset generation trigger if "generate_all" in st.session_state: # Check if key exists before deleting del st.session_state.generate_all elif not st.session_state.get("user_prompt"): # Show initial message if no prompt yet st.info("Configure your story in the sidebar and click 'Generate Full Story & Assets' to begin!") # --- Final Cleanup Instruction (Optional: can be tied to session end if platform supports) --- # For Streamlit, manual cleanup via button or at start of new run is common. # The `cleanup_temp_files()` is called at the start of a new generation.