from __future__ import annotations from typing import TYPE_CHECKING, AsyncIterator import gradio as gr from langchain.chat_models import init_chat_model from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langgraph.prebuilt import create_react_agent if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langgraph.graph.graph import CompiledGraph MESSAGE_TYPE = BaseMessage | gr.ChatMessage | dict[str, str] def create_agent( model_name: str, provider: str, api_key: str, tools: list ) -> CompiledGraph: """Create a React agent with the specified model.""" model = _create_model(model_name, provider, api_key) return create_react_agent( model, tools=tools, ) async def call_agent( agent: CompiledGraph, messages: list[MESSAGE_TYPE], prompt: HumanMessage ) -> AsyncIterator[list[MESSAGE_TYPE]]: async for chunk in agent.astream( { "messages": [_convert_to_langchain_message(msg) for msg in messages[:-1]] + [prompt] } ): if "tools" in chunk: for step in chunk["tools"]["messages"]: messages.append( gr.ChatMessage( role="assistant", content=step.content, metadata={"title": f"🛠️ Used tool {step.name}"}, ) ) yield messages if "agent" in chunk: messages.append( gr.ChatMessage( role="assistant", content=_get_chunk_message_content(chunk), ) ) yield messages def _create_model(model_name: str, provider: str, api_key: str) -> BaseChatModel: """Get the chat model based on the provider and model name.""" if provider == "Anthropic": return init_chat_model( "anthropic:" + model_name, anthropic_api_key=api_key, ) elif provider == "Mistral": return init_chat_model( "mistralai:" + model_name, mistral_api_key=api_key, ) elif provider == "OpenAI": return init_chat_model( "openai:" + model_name, openai_api_key=api_key, ) else: raise ValueError(f"Unsupported model provider: {provider}") def _is_ai_message(message: MESSAGE_TYPE) -> bool: if isinstance(message, AIMessage): return True if isinstance(message, gr.ChatMessage): return message.role == "assistant" if isinstance(message, dict): return message.get("role") == "assistant" return False def _convert_to_langchain_message(message: MESSAGE_TYPE) -> BaseMessage: if isinstance(message, BaseMessage): return message if isinstance(message, gr.ChatMessage): return ( AIMessage(content=message.content) if _is_ai_message(message) else HumanMessage(content=message.content) ) if isinstance(message, dict): return ( AIMessage(content=message.get("content", "")) if _is_ai_message(message) else HumanMessage(content=message.get("content", "")) ) raise ValueError(f"Unsupported message type: {type(message)}") def _get_chunk_message_content(chunk: dict) -> str: msg_object = chunk["agent"]["messages"][0] message = msg_object.content if isinstance(message, list): message = message[0] if message else "" if isinstance(message, dict): message = message.get("text") return message or "Calling tool(s)"