Spaces:
Sleeping
Sleeping
File size: 3,464 Bytes
c64c726 8ff38d6 c64c726 7deb5ff c64c726 7deb5ff c64c726 7deb5ff c64c726 7deb5ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
from .envs import TorchEnv, WorldModelEnv
from .models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
from .models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
from .models.rew_end_model import RewEndModel, RewEndModelConfig
from .utils import extract_state_dict
@dataclass
class AgentConfig:
denoiser: DenoiserConfig
upsampler: Optional[DenoiserConfig]
rew_end_model: Optional[RewEndModelConfig]
actor_critic: Optional[ActorCriticConfig]
num_actions: int
def __post_init__(self) -> None:
self.denoiser.inner_model.num_actions = self.num_actions
if self.upsampler is not None:
self.upsampler.inner_model.num_actions = self.num_actions
if self.rew_end_model is not None:
self.rew_end_model.num_actions = self.num_actions
if self.actor_critic is not None:
self.actor_critic.num_actions = self.num_actions
class Agent(nn.Module):
def __init__(self, cfg: AgentConfig) -> None:
super().__init__()
self.denoiser = Denoiser(cfg.denoiser)
self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None
@property
def device(self):
return self.denoiser.device
def setup_training(
self,
sigma_distribution_cfg: SigmaDistributionConfig,
sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
) -> None:
self.denoiser.setup_training(sigma_distribution_cfg)
if self.upsampler is not None:
self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
if self.actor_critic is not None:
self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
def load(
self,
path_to_ckpt: Path,
load_denoiser: bool = True,
load_upsampler: bool = True,
load_rew_end_model: bool = True,
load_actor_critic: bool = True,
) -> None:
sd = torch.load(Path(path_to_ckpt), map_location=self.device)
self.load_state_dict(sd, load_denoiser, load_upsampler, load_rew_end_model, load_actor_critic)
def load_state_dict(
self,
state_dict: dict,
load_denoiser: bool = True,
load_upsampler: bool = True,
load_rew_end_model: bool = True,
load_actor_critic: bool = True,
) -> None:
"""Load state dict directly without file I/O"""
if load_denoiser:
self.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
if load_upsampler and self.upsampler is not None:
self.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
if load_rew_end_model and self.rew_end_model is not None:
self.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
if load_actor_critic and self.actor_critic is not None:
self.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
|