File size: 3,464 Bytes
c64c726
 
 
 
 
 
 
8ff38d6
 
 
 
 
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7deb5ff
 
 
 
 
 
 
 
 
 
 
c64c726
7deb5ff
 
 
c64c726
7deb5ff
c64c726
7deb5ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

import torch
import torch.nn as nn

from .envs import TorchEnv, WorldModelEnv
from .models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
from .models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
from .models.rew_end_model import RewEndModel, RewEndModelConfig
from .utils import extract_state_dict


@dataclass
class AgentConfig:
    denoiser: DenoiserConfig
    upsampler: Optional[DenoiserConfig]
    rew_end_model: Optional[RewEndModelConfig]
    actor_critic: Optional[ActorCriticConfig]
    num_actions: int

    def __post_init__(self) -> None:
        self.denoiser.inner_model.num_actions = self.num_actions
        if self.upsampler is not None:
            self.upsampler.inner_model.num_actions = self.num_actions
        if self.rew_end_model is not None:
            self.rew_end_model.num_actions = self.num_actions
        if self.actor_critic is not None:
            self.actor_critic.num_actions = self.num_actions


class Agent(nn.Module):
    def __init__(self, cfg: AgentConfig) -> None:
        super().__init__()
        self.denoiser = Denoiser(cfg.denoiser)
        self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
        self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
        self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None

    @property
    def device(self):
        return self.denoiser.device

    def setup_training(
        self,
        sigma_distribution_cfg: SigmaDistributionConfig,
        sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
        actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
        rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
    ) -> None:
        self.denoiser.setup_training(sigma_distribution_cfg)
        if self.upsampler is not None:
            self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
        if self.actor_critic is not None:
            self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)

    def load(
        self,
        path_to_ckpt: Path,
        load_denoiser: bool = True,
        load_upsampler: bool = True,
        load_rew_end_model: bool = True,
        load_actor_critic: bool = True,
    ) -> None:
        sd = torch.load(Path(path_to_ckpt), map_location=self.device)
        self.load_state_dict(sd, load_denoiser, load_upsampler, load_rew_end_model, load_actor_critic)

    def load_state_dict(
        self,
        state_dict: dict,
        load_denoiser: bool = True,
        load_upsampler: bool = True,
        load_rew_end_model: bool = True,
        load_actor_critic: bool = True,
    ) -> None:
        """Load state dict directly without file I/O"""
        if load_denoiser:
            self.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
        if load_upsampler and self.upsampler is not None:
            self.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
        if load_rew_end_model and self.rew_end_model is not None:
            self.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
        if load_actor_critic and self.actor_critic is not None:
            self.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))