| #include <cstdint> |
| #include <c10/util/Half.h> |
| |
| #include <c10/cuda/CUDAStream.h> |
|
|
| #define CUDA_CALL(code) \ |
| do { \ |
| cudaError_t status = code; \ |
| std::string err = cudaGetErrorString(status); \ |
| TORCH_CHECK(status == cudaSuccess, err); \ |
| } while (0) |
|
|
| namespace megablocks { |
| namespace construct_indices { |
|
|
| |
| |
| |
| const int kThreadsPerBlock = 32; |
|
|
| __global__ void __launch_bounds__(kThreadsPerBlock) |
| ConstructIndicesKernel(short * __restrict__ indices, |
| int num_columns, |
| int block_size, |
| const int * __restrict__ padded_bins) { |
| |
| int start = 0; |
| if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1); |
| int end = __ldg(padded_bins + blockIdx.x); |
|
|
| |
| start /= block_size; |
| end /= block_size; |
|
|
| |
| indices += (start + blockIdx.y) * num_columns + threadIdx.x; |
|
|
| |
| int bin_offset = blockIdx.y; |
| int num_rows = end - start; |
| for (; bin_offset < num_rows; num_rows -= gridDim.y) { |
| short *out = indices; |
| for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) { |
| *out = bid + (blockIdx.x * num_columns); |
| out += kThreadsPerBlock; |
| } |
| indices += gridDim.y * num_columns; |
| } |
| } |
|
|
| cudaError_t ConstructIndices(short * __restrict__ indices, |
| int output_block_rows, |
| int output_block_columns, |
| int block_size, |
| const int * __restrict__ padded_bins, |
| int num_bins, |
| cudaStream_t stream) { |
| dim3 block_dim(kThreadsPerBlock); |
| dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins)); |
| ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices, |
| output_block_columns, |
| block_size, |
| padded_bins); |
| return cudaGetLastError(); |
| } |
|
|
| } |
|
|
| void indices(torch::Tensor padded_bins, |
| int block_size, |
| int output_block_rows, |
| int output_block_columns, |
| torch::Tensor out) { |
| TORCH_CHECK(padded_bins.is_cuda()); |
| TORCH_CHECK(padded_bins.ndimension() == 1); |
| TORCH_CHECK(padded_bins.scalar_type() == torch::kInt); |
|
|
| TORCH_CHECK(out.is_cuda()); |
| TORCH_CHECK(out.ndimension() == 1); |
| TORCH_CHECK(out.scalar_type() == torch::kInt16); |
| TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns)); |
|
|
| |
| if (out.numel() == 0) return; |
|
|
| CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(), |
| output_block_rows, |
| output_block_columns, |
| block_size, |
| padded_bins.data_ptr<int>(), |
| padded_bins.numel(), |
| c10::cuda::getCurrentCUDAStream())); |
| } |
|
|
| } |
|
|