| import sys | |
| import time | |
| import torch | |
| import unittest | |
| import basetest | |
| from greedrl import Solver | |
| from greedrl.function import * | |
| device = Solver().device | |
| class TestFunction(basetest.TestCase): | |
| def test_task_group_split(self): | |
| group = torch.ones((8, 8), dtype=torch.int32) | |
| group[:, 0:4] = 0 | |
| value = torch.zeros((8, 8), dtype=torch.bool) | |
| value[:, 0:4] = True | |
| result = task_group_split(group, value) | |
| assert not torch.any(result) | |
| value[:, 0:2] = False | |
| result = task_group_split(group, value) | |
| assert torch.all(result) | |
| def test_task_group_split2(self): | |
| group = torch.randint(48, (1024, 1000), dtype=torch.int32) | |
| value = torch.randint(2, (1024, 1000), dtype=torch.int8) <= 0 | |
| self.do_test(task_group_split, group, value) | |
| def test_task_group_priority(self): | |
| group = torch.ones((8, 8), dtype=torch.int32) | |
| group[:, 0:4] = 0 | |
| priority = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32) | |
| priority = priority[None, :].expand(8, -1).clone() | |
| value = torch.zeros((8, 8), dtype=torch.bool) | |
| value[:, 4:6] = True | |
| result = task_group_priority(group, priority, value) | |
| expected = torch.tensor([False, True, True, True, True, True, False, True]) | |
| expected = expected[None, :].expand(8, -1) | |
| assert torch.all(result == expected) | |
| def test_task_group_priority2(self): | |
| group = torch.randint(48, (1024, 1000), dtype=torch.int32) | |
| value = torch.randint(2, (1024, 1000), dtype=torch.int8) < 1 | |
| priority = torch.randint(2, (1024, 1000), dtype=torch.int32) | |
| self.do_test(task_group_priority, group, priority, value) | |
| def do_test(self, function, *args): | |
| print("\ntest {} ...".format(function.__name__)) | |
| start = time.time() | |
| result1 = function(*args) | |
| print("time: {:.6f}s, device: {}".format(time.time() - start, args[0].device)) | |
| args = [arg.to(device) for arg in args] | |
| result1 = result1.to(device) | |
| function(*args) | |
| self.sync_device(device) | |
| start = time.time() | |
| result2 = function(*args) | |
| self.sync_device(device) | |
| print("time: {:.6f}s, device: {} ".format(time.time() - start, args[0].device)) | |
| if result1.is_floating_point(): | |
| assert torch.all(torch.abs(result1 - result2) < 1e-6) | |
| else: | |
| assert torch.all(result1 == result2) | |
| def sync_device(self, device): | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize(device) | |
| if __name__ == '__main__': | |
| unittest.main() | |