File size: 2,574 Bytes
b5b0a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642c210
b5b0a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for OpenSpiel Environment.

This module defines the Action, Observation, and State types for OpenSpiel games.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from openenv_core.env_server import Action, Observation, State


@dataclass
class OpenSpielAction(Action):
    """
    Action for OpenSpiel environments.

    Attributes:
        action_id: The integer action ID to take (from legal_actions).
        game_name: Name of the OpenSpiel game (e.g., "catch", "tic_tac_toe").
        game_params: Optional game-specific parameters (e.g., {"rows": 8, "columns": 6}).
    """
    action_id: int
    game_name: str = "catch"
    game_params: Dict[str, Any] = field(default_factory=dict)


@dataclass
class OpenSpielObservation(Observation):
    """
    Observation from OpenSpiel environment.

    This represents what the agent sees after taking an action.
    For single-player games, this is straightforward.
    For multi-player games, this is from the perspective of the agent player.

    Attributes:
        info_state: Information state tensor (list of floats) for the agent.
                   This contains all information available to the agent.
        legal_actions: List of legal action IDs the agent can take.
        game_phase: String describing the current phase (e.g., "playing", "terminal").
        current_player_id: ID of the current player (-1 for simultaneous, player ID otherwise).
        opponent_last_action: Last action taken by opponent (if available, None otherwise).
    """
    info_state: List[float]
    legal_actions: List[int]
    game_phase: str = "playing"
    current_player_id: int = 0
    opponent_last_action: Optional[int] = None


@dataclass
class OpenSpielState(State):
    """
    State for OpenSpiel environment.

    Attributes:
        game_name: Name of the OpenSpiel game.
        agent_player: Which player ID the agent controls (0 by default).
        opponent_policy: Name of the opponent policy ("random", "fixed", etc.).
        game_params: Game-specific parameters.
        num_players: Total number of players in the game.
    """
    game_name: str = "catch"
    agent_player: int = 0
    opponent_policy: str = "random"
    game_params: Dict[str, Any] = field(default_factory=dict)
    num_players: int = 1