huggingface_ai_final / vision_tools.py
alfulanny's picture
Create vision_tools.py
0106d0b verified
"""
Vision tools: image captioning using Hugging Face Inference API with a local fallback.
Functions:
- `caption_image(path)`: returns a short caption for an image file.
"""
from typing import Optional
import os
import logging
logger = logging.getLogger(__name__)
try:
from huggingface_hub import InferenceApi
except Exception:
InferenceApi = None
try:
from transformers import pipeline
except Exception:
pipeline = None
try:
from smolagents import tool
except Exception:
tool = None
def caption_image(path: str, model: str = "nlpconnect/vit-gpt2-image-captioning") -> str:
"""Caption an image at `path` using HF Inference API when possible, otherwise local pipeline if available.
Returns a short textual caption or an error string.
"""
if not os.path.exists(path):
return "(file not found)"
hf_token = os.environ.get("HF_TOKEN")
# Try Inference API first
if hf_token and InferenceApi is not None:
try:
client = InferenceApi(repo_id=model, token=hf_token)
with open(path, "rb") as f:
out = client(inputs=f)
# InferenceApi for image-to-text may return text or structure
if isinstance(out, dict) and "generated_text" in out:
return out["generated_text"].strip()
if isinstance(out, list) and len(out) > 0:
first = out[0]
if isinstance(first, dict) and "generated_text" in first:
return first["generated_text"].strip()
return str(first)
return str(out)
except Exception as e:
logger.warning("HF Inference image captioning failed: %s", e)
# Local pipeline fallback (may not be installed or suitable for large models)
if pipeline is not None:
try:
pipe = pipeline("image-to-text", model=model)
res = pipe(path)
if isinstance(res, list) and len(res) > 0:
return res[0].get("generated_text", str(res[0]))
return str(res)
except Exception as e:
logger.warning("Local pipeline image captioning failed: %s", e)
return "(image captioning unavailable)"
# Export a smolagents-wrapped tool if possible
if tool is not None:
try:
@tool
def caption_image_tool(path: str, model: str = "nlpconnect/vit-gpt2-image-captioning") -> str:
return caption_image(path, model=model)
except Exception:
caption_image_tool = caption_image
else:
caption_image_tool = caption_image
__all__ = ["caption_image", "caption_image_tool"]