File size: 2,729 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Generator, List, Optional

import numpy as np
import torch

from .dataset import CSGOHdf5Dataset, Dataset
from .segment import SegmentId


class BatchSampler(torch.utils.data.Sampler):
    def __init__(
        self,
        dataset: Dataset,
        rank: int,
        world_size: int,
        batch_size: int,
        seq_length: int,
        sample_weights: Optional[List[float]] = None,
        can_sample_beyond_end: bool = False,
    ) -> None:
        super().__init__(dataset)
        assert isinstance(dataset, (Dataset, CSGOHdf5Dataset))
        self.dataset = dataset
        self.rank = rank
        self.world_size = world_size
        self.sample_weights = sample_weights
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.can_sample_beyond_end = can_sample_beyond_end

    def __len__(self):
        raise NotImplementedError

    def __iter__(self) -> Generator[List[SegmentId], None, None]:
        while True:
            yield self.sample()

    def sample(self) -> List[SegmentId]:
        num_episodes = self.dataset.num_episodes

        if (self.sample_weights is None) or num_episodes < len(self.sample_weights):
            weights = self.dataset.lengths / self.dataset.num_steps
        else:
            weights = self.sample_weights
            num_weights = len(self.sample_weights)
            assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
            sizes = [
                num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1)
                for i in range(num_weights)
            ]
            weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]

        episodes_partition = np.arange(self.rank, num_episodes, self.world_size)
        weights = np.array(weights[self.rank::self.world_size])
        max_eps = self.batch_size
        episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum())
        episode_ids = episode_ids.repeat(self.batch_size // max_eps)
        timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])

        # padding allowed, both before start and after end
        if self.can_sample_beyond_end:
            starts = timesteps - np.random.randint(0, self.seq_length, len(timesteps))
            stops = starts + self.seq_length

        # padding allowed only before start
        else:
            stops = np.minimum(
                self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.seq_length, len(timesteps))
            )
            starts = stops - self.seq_length

        return [SegmentId(*x) for x in zip(episode_ids, starts, stops)]