Kernels

Activation

Activation is a python package that contains custom CUDA-based activation kernels, primarily targeting AMD GPUs.

  • Currently implemented
    • PolyNorm

    • RMSNorm

    • FusedAddRMSNorm

      A fused operator that combines residual addition (x + residual) with RMSNorm in a single kernel.

      • Instead of:

        y = x + residual
        hidden_state = rms_norm(y, weight, eps)
        out = y + some_op(hidden_state)
        
      • Fused as:

        hidden_state, y = fused_add_rms_norm(x, residual, weight, eps)
        out = y + some_op(hidden_state)
        
    • FusedMulPolyNorm

      A fused operator that combines PolyNorm with an element-wise multiplication by a Tensor.

      • Instead of:

        y = poly_norm(x, weight, bias, eps)
        out = y * a
        
      • Fused as:

        out = fused_mul_poly_norm(x, a, weight, bias, eps)
        
    • GroupedFusedMulPolyNorm (Triton)

      A Triton-accelerated grouped variant of FusedMulPolyNorm for MoE (Mixture of Experts) models. Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), with per-expert weights/bias and in-kernel binary search for expert mapping.

      • Instead of:

        for i, expert in enumerate(experts):
            out[start:end] = fused_mul_poly_norm(x[start:end], mul[start:end], weight[i], bias[i], eps)
        
      • Fused as:

        out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps)
        

Usage

import torch
from kernels import get_kernel

activation = get_kernel("motif-technologies/activation")

torch.set_default_device("cuda")
poly_norm = activation.layers.PolyNorm(eps=1e-6)
x = torch.randn(10, 10)

print(poly_norm(x))

Performance

  • Test cases are from the Motif LLM
  • The results can be reproduced using the provided benchmarking tools.
  • For details on how to use the benchmarking tools, please refer to the benchmarks README.
  • The benchmark results may show fluctuations, especially in the backward pass and when the dimension size is small.

RMSNorm

H100 Results

Forward Performance

RMSNorm Forward Performance

Backward Performance

RMSNorm Backward Performance

MI250 Results

Forward Performance

RMSNorm Forward Performance

Backward Performance

RMSNorm Backward Performance


FusedAddRMSNorm

For fusion case performance, the non-fused baseline was implemented with our custom kernels.

H100 Results

Forward Performance

FusedAddRMSNorm Forward Performance

Backward Performance

FusedAddRMSNorm Backward Performance

MI250 Results

Forward Performance

FusedAddRMSNorm Forward Performance

Backward Performance

FusedAddRMSNorm Backward Performance


PolyNorm

H100 Results

Forward Performance

PolyNorm Forward Performance

Backward Performance

PolyNorm Backward Performance

MI250 Results

Forward Performance

PolyNorm Forward Performance

Backward Performance

PolyNorm Backward Performance


FusedMulPolyNorm

For fusion case performance, the non-fused baseline was implemented with our custom kernels.

H100 Results

Forward Performance

FusedMulPolyNorm Forward Performance

Backward Performance

FusedMulPolyNorm Backward Performance

MI250 Results

Forward Performance

FusedMulPolyNorm Forward Performance

Backward Performance

FusedMulPolyNorm Backward Performance


GroupedFusedMulPolyNorm (Triton)

This kernel is implemented in Triton (JIT-compiled, no CUDA C++ build required). Benchmarks compare three variants: Naive (raw PyTorch reference), Compiled (torch.compile'd reference), and Triton (fused Triton kernel). Benchmark dimension: 1280, 384 experts.

B200 Results (bf16)

Forward Performance
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive
1 1024 294.54 73.46 64.33 4.58x
1 2048 373.50 94.88 65.26 5.72x
1 4096 372.65 94.90 66.90 5.57x
1 8192 486.98 102.33 72.71 6.70x
2 4096 486.66 101.87 72.27 6.73x
2 8192 950.62 106.96 90.06 10.56x
4 4096 950.72 107.17 71.28 13.34x
4 8192 1779.12 198.91 96.93 18.35x
8 4096 1778.73 199.10 96.88 18.36x
8 8192 3384.03 381.91 179.57 18.85x
Backward Performance
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive
1 1024 1690.61 999.66 1017.66 1.66x
1 8192 1680.39 906.43 906.41 1.85x
2 8192 2466.73 870.74 862.78 2.86x
4 4096 2466.04 942.62 945.68 2.61x
4 8192 4543.10 941.01 908.30 5.00x
8 4096 4542.91 814.73 900.01 5.05x
8 8192 8599.41 956.81 955.07 9.00x
Forward + Backward Combined
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive Triton vs Compiled
1 1024 1985.15 1073.12 1081.99 1.83x 0.99x
1 4096 2085.10 974.32 960.73 2.17x 1.01x
1 8192 2167.37 1008.76 979.12 2.21x 1.03x
2 4096 2083.49 1001.03 965.30 2.16x 1.04x
2 8192 3417.35 977.70 952.84 3.59x 1.03x
4 4096 3416.76 1049.79 1016.97 3.36x 1.03x
4 8192 6322.22 1139.92 1005.23 6.29x 1.13x
8 4096 6321.64 1013.83 996.89 6.34x 1.02x
8 8192 11983.44 1338.71 1134.64 10.56x 1.18x

B200 Results (fp32)

Forward Performance
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive
1 1024 318.05 83.29 64.24 4.95x
1 2048 311.14 95.19 63.64 4.89x
1 8192 401.78 101.61 68.21 5.89x
2 4096 403.42 100.97 68.01 5.93x
2 8192 803.31 130.51 68.21 11.78x
4 4096 802.86 130.61 66.97 11.99x
4 8192 1505.96 246.77 100.49 14.99x
8 4096 1507.87 246.84 100.23 15.04x
8 8192 2856.93 476.34 184.40 15.49x
Backward Performance
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive
1 1024 1604.25 989.30 1114.12 1.44x
1 8192 1996.40 1117.71 1115.47 1.79x
2 8192 2353.87 1119.41 1118.57 2.10x
4 4096 2358.47 1102.23 1125.16 2.10x
4 8192 4346.92 1125.33 1135.36 3.83x
8 4096 4347.47 1104.27 1119.63 3.88x
8 8192 8226.50 1172.66 1197.68 6.87x
Forward + Backward Combined
batch_size seq_len Naive (us) Compiled (us) Triton (us) Triton vs Naive Triton vs Compiled
1 1024 1922.30 1072.59 1178.36 1.63x 0.91x
1 4096 2367.77 1208.69 1192.07 1.99x 1.01x
1 8192 2398.19 1219.32 1183.69 2.03x 1.03x
2 4096 2401.39 1248.87 1154.72 2.08x 1.08x
2 8192 3157.18 1249.92 1186.77 2.66x 1.05x
4 4096 3161.33 1232.84 1192.13 2.65x 1.03x
4 8192 5852.88 1372.10 1235.86 4.74x 1.11x
8 4096 5855.34 1351.11 1219.85 4.80x 1.11x
8 8192 11083.43 1649.00 1382.07 8.02x 1.19x

Pre-commit Hooks

This project uses pre-commit to automatically check and format code before commits.

Setup

  1. Install pre-commit:

    pip install pre-commit
    
  2. Install the git hooks:

   pre-commit install

Once installed, the configured hooks will run automatically on each commit.

Included Hooks

The following tools are run via pre-commit:

  • yapf โ€“ Python code formatter
  • typos โ€“ Spell checker for common typos
  • isort โ€“ Organizes and sorts Python imports
  • clang-format โ€“ Formats C++/CUDA code (--style=file)
  • pymarkdown โ€“ Lints and auto-fixes Markdown files
  • actionlint โ€“ Validates GitHub Actions workflows

Usage

  • Run all checks on the entire codebase:

    pre-commit run --all-files
    
  • Run a specific hook (example: isort):

  pre-commit run isort --all-files
Downloads last month
2,874
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support