websearch-exp / server /websearch_tool.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
c493734 verified
# Copyright 2025 Yuan He. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Inspired by: https://github.com/THUDM/slime/tree/main/examples/search-r1
from __future__ import annotations
import asyncio
import random
import aiohttp
import chardet
from models import WebContent, WebSearchAction, WebSearchObservation
class WebSearchTool:
"""A tool for searching the web using Google Search API (via Serper.dev)."""
def __init__(
self,
api_key: str | None = None,
top_k: int = 5,
timeout: int = 60,
snippet_only: bool = False,
proxy: str | None = None,
):
self.api_key = api_key
self.top_k = top_k
self.timeout = timeout
self.snippet_only = snippet_only
self.proxy = proxy
async def execute(self, web_search_action: WebSearchAction) -> WebSearchObservation:
"""
Execute a web search based on the query.
"""
query = web_search_action.query.strip()
try:
web_contents = await self.google_search(
api_key=self.api_key,
query=query,
top_k=self.top_k,
timeout=self.timeout,
snippet_only=self.snippet_only,
)
if web_contents:
return WebSearchObservation(
content=self.format_web_contents(web_contents, query),
web_contents=web_contents,
done=False,
metadata={"query": query},
)
else:
return WebSearchObservation(
content=f"[ERROR] No search results found for query: {query}",
web_contents=[],
done=False,
metadata={"query": query, "error": "No search results found"},
)
except Exception as e:
import traceback
tb_str = traceback.format_exc()
return WebSearchObservation(
content=f"[ERROR] Search failed due to: {str(e)}\nTraceback:\n{tb_str}",
web_contents=[],
done=False,
metadata={"query": query, "error": str(e), "traceback": tb_str},
)
async def google_search(
self,
api_key: str,
query: str,
top_k: int = 5,
timeout: int = 60,
snippet_only: bool = False,
) -> list[WebContent]:
"""
Perform a Google search using Serper.dev API.
Args:
api_key: Serper.dev API key.
query: Search query string.
top_k: Number of results to return.
timeout: Request timeout in seconds.
snippet_only: If `True`, return only snippets; if `False`, fetch full webpage content.
Returns:
list[dict[str, Any]]: List of search results with titles and content.
"""
timeout_obj = aiohttp.ClientTimeout(total=timeout)
session_kwargs = {}
if self.proxy:
session_kwargs["proxy"] = self.proxy
async with aiohttp.ClientSession(**session_kwargs) as session:
async with session.post(
"https://google.serper.dev/search",
json={
"q": query,
"num": top_k,
"gl": "us",
"hl": "en",
},
headers={
"Content-Type": "application/json",
"X-API-KEY": api_key,
},
timeout=timeout_obj,
) as resp:
resp.raise_for_status()
response = await resp.json()
items = response.get("organic", [])
web_contents = []
if snippet_only:
# Quick mode: just use snippets
for item in items:
title = item.get("title", "")
snippet = item.get("snippet", "")
context = " ".join(self.parse_search_snippet(snippet))
if title or context:
title = title or "No title."
context = context or "No snippet available."
web_contents.append(WebContent(title=title, content=context, url=item.get("link", "")))
else:
# Deep mode: fetch full page content
links = [item.get("link", "") for item in items if "link" in item]
raw_contents = await self.fetch_web_contents(links)
for i, item in enumerate(items):
title = item.get("title", "")
snippet = item.get("snippet", "")
# Extract relevant context from the full page
context = self.expand_search_snippet(snippet, raw_contents[i]) if i < len(raw_contents) and raw_contents[i] else snippet
if title or context:
title = title or "No title."
context = context or "No content available."
web_contents.append(WebContent(title=title, content=context, url=item.get("link", "")))
return web_contents
@staticmethod
async def fetch_web_contents(urls: list[str], limit: int = 8) -> list[str]:
"""
Fetch multiple web contents concurrently with rate limiting.
Args:
urls (list[str]): List of URLs to fetch.
limit (int): Maximum concurrent requests.
Returns:
list[str]: List of page contents (empty string for failed requests).
"""
async def _fetch(url: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore) -> str:
if url == "":
return ""
user_agents = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
]
headers = {"User-Agent": random.choice(user_agents)}
async with semaphore:
try:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=10)) as response:
raw = await response.read()
detected = chardet.detect(raw)
encoding = detected.get("encoding") or "utf-8"
return raw.decode(encoding, errors="ignore")
except (aiohttp.ClientError, asyncio.TimeoutError, Exception):
# Silently fail for individual pages
return ""
semaphore = asyncio.Semaphore(limit)
timeout = aiohttp.ClientTimeout(total=10)
connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
tasks = [_fetch(url, session, semaphore) for url in urls]
return await asyncio.gather(*tasks)
@staticmethod
def parse_search_snippet(snippet: str) -> list[str]:
"""
Parse a search snippet into meaningful segments.
Args:
snippet: The snippet text with ellipsis separators.
Returns:
List of text segments with at least 5 words.
"""
segments = snippet.split("...")
return [s.strip() for s in segments if len(s.strip().split()) > 5]
@staticmethod
def expand_search_snippet(snippet: str, web_content: str) -> str:
"""
Finds snippet segments in the web content and expands them to full paragraphs.
Args:
snippet (str): The search snippet with key phrases.
web_content (str): The full web content text.
Returns:
str: The expanded full context of the snippet.
"""
snippets = WebSearchTool.parse_search_snippet(snippet)
ctx_paras = []
for s in snippets:
# Find snippet in document
pos = web_content.replace("\n", " ").find(s)
if pos == -1:
continue
# Expand to paragraph boundaries
sta = pos
while sta > 0 and web_content[sta] != "\n":
sta -= 1
end = pos + len(s)
while end < len(web_content) and web_content[end] != "\n":
end += 1
para = web_content[sta:end].strip()
if para and para not in ctx_paras:
ctx_paras.append(para)
return "\n".join(ctx_paras)
@staticmethod
def format_web_contents(web_contents: list[WebContent], query: str) -> str:
"""
Format search results into a readable string.
Args:
results (list[dict[str, Any]]): List of search result dictionaries.
query (str): Original search query.
Returns:
str: Formatted string representation of results.
"""
lines = [f"Search results for: {query}\n"]
for i, result in enumerate(web_contents, 1):
lines.append(f"[{i}] {result.title}")
lines.append(f" URL: {result.url or 'N/A'}")
lines.append(f" {result.content[:500]}{'...' if len(result.content) > 500 else ''}")
lines.append("")
return "\n".join(lines)