Spaces:
Running
Running
| """ | |||
| TimeLapseForge - Universal API Provider Layer v2.2 | |||
| All API client imports are LAZY. | |||
| Smart prompt truncation per provider. | |||
| """ | |||
| import os | |||
| import io | |||
| import time | |||
| import base64 | |||
| import tempfile | |||
| import requests | |||
| from PIL import Image | |||
| from typing import Optional, Dict, List, Any, Tuple | |||
| from abc import ABC, abstractmethod | |||
| def _safe_import(package_name): | |||
| try: | |||
| import importlib | |||
| return importlib.import_module(package_name) | |||
| except ImportError: | |||
| return None | |||
| def _require_import(package_name, pip_name=None): | |||
| mod = _safe_import(package_name) | |||
| if mod is None: | |||
| pip = pip_name or package_name | |||
| raise ImportError( | |||
| "Package '" + pip + "' is not installed. " | |||
| "Add it to requirements.txt or use a different provider." | |||
| ) | |||
| return mod | |||
| # ============================================ | |||
| # SMART PROMPT TRUNCATOR | |||
| # ============================================ | |||
| def smart_truncate(text, max_length, preserve_end=True): | |||
| """ | |||
| Intelligently truncate a prompt to fit within API limits. | |||
| Preserves the most important parts: subject description and style suffix. | |||
| """ | |||
| if not text or len(text) <= max_length: | |||
| return text | |||
| # Strategy: keep first part (subject) and last part (style keywords) | |||
| if preserve_end: | |||
| # Find the last comma-separated style section | |||
| parts = text.rsplit(", ", 1) | |||
| if len(parts) == 2 and len(parts[1]) < max_length // 3: | |||
| suffix = ", " + parts[1] | |||
| available = max_length - len(suffix) - 5 # 5 for " ... " | |||
| if available > 100: | |||
| return text[:available] + " ... " + suffix | |||
| # Simple truncation with clean cut at word boundary | |||
| truncated = text[:max_length - 3] | |||
| last_space = truncated.rfind(" ") | |||
| if last_space > max_length // 2: | |||
| truncated = truncated[:last_space] | |||
| return truncated + "..." | |||
| def split_prompt_parts(full_prompt): | |||
| """ | |||
| Split a long prompt into core subject and style modifiers. | |||
| Returns (core, style) where style is the reusable suffix. | |||
| """ | |||
| # Common style keywords that appear at the end | |||
| style_markers = [ | |||
| "photorealistic", "cinematic", "4K", "8K", "detailed", | |||
| "shot on", "lens", "lighting", "consistent", "camera", | |||
| "high quality", "professional", "dramatic", | |||
| ] | |||
| # Try to find where style section starts | |||
| lower = full_prompt.lower() | |||
| best_split = len(full_prompt) | |||
| for marker in style_markers: | |||
| idx = lower.rfind(marker) | |||
| if idx > len(full_prompt) // 2: | |||
| # Find the comma before this marker | |||
| comma_idx = full_prompt.rfind(", ", 0, idx) | |||
| if comma_idx > len(full_prompt) // 3: | |||
| best_split = min(best_split, comma_idx) | |||
| if best_split < len(full_prompt): | |||
| core = full_prompt[:best_split].strip().rstrip(",") | |||
| style = full_prompt[best_split:].strip().lstrip(",").strip() | |||
| return core, style | |||
| return full_prompt, "" | |||
| # ============================================ | |||
| # BASE PROVIDER CLASS | |||
| # ============================================ | |||
| class BaseProvider(ABC): | |||
| name = "base" | |||
| display_name = "Base Provider" | |||
| website = "" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "" | |||
| available_models = [] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 # Default generous limit | |||
| def __init__(self, api_key=""): | |||
| self.api_key = api_key.strip() | |||
| def _truncate(self, prompt, max_len=None): | |||
| """Truncate prompt to fit provider's limit.""" | |||
| limit = max_len or self.max_prompt_length | |||
| return smart_truncate(prompt, limit) | |||
| def generate_image( | |||
| self, prompt, negative_prompt="", | |||
| width=1024, height=1024, | |||
| seed=None, model=None, **kwargs, | |||
| ): | |||
| pass | |||
| def img2img( | |||
| self, prompt, image, strength=0.4, | |||
| negative_prompt="", seed=None, | |||
| model=None, **kwargs, | |||
| ): | |||
| return self.generate_image( | |||
| prompt=prompt, negative_prompt=negative_prompt, | |||
| width=image.width, height=image.height, | |||
| seed=seed, model=model, **kwargs, | |||
| ) | |||
| def _image_to_base64(img, fmt="PNG"): | |||
| buf = io.BytesIO() | |||
| img.save(buf, format=fmt) | |||
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |||
| def _base64_to_image(b64): | |||
| data = base64.b64decode(b64) | |||
| return Image.open(io.BytesIO(data)).convert("RGB") | |||
| def _url_to_image(url): | |||
| resp = requests.get(url, timeout=120) | |||
| resp.raise_for_status() | |||
| return Image.open(io.BytesIO(resp.content)).convert("RGB") | |||
| def _bytes_to_image(data): | |||
| return Image.open(io.BytesIO(data)).convert("RGB") | |||
| # ============================================ | |||
| # 1. OPENAI (DALL-E 3 / gpt-image-1) | |||
| # ============================================ | |||
| class OpenAIProvider(BaseProvider): | |||
| name = "openai" | |||
| display_name = "OpenAI (DALL-E 3 / gpt-image-1)" | |||
| website = "https://platform.openai.com/api-keys" | |||
| supports_img2img = False | |||
| supports_negative_prompt = False | |||
| default_model = "dall-e-3" | |||
| available_models = ["dall-e-3", "dall-e-2", "gpt-image-1"] | |||
| requires_package = "openai" | |||
| max_prompt_length = 3900 # DALL-E 3 limit is 4000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| openai_mod = _require_import("openai") | |||
| client = openai_mod.OpenAI(api_key=self.api_key) | |||
| model = model or self.default_model | |||
| # Set correct limit per model | |||
| if model == "dall-e-2": | |||
| limit = 900 # DALL-E 2 limit is 1000 | |||
| elif model == "gpt-image-1": | |||
| limit = 32000 # gpt-image-1 has much higher limit | |||
| else: | |||
| limit = 3900 # DALL-E 3 | |||
| safe_prompt = self._truncate(prompt, limit) | |||
| size_map = { | |||
| (1024, 1024): "1024x1024", (1792, 1024): "1792x1024", | |||
| (1024, 1792): "1024x1792", (512, 512): "512x512", | |||
| (256, 256): "256x256", | |||
| } | |||
| size = size_map.get((width, height), "1024x1024") | |||
| if model == "gpt-image-1": | |||
| response = client.images.generate( | |||
| model="gpt-image-1", prompt=safe_prompt, n=1, size=size, | |||
| ) | |||
| if hasattr(response.data[0], 'b64_json') and response.data[0].b64_json: | |||
| return self._base64_to_image(response.data[0].b64_json) | |||
| return self._url_to_image(response.data[0].url) | |||
| else: | |||
| api_kwargs = dict( | |||
| model=model, prompt=safe_prompt, n=1, size=size, | |||
| response_format="b64_json", | |||
| ) | |||
| if model == "dall-e-3": | |||
| api_kwargs["quality"] = kwargs.get("quality", "hd") | |||
| api_kwargs["style"] = kwargs.get("style", "natural") | |||
| response = client.images.generate(**api_kwargs) | |||
| return self._base64_to_image(response.data[0].b64_json) | |||
| # ============================================ | |||
| # 2. STABILITY AI | |||
| # ============================================ | |||
| class StabilityProvider(BaseProvider): | |||
| name = "stability" | |||
| display_name = "Stability AI (SD3 / SDXL)" | |||
| website = "https://platform.stability.ai/account/keys" | |||
| supports_img2img = True | |||
| supports_negative_prompt = True | |||
| default_model = "sd3.5-large" | |||
| available_models = [ | |||
| "sd3.5-large", "sd3.5-large-turbo", "sd3.5-medium", | |||
| "sd3-large", "sd3-large-turbo", "sd3-medium", | |||
| "stable-image-core", "stable-image-ultra", | |||
| ] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 | |||
| API_BASE = "https://api.stability.ai/v2beta" | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| model = model or self.default_model | |||
| safe_prompt = self._truncate(prompt) | |||
| headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"} | |||
| data = {"prompt": safe_prompt, "output_format": "png", "width": width, "height": height} | |||
| if negative_prompt: | |||
| data["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| data["seed"] = seed | |||
| if "stable-image" in model: | |||
| url = self.API_BASE + "/stable-image/generate/" + model.replace("stable-image-", "") | |||
| else: | |||
| url = self.API_BASE + "/stable-image/generate/sd3" | |||
| data["model"] = model | |||
| resp = requests.post(url, headers=headers, files={"none": ""}, data=data, timeout=120) | |||
| resp.raise_for_status() | |||
| return self._bytes_to_image(resp.content) | |||
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |||
| seed=None, model=None, **kwargs): | |||
| safe_prompt = self._truncate(prompt) | |||
| headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"} | |||
| buf = io.BytesIO() | |||
| image.save(buf, format="PNG") | |||
| buf.seek(0) | |||
| data = { | |||
| "prompt": safe_prompt, "strength": strength, | |||
| "output_format": "png", "mode": "image-to-image", | |||
| } | |||
| if negative_prompt: | |||
| data["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| data["seed"] = seed | |||
| files = {"image": ("input.png", buf, "image/png")} | |||
| url = self.API_BASE + "/stable-image/generate/sd3" | |||
| resp = requests.post(url, headers=headers, files=files, data=data, timeout=120) | |||
| resp.raise_for_status() | |||
| return self._bytes_to_image(resp.content) | |||
| # ============================================ | |||
| # 3. REPLICATE | |||
| # ============================================ | |||
| class ReplicateProvider(BaseProvider): | |||
| name = "replicate" | |||
| display_name = "Replicate (Flux / SDXL / Any)" | |||
| website = "https://replicate.com/account/api-tokens" | |||
| supports_img2img = True | |||
| supports_negative_prompt = True | |||
| default_model = "black-forest-labs/flux-1.1-pro" | |||
| available_models = [ | |||
| "black-forest-labs/flux-1.1-pro", | |||
| "black-forest-labs/flux-schnell", | |||
| "black-forest-labs/flux-dev", | |||
| "stability-ai/sdxl:latest", | |||
| "stability-ai/stable-diffusion-3.5-large", | |||
| "bytedance/sdxl-lightning-4step:latest", | |||
| ] | |||
| requires_package = "replicate" | |||
| max_prompt_length = 10000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| replicate_mod = _require_import("replicate") | |||
| client = replicate_mod.Client(api_token=self.api_key) | |||
| model_id = model or self.default_model | |||
| safe_prompt = self._truncate(prompt) | |||
| input_params = {"prompt": safe_prompt, "width": width, "height": height} | |||
| if negative_prompt and "flux" not in model_id.lower(): | |||
| input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| input_params["seed"] = seed | |||
| output = client.run(model_id, input=input_params) | |||
| if isinstance(output, list): | |||
| url = str(output[0]) | |||
| elif hasattr(output, 'url'): | |||
| url = output.url | |||
| else: | |||
| url = str(output) | |||
| return self._url_to_image(url) | |||
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |||
| seed=None, model=None, **kwargs): | |||
| replicate_mod = _require_import("replicate") | |||
| client = replicate_mod.Client(api_token=self.api_key) | |||
| model_id = model or "stability-ai/sdxl:latest" | |||
| safe_prompt = self._truncate(prompt) | |||
| buf = io.BytesIO() | |||
| image.save(buf, format="PNG") | |||
| buf.seek(0) | |||
| input_params = {"prompt": safe_prompt, "image": buf, "prompt_strength": strength} | |||
| if negative_prompt: | |||
| input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| input_params["seed"] = seed | |||
| output = client.run(model_id, input=input_params) | |||
| url = str(output[0]) if isinstance(output, list) else str(output) | |||
| return self._url_to_image(url) | |||
| # ============================================ | |||
| # 4. TOGETHER AI | |||
| # ============================================ | |||
| class TogetherProvider(BaseProvider): | |||
| name = "together" | |||
| display_name = "Together AI (Flux / SDXL)" | |||
| website = "https://api.together.xyz/settings/api-keys" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "black-forest-labs/FLUX.1.1-pro" | |||
| available_models = [ | |||
| "black-forest-labs/FLUX.1.1-pro", | |||
| "black-forest-labs/FLUX.1-schnell-Free", | |||
| "black-forest-labs/FLUX.1-dev", | |||
| "stabilityai/stable-diffusion-xl-base-1.0", | |||
| ] | |||
| requires_package = "together" | |||
| max_prompt_length = 10000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| together_mod = _require_import("together") | |||
| client = together_mod.Together(api_key=self.api_key) | |||
| model_id = model or self.default_model | |||
| safe_prompt = self._truncate(prompt) | |||
| params = dict(model=model_id, prompt=safe_prompt, width=width, height=height, | |||
| steps=kwargs.get("steps", 28), n=1, response_format="b64_json") | |||
| if negative_prompt: | |||
| params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| params["seed"] = seed | |||
| response = client.images.generate(**params) | |||
| return self._base64_to_image(response.data[0].b64_json) | |||
| # ============================================ | |||
| # 5. FAL.AI | |||
| # ============================================ | |||
| class FalProvider(BaseProvider): | |||
| name = "fal" | |||
| display_name = "Fal.ai (Flux Pro / Fast SDXL)" | |||
| website = "https://fal.ai/dashboard/keys" | |||
| supports_img2img = True | |||
| supports_negative_prompt = True | |||
| default_model = "fal-ai/flux-pro/v1.1" | |||
| available_models = [ | |||
| "fal-ai/flux-pro/v1.1", "fal-ai/flux/dev", "fal-ai/flux/schnell", | |||
| "fal-ai/flux-realism", "fal-ai/fast-sdxl", | |||
| "fal-ai/stable-diffusion-v35-large", "fal-ai/recraft-v3", | |||
| ] | |||
| requires_package = "fal_client" | |||
| max_prompt_length = 10000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| fal_client = _require_import("fal_client", "fal-client") | |||
| os.environ["FAL_KEY"] = self.api_key | |||
| model_id = model or self.default_model | |||
| safe_prompt = self._truncate(prompt) | |||
| arguments = {"prompt": safe_prompt, "image_size": {"width": width, "height": height}, "num_images": 1} | |||
| if negative_prompt: | |||
| arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| arguments["seed"] = seed | |||
| result = fal_client.subscribe(model_id, arguments=arguments) | |||
| images = result.get("images", []) | |||
| if images: | |||
| return self._url_to_image(images[0]["url"]) | |||
| raise ValueError("No image returned from Fal.ai") | |||
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |||
| seed=None, model=None, **kwargs): | |||
| fal_client = _require_import("fal_client", "fal-client") | |||
| os.environ["FAL_KEY"] = self.api_key | |||
| safe_prompt = self._truncate(prompt) | |||
| b64 = self._image_to_base64(image) | |||
| data_uri = "data:image/png;base64," + b64 | |||
| arguments = {"prompt": safe_prompt, "image_url": data_uri, "strength": strength, "num_images": 1} | |||
| if negative_prompt: | |||
| arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| arguments["seed"] = seed | |||
| model_id = model or "fal-ai/flux/dev/image-to-image" | |||
| result = fal_client.subscribe(model_id, arguments=arguments) | |||
| images = result.get("images", []) | |||
| if images: | |||
| return self._url_to_image(images[0]["url"]) | |||
| raise ValueError("No image from Fal.ai img2img") | |||
| # ============================================ | |||
| # 6. GOOGLE GEMINI | |||
| # ============================================ | |||
| class GoogleGeminiProvider(BaseProvider): | |||
| name = "google" | |||
| display_name = "Google Gemini (Imagen 3)" | |||
| website = "https://aistudio.google.com/apikey" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "imagen-3.0-generate-002" | |||
| available_models = ["imagen-3.0-generate-002", "imagen-3.0-fast-generate-001"] | |||
| requires_package = "google.generativeai" | |||
| max_prompt_length = 5000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| genai = _require_import("google.generativeai", "google-generativeai") | |||
| genai.configure(api_key=self.api_key) | |||
| model_id = model or self.default_model | |||
| safe_prompt = self._truncate(prompt) | |||
| imagen = genai.ImageGenerationModel(model_id) | |||
| params = dict(prompt=safe_prompt, number_of_images=1) | |||
| if negative_prompt: | |||
| params["negative_prompt"] = smart_truncate(negative_prompt, 2000) | |||
| response = imagen.generate_images(**params) | |||
| if response.images: | |||
| return response.images[0]._pil_image.convert("RGB") | |||
| raise ValueError("No image returned from Imagen") | |||
| # ============================================ | |||
| # 7. HUGGING FACE INFERENCE API | |||
| # ============================================ | |||
| class HuggingFaceProvider(BaseProvider): | |||
| name = "huggingface" | |||
| display_name = "HuggingFace Inference API" | |||
| website = "https://huggingface.co/settings/tokens" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "black-forest-labs/FLUX.1-schnell" | |||
| available_models = [ | |||
| "black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev", | |||
| "stabilityai/stable-diffusion-xl-base-1.0", | |||
| "stabilityai/stable-diffusion-3.5-large", | |||
| "runwayml/stable-diffusion-v1-5", | |||
| ] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 | |||
| API_BASE = "/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%26quot%3B%3C%2Fspan%3E%3C!----%3E%3C%2Ftd%3E%3C%2Ftr%3E%3Ctr id="L505"> | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| model_id = model or self.default_model | |||
| url = self.API_BASE + "/" + model_id | |||
| headers = {"Authorization": "Bearer " + self.api_key} | |||
| safe_prompt = self._truncate(prompt) | |||
| payload = {"inputs": safe_prompt, "parameters": {"width": width, "height": height}} | |||
| if negative_prompt: | |||
| payload["parameters"]["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| payload["parameters"]["seed"] = seed | |||
| resp = requests.post(url, headers=headers, json=payload, timeout=180) | |||
| if resp.status_code == 503: | |||
| time.sleep(20) | |||
| resp = requests.post(url, headers=headers, json=payload, timeout=180) | |||
| resp.raise_for_status() | |||
| return self._bytes_to_image(resp.content) | |||
| # ============================================ | |||
| # 8. xAI GROK | |||
| # ============================================ | |||
| class XAIProvider(BaseProvider): | |||
| name = "xai" | |||
| display_name = "xAI Grok (Aurora)" | |||
| website = "https://console.x.ai/team/default/api-keys" | |||
| supports_img2img = False | |||
| supports_negative_prompt = False | |||
| default_model = "grok-2-image" | |||
| available_models = ["grok-2-image"] | |||
| requires_package = "openai" | |||
| max_prompt_length = 4000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| openai_mod = _require_import("openai") | |||
| client = openai_mod.OpenAI(api_key=self.api_key, base_url="https://api.x.ai/v1") | |||
| safe_prompt = self._truncate(prompt) | |||
| response = client.images.generate( | |||
| model=model or self.default_model, prompt=safe_prompt, | |||
| n=1, response_format="b64_json", size="1024x1024", | |||
| ) | |||
| return self._base64_to_image(response.data[0].b64_json) | |||
| # ============================================ | |||
| # 9. FIREWORKS AI | |||
| # ============================================ | |||
| class FireworksProvider(BaseProvider): | |||
| name = "fireworks" | |||
| display_name = "Fireworks AI (Flux / SD)" | |||
| website = "https://fireworks.ai/account/api-keys" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "accounts/fireworks/models/flux-1-1-pro" | |||
| available_models = [ | |||
| "accounts/fireworks/models/flux-1-1-pro", | |||
| "accounts/fireworks/models/flux-1-schnell-fp8", | |||
| "accounts/fireworks/models/flux-1-dev-fp8", | |||
| "accounts/fireworks/models/stable-diffusion-xl-1024-v1-0", | |||
| ] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| url = "https://api.fireworks.ai/inference/v1/images/generations" | |||
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |||
| safe_prompt = self._truncate(prompt) | |||
| payload = { | |||
| "model": model or self.default_model, "prompt": safe_prompt, | |||
| "n": 1, "size": str(width) + "x" + str(height), "response_format": "b64_json", | |||
| } | |||
| if negative_prompt: | |||
| payload["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| payload["seed"] = seed | |||
| resp = requests.post(url, headers=headers, json=payload, timeout=120) | |||
| resp.raise_for_status() | |||
| data = resp.json() | |||
| return self._base64_to_image(data["data"][0]["b64_json"]) | |||
| # ============================================ | |||
| # 10. IDEOGRAM | |||
| # ============================================ | |||
| class IdeogramProvider(BaseProvider): | |||
| name = "ideogram" | |||
| display_name = "Ideogram (v2 / v2-turbo)" | |||
| website = "https://ideogram.ai/manage-api" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "V_2" | |||
| available_models = ["V_2", "V_2_TURBO", "V_1", "V_1_TURBO"] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| url = "https://api.ideogram.ai/generate" | |||
| headers = {"Api-Key": self.api_key, "Content-Type": "application/json"} | |||
| safe_prompt = self._truncate(prompt) | |||
| payload = { | |||
| "image_request": { | |||
| "prompt": safe_prompt, "model": model or self.default_model, | |||
| "magic_prompt_option": "AUTO", "aspect_ratio": "ASPECT_1_1", | |||
| } | |||
| } | |||
| if negative_prompt: | |||
| payload["image_request"]["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| payload["image_request"]["seed"] = seed | |||
| resp = requests.post(url, headers=headers, json=payload, timeout=120) | |||
| resp.raise_for_status() | |||
| data = resp.json() | |||
| return self._url_to_image(data["data"][0]["url"]) | |||
| # ============================================ | |||
| # 11. LEONARDO AI | |||
| # ============================================ | |||
| class LeonardoProvider(BaseProvider): | |||
| name = "leonardo" | |||
| display_name = "Leonardo AI" | |||
| website = "https://app.leonardo.ai/api-access" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "6b645e3a-d64f-4341-a6d8-7a3690fbf042" | |||
| available_models = [ | |||
| "6b645e3a-d64f-4341-a6d8-7a3690fbf042", | |||
| "aa77f04e-3eec-4034-9c07-d0f619684628", | |||
| "1e60896f-3c26-4296-8ecc-53e2afecc132", | |||
| ] | |||
| requires_package = "" | |||
| max_prompt_length = 10000 | |||
| API_BASE = "https://cloud.leonardo.ai/api/rest/v1" | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |||
| safe_prompt = self._truncate(prompt) | |||
| payload = { | |||
| "prompt": safe_prompt, "modelId": model or self.default_model, | |||
| "width": width, "height": height, "num_images": 1, | |||
| } | |||
| if negative_prompt: | |||
| payload["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |||
| if seed is not None: | |||
| payload["seed"] = seed | |||
| resp = requests.post(self.API_BASE + "/generations", headers=headers, json=payload, timeout=60) | |||
| resp.raise_for_status() | |||
| gen_id = resp.json()["sdGenerationJob"]["generationId"] | |||
| for _ in range(30): | |||
| time.sleep(5) | |||
| poll = requests.get(self.API_BASE + "/generations/" + gen_id, headers=headers, timeout=30) | |||
| poll.raise_for_status() | |||
| gen = poll.json().get("generations_by_pk", {}) | |||
| if gen.get("status") == "COMPLETE": | |||
| images = gen.get("generated_images", []) | |||
| if images: | |||
| return self._url_to_image(images[0]["url"]) | |||
| raise TimeoutError("Leonardo generation timed out") | |||
| # ============================================ | |||
| # 12. CUSTOM OPENAI-COMPATIBLE | |||
| # ============================================ | |||
| class CustomOpenAIProvider(BaseProvider): | |||
| name = "custom_openai" | |||
| display_name = "Custom OpenAI-Compatible API" | |||
| website = "" | |||
| supports_img2img = False | |||
| supports_negative_prompt = False | |||
| default_model = "dall-e-3" | |||
| available_models = ["dall-e-3", "dall-e-2", "custom"] | |||
| requires_package = "openai" | |||
| max_prompt_length = 3900 | |||
| def __init__(self, api_key="", base_url=""): | |||
| super().__init__(api_key) | |||
| self.base_url = base_url.strip().rstrip("/") if base_url else "" | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| openai_mod = _require_import("openai") | |||
| ck = {"api_key": self.api_key} | |||
| if self.base_url: | |||
| ck["base_url"] = self.base_url | |||
| client = openai_mod.OpenAI(**ck) | |||
| safe_prompt = self._truncate(prompt) | |||
| response = client.images.generate( | |||
| model=model or self.default_model, prompt=safe_prompt, | |||
| n=1, size=str(width) + "x" + str(height), response_format="b64_json", | |||
| ) | |||
| return self._base64_to_image(response.data[0].b64_json) | |||
| # ============================================ | |||
| # 13. DIRECT URL API | |||
| # ============================================ | |||
| class DirectURLProvider(BaseProvider): | |||
| name = "direct_url" | |||
| display_name = "Direct URL API (Any REST Endpoint)" | |||
| website = "" | |||
| supports_img2img = False | |||
| supports_negative_prompt = True | |||
| default_model = "custom" | |||
| available_models = ["custom"] | |||
| requires_package = "" | |||
| max_prompt_length = 50000 | |||
| def __init__(self, api_key="", endpoint_url=""): | |||
| super().__init__(api_key) | |||
| self.endpoint_url = endpoint_url.strip() | |||
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |||
| seed=None, model=None, **kwargs): | |||
| if not self.endpoint_url: | |||
| raise ValueError("No endpoint URL provided") | |||
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |||
| safe_prompt = self._truncate(prompt) | |||
| payload = {"prompt": safe_prompt, "width": width, "height": height} | |||
| if negative_prompt: | |||
| payload["negative_prompt"] = smart_truncate(negative_prompt, 10000) | |||
| if seed is not None: | |||
| payload["seed"] = seed | |||
| if model and model != "custom": | |||
| payload["model"] = model | |||
| resp = requests.post(self.endpoint_url, headers=headers, json=payload, timeout=180) | |||
| resp.raise_for_status() | |||
| ct = resp.headers.get("Content-Type", "") | |||
| if "image" in ct: | |||
| return self._bytes_to_image(resp.content) | |||
| data = resp.json() | |||
| for key in ["images", "data", "output", "result"]: | |||
| if key in data: | |||
| item = data[key] | |||
| if isinstance(item, list): | |||
| item = item[0] | |||
| if isinstance(item, dict): | |||
| for subkey in ["b64_json", "url", "image"]: | |||
| if subkey in item: | |||
| val = item[subkey] | |||
| if isinstance(val, str) and val.startswith("http"): | |||
| return self._url_to_image(val) | |||
| return self._base64_to_image(val) | |||
| if isinstance(item, str): | |||
| if item.startswith("http"): | |||
| return self._url_to_image(item) | |||
| return self._base64_to_image(item) | |||
| raise ValueError("Could not parse image from API response") | |||
| # ============================================ | |||
| # PROVIDER REGISTRY | |||
| # ============================================ | |||
| PROVIDERS = { | |||
| "openai": OpenAIProvider, | |||
| "stability": StabilityProvider, | |||
| "replicate": ReplicateProvider, | |||
| "together": TogetherProvider, | |||
| "fal": FalProvider, | |||
| "google": GoogleGeminiProvider, | |||
| "huggingface": HuggingFaceProvider, | |||
| "xai": XAIProvider, | |||
| "fireworks": FireworksProvider, | |||
| "ideogram": IdeogramProvider, | |||
| "leonardo": LeonardoProvider, | |||
| "custom_openai": CustomOpenAIProvider, | |||
| "direct_url": DirectURLProvider, | |||
| } | |||
| PROVIDER_DISPLAY_NAMES = {cls.display_name: key for key, cls in PROVIDERS.items()} | |||
| def get_provider(provider_name, api_key, **kwargs): | |||
| cls = PROVIDERS.get(provider_name) | |||
| if cls is None: | |||
| raise ValueError("Unknown provider: " + str(provider_name)) | |||
| if provider_name == "custom_openai": | |||
| return cls(api_key=api_key, base_url=kwargs.get("base_url", "")) | |||
| if provider_name == "direct_url": | |||
| return cls(api_key=api_key, endpoint_url=kwargs.get("endpoint_url", "")) | |||
| return cls(api_key=api_key) | |||
| def get_provider_info(): | |||
| info = [] | |||
| for key, cls in PROVIDERS.items(): | |||
| pkg = cls.requires_package | |||
| installed = True | |||
| if pkg: | |||
| installed = _safe_import(pkg.split(".")[0]) is not None | |||
| info.append({ | |||
| "name": key, "display_name": cls.display_name, | |||
| "website": cls.website, "supports_img2img": cls.supports_img2img, | |||
| "default_model": cls.default_model, "available_models": cls.available_models, | |||
| "requires_package": pkg, "package_installed": installed, | |||
| "max_prompt_length": cls.max_prompt_length, | |||
| }) | |||
| return info | |||
| def get_models_for_provider(provider_name): | |||
| cls = PROVIDERS.get(provider_name) | |||
| return cls.available_models if cls else [] | |||