czakop commited on
Commit
70d877e
·
1 Parent(s): 244ad23

create chess agent

Browse files
README.md CHANGED
@@ -1,6 +1,7 @@
1
  ---
 
2
  title: Chess Agent
3
- emoji: 🏆
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
@@ -10,5 +11,3 @@ pinned: false
10
  license: apache-2.0
11
  short_description: Chess Agent
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ tags: [agents, agent-demo-track, chess, chessboard, games]
3
  title: Chess Agent
4
+ emoji: ♟️
5
  colorFrom: indigo
6
  colorTo: purple
7
  sdk: gradio
 
11
  license: apache-2.0
12
  short_description: Chess Agent
13
  ---
 
 
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ import chess
5
+ import gradio as gr
6
+ from gradio_chessboard import Chessboard
7
+ from langchain_core.messages import HumanMessage
8
+ from utils.helpers import call_agent, create_agent
9
+ from utils.tools import create_base_tools, create_mcp_tools
10
+
11
+ MODEL_DEFAULTS = {
12
+ "anthropic": {
13
+ "model_name": "claude-sonnet-4-20250514",
14
+ "provider": "anthropic",
15
+ "api_key": os.getenv("ANTHROPIC_API_KEY", ""),
16
+ },
17
+ "mistral": {
18
+ "model_name": "mistral-large-latest",
19
+ "provider": "mistralai",
20
+ "api_key": os.getenv("MISTRAL_API_KEY", ""),
21
+ },
22
+ "openai": {
23
+ "model_name": "gpt-4o",
24
+ "provider": "openai",
25
+ "api_key": os.getenv("OPENAI_API_KEY", ""),
26
+ },
27
+ }
28
+
29
+
30
+ async def main():
31
+ board = chess.Board()
32
+
33
+ base_tools = create_base_tools(board)
34
+ mcp_tools = await create_mcp_tools(
35
+ url="https://czakop-chess-agent-mcp.hf.space/gradio_api/mcp/sse",
36
+ transport="sse",
37
+ )
38
+
39
+ async def chat_entrypoint(prompt, messages, model_name, model_provider, api_key):
40
+ """Entrypoint for the chat interaction."""
41
+ messages.append(gr.ChatMessage(role="user", content=prompt))
42
+ yield messages, board.fen()
43
+
44
+ real_prompt = HumanMessage(
45
+ content=f"{prompt}\nCurrent board state: {board.fen()}"
46
+ )
47
+ agent = create_agent(
48
+ model_name, model_provider, api_key, base_tools + mcp_tools
49
+ )
50
+ async for messages in call_agent(agent, messages, real_prompt):
51
+ yield messages, board.fen()
52
+
53
+ async def move_entrypoint(messages, fen, model_name, model_provider, api_key):
54
+ """Entrypoint for the chess move interaction."""
55
+ board.set_fen(fen)
56
+ messages.append(
57
+ gr.ChatMessage(
58
+ role="user",
59
+ content="Your turn! Make a move.",
60
+ )
61
+ )
62
+ yield messages, board.fen()
63
+
64
+ real_prompt = HumanMessage(
65
+ content=f"Make a move with {'white' if board.turn == chess.WHITE else 'black'}, current board state: {fen}. Use can use tools to analyze the position."
66
+ )
67
+ agent = create_agent(
68
+ model_name, model_provider, api_key, base_tools + mcp_tools
69
+ )
70
+ async for messages in call_agent(agent, messages, real_prompt):
71
+ yield messages, board.fen()
72
+
73
+ with gr.Blocks(fill_height=True) as chessagent:
74
+ gr.Markdown("# Play Chess with an AI Agent ♔ and see its thoughts 💭")
75
+
76
+ with gr.Row():
77
+ with gr.Column(min_width=500):
78
+ with gr.Row():
79
+ model_provider = gr.Dropdown(
80
+ choices=["Anthropic", "Mistral", "OpenAI"],
81
+ value="Anthropic",
82
+ label="Model Provider",
83
+ type="value",
84
+ interactive=True,
85
+ allow_custom_value=False,
86
+ )
87
+ model_name = gr.Textbox(
88
+ value=MODEL_DEFAULTS["anthropic"]["model_name"],
89
+ label="Model Name",
90
+ placeholder="Enter model name",
91
+ interactive=True,
92
+ )
93
+ api_key = gr.Textbox(
94
+ value=MODEL_DEFAULTS["anthropic"]["api_key"],
95
+ label="API Key",
96
+ placeholder="Enter your API key",
97
+ type="password",
98
+ interactive=True,
99
+ )
100
+ model_provider.change(
101
+ fn=lambda provider: MODEL_DEFAULTS[provider.lower()][
102
+ "model_name"
103
+ ],
104
+ inputs=model_provider,
105
+ outputs=model_name,
106
+ )
107
+ model_provider.change(
108
+ fn=lambda provider: MODEL_DEFAULTS[provider.lower()]["api_key"],
109
+ inputs=model_provider,
110
+ outputs=api_key,
111
+ )
112
+
113
+ board_component = Chessboard(game_mode=True, label="Chess Board")
114
+
115
+ chatbot = gr.Chatbot(
116
+ type="messages",
117
+ label="Chess Agent",
118
+ avatar_images=(
119
+ "https://chessboardjs.com/img/chesspieces/wikipedia/wK.png",
120
+ "https://chessboardjs.com/img/chesspieces/wikipedia/bK.png",
121
+ ),
122
+ min_height=650,
123
+ render=False,
124
+ )
125
+
126
+ board_component.move(
127
+ fn=move_entrypoint,
128
+ inputs=[
129
+ chatbot,
130
+ board_component,
131
+ model_name,
132
+ model_provider,
133
+ api_key,
134
+ ],
135
+ outputs=[chatbot, board_component],
136
+ )
137
+
138
+ with gr.Column():
139
+ chatbot.render()
140
+ input_box = gr.Textbox(lines=1, label="Chat Message")
141
+ input_box.submit(
142
+ fn=chat_entrypoint,
143
+ inputs=[input_box, chatbot, model_name, model_provider, api_key],
144
+ outputs=[chatbot, board_component],
145
+ )
146
+ input_box.submit(lambda: "", None, [input_box], queue=False)
147
+
148
+ chessagent.launch()
149
+
150
+
151
+ if __name__ == "__main__":
152
+ asyncio.run(main())
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ gradio_chessboard
3
+ chess
4
+ langchain
5
+ langchain-mcp-adapters
6
+ langchain-anthropic
7
+ langchain-mistralai
8
+ langchain-openai
9
+ langgraph
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (154 Bytes). View file
 
utils/__pycache__/helpers.cpython-312.pyc ADDED
Binary file (4.84 kB). View file
 
utils/__pycache__/tools.cpython-312.pyc ADDED
Binary file (3 kB). View file
 
utils/helpers.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, AsyncIterator
4
+
5
+ import gradio as gr
6
+ from langchain.chat_models import init_chat_model
7
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
8
+ from langgraph.prebuilt import create_react_agent
9
+
10
+ if TYPE_CHECKING:
11
+ from langchain_core.language_models.chat_models import BaseChatModel
12
+ from langgraph.graph.graph import CompiledGraph
13
+
14
+ MESSAGE_TYPE = BaseMessage | gr.ChatMessage | dict[str, str]
15
+
16
+
17
+ def create_agent(
18
+ model_name: str, provider: str, api_key: str, tools: list
19
+ ) -> CompiledGraph:
20
+ """Create a React agent with the specified model."""
21
+ model = _create_model(model_name, provider, api_key)
22
+ return create_react_agent(
23
+ model,
24
+ tools=tools,
25
+ )
26
+
27
+
28
+ async def call_agent(
29
+ agent: CompiledGraph, messages: list[MESSAGE_TYPE], prompt: HumanMessage
30
+ ) -> AsyncIterator[list[MESSAGE_TYPE]]:
31
+ async for chunk in agent.astream(
32
+ {
33
+ "messages": [_convert_to_langchain_message(msg) for msg in messages[:-1]]
34
+ + [prompt]
35
+ }
36
+ ):
37
+ print(f"Chunk received: {chunk}")
38
+ if "tools" in chunk:
39
+ for step in chunk["tools"]["messages"]:
40
+ messages.append(
41
+ gr.ChatMessage(
42
+ role="assistant",
43
+ content=step.content,
44
+ metadata={"title": f"🛠️ Used tool {step.name}"},
45
+ )
46
+ )
47
+ yield messages
48
+ if "agent" in chunk:
49
+ messages.append(
50
+ gr.ChatMessage(
51
+ role="assistant",
52
+ content=_get_chunk_message_content(chunk),
53
+ )
54
+ )
55
+ yield messages
56
+
57
+
58
+ def _create_model(model_name: str, provider: str, api_key: str) -> BaseChatModel:
59
+ """Get the chat model based on the provider and model name."""
60
+ if provider == "Anthropic":
61
+ return init_chat_model(
62
+ "anthropic:" + model_name,
63
+ anthropic_api_key=api_key,
64
+ )
65
+ elif provider == "Mistral":
66
+ return init_chat_model(
67
+ "mistralai:" + model_name,
68
+ mistral_api_key=api_key,
69
+ )
70
+ elif provider == "OpenAI":
71
+ return init_chat_model(
72
+ "openai:" + model_name,
73
+ openai_api_key=api_key,
74
+ )
75
+ else:
76
+ raise ValueError(f"Unsupported model provider: {provider}")
77
+
78
+
79
+ def _is_ai_message(message: MESSAGE_TYPE) -> bool:
80
+ if isinstance(message, AIMessage):
81
+ return True
82
+ if isinstance(message, gr.ChatMessage):
83
+ return message.role == "assistant"
84
+ if isinstance(message, dict):
85
+ return message.get("role") == "assistant"
86
+ return False
87
+
88
+
89
+ def _convert_to_langchain_message(message: MESSAGE_TYPE) -> BaseMessage:
90
+ if isinstance(message, BaseMessage):
91
+ return message
92
+ if isinstance(message, gr.ChatMessage):
93
+ return (
94
+ AIMessage(content=message.content)
95
+ if _is_ai_message(message)
96
+ else HumanMessage(content=message.content)
97
+ )
98
+ if isinstance(message, dict):
99
+ return (
100
+ AIMessage(content=message.get("content", ""))
101
+ if _is_ai_message(message)
102
+ else HumanMessage(content=message.get("content", ""))
103
+ )
104
+ raise ValueError(f"Unsupported message type: {type(message)}")
105
+
106
+
107
+ def _get_chunk_message_content(chunk: dict) -> str:
108
+ msg_object = chunk["agent"]["messages"][0]
109
+ message = msg_object.content
110
+ if isinstance(message, list):
111
+ message = message[0] if message else ""
112
+ if isinstance(message, dict):
113
+ message = message.get("text")
114
+ return message or "Calling tool(s)"
utils/tools.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ from langchain_core.tools import BaseTool, tool
3
+ from langchain_mcp_adapters.client import MultiServerMCPClient
4
+
5
+
6
+ def create_base_tools(board: chess.Board) -> list[BaseTool]:
7
+ """Create tools for interacting with a chess board.
8
+
9
+ Args:
10
+ board (chess.Board): The chess board to interact with.
11
+
12
+ Returns:
13
+ list[BaseTool]: A list of tools for interacting with the chess board.
14
+ """
15
+
16
+ @tool
17
+ def get_fen() -> str:
18
+ """Get the current FEN string of the chess board."""
19
+ return board.fen()
20
+
21
+ @tool
22
+ def set_fen(fen: str) -> str:
23
+ """Set the chess board to a specific FEN string.
24
+ Don't use when you are playing a game, use the `make_move` tool instead.
25
+
26
+ Args:
27
+ fen (str): The FEN string to set the board to.
28
+ """
29
+ try:
30
+ board.set_fen(fen)
31
+ return board.fen()
32
+ except ValueError as e:
33
+ return str(e)
34
+
35
+ @tool
36
+ def make_move(move: str) -> str:
37
+ """Make a move on the chess board and return the new FEN string.
38
+
39
+ Args:
40
+ move (str): The move in UCI format (e.g., "e2e4").
41
+ """
42
+ try:
43
+ chess_move = chess.Move.from_uci(move)
44
+ if chess_move in board.legal_moves:
45
+ board.push(chess_move)
46
+ return board.fen()
47
+ else:
48
+ return "Illegal move"
49
+ except Exception as e:
50
+ return str(e)
51
+
52
+ return [
53
+ get_fen,
54
+ set_fen,
55
+ make_move,
56
+ ]
57
+
58
+
59
+ async def create_mcp_tools(
60
+ url: str = "http://localhost:7860/gradio_api/mcp/sse", transport: str = "sse"
61
+ ) -> list[BaseTool]:
62
+ mcp_client = MultiServerMCPClient(
63
+ {
64
+ "chess": {
65
+ "url": url,
66
+ "transport": transport,
67
+ }
68
+ }
69
+ )
70
+
71
+ return await mcp_client.get_tools()