|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
OpenSpielEnv HTTP Client. |
|
|
|
|
|
This module provides the client for connecting to an OpenSpiel Environment server |
|
|
over HTTP. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, Optional, TYPE_CHECKING |
|
|
|
|
|
from openenv_core.client_types import StepResult |
|
|
|
|
|
from openenv_core.http_env_client import HTTPEnvClient |
|
|
|
|
|
from .models import OpenSpielAction, OpenSpielObservation, OpenSpielState |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from openenv_core.containers.runtime import ContainerProvider |
|
|
|
|
|
|
|
|
class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]): |
|
|
""" |
|
|
HTTP client for OpenSpiel Environment. |
|
|
|
|
|
This client connects to an OpenSpielEnvironment HTTP server and provides |
|
|
methods to interact with it: reset(), step(), and state access. |
|
|
|
|
|
Example: |
|
|
>>> # Connect to a running server |
|
|
>>> client = OpenSpielEnv(base_url="http://localhost:8000") |
|
|
>>> result = client.reset() |
|
|
>>> print(result.observation.info_state) |
|
|
>>> |
|
|
>>> # Take an action |
|
|
>>> result = client.step(OpenSpielAction(action_id=1, game_name="catch")) |
|
|
>>> print(result.observation.reward) |
|
|
|
|
|
Example with Docker: |
|
|
>>> # Automatically start container and connect |
|
|
>>> client = OpenSpielEnv.from_docker_image("openspiel-env:latest") |
|
|
>>> result = client.reset() |
|
|
>>> result = client.step(OpenSpielAction(action_id=0)) |
|
|
""" |
|
|
|
|
|
def _step_payload(self, action: OpenSpielAction) -> Dict[str, Any]: |
|
|
""" |
|
|
Convert OpenSpielAction to JSON payload for step request. |
|
|
|
|
|
Args: |
|
|
action: OpenSpielAction instance. |
|
|
|
|
|
Returns: |
|
|
Dictionary representation suitable for JSON encoding. |
|
|
""" |
|
|
return { |
|
|
"action_id": action.action_id, |
|
|
"game_name": action.game_name, |
|
|
"game_params": action.game_params, |
|
|
} |
|
|
|
|
|
def _parse_result( |
|
|
self, payload: Dict[str, Any] |
|
|
) -> StepResult[OpenSpielObservation]: |
|
|
""" |
|
|
Parse server response into StepResult[OpenSpielObservation]. |
|
|
|
|
|
Args: |
|
|
payload: JSON response from server. |
|
|
|
|
|
Returns: |
|
|
StepResult with OpenSpielObservation. |
|
|
""" |
|
|
obs_data = payload.get("observation", {}) |
|
|
|
|
|
observation = OpenSpielObservation( |
|
|
info_state=obs_data.get("info_state", []), |
|
|
legal_actions=obs_data.get("legal_actions", []), |
|
|
game_phase=obs_data.get("game_phase", "playing"), |
|
|
current_player_id=obs_data.get("current_player_id", 0), |
|
|
opponent_last_action=obs_data.get("opponent_last_action"), |
|
|
done=payload.get("done", False), |
|
|
reward=payload.get("reward"), |
|
|
metadata=obs_data.get("metadata", {}), |
|
|
) |
|
|
|
|
|
return StepResult( |
|
|
observation=observation, |
|
|
reward=payload.get("reward"), |
|
|
done=payload.get("done", False), |
|
|
) |
|
|
|
|
|
def _parse_state(self, payload: Dict[str, Any]) -> OpenSpielState: |
|
|
""" |
|
|
Parse server response into OpenSpielState object. |
|
|
|
|
|
Args: |
|
|
payload: JSON response from /state endpoint. |
|
|
|
|
|
Returns: |
|
|
OpenSpielState object with environment state information. |
|
|
""" |
|
|
return OpenSpielState( |
|
|
episode_id=payload.get("episode_id"), |
|
|
step_count=payload.get("step_count", 0), |
|
|
game_name=payload.get("game_name", "unknown"), |
|
|
agent_player=payload.get("agent_player", 0), |
|
|
opponent_policy=payload.get("opponent_policy", "random"), |
|
|
game_params=payload.get("game_params", {}), |
|
|
num_players=payload.get("num_players", 1), |
|
|
) |
|
|
|