Erland commited on
Commit
632c6f7
·
verified ·
1 Parent(s): 335ed55

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/delta_net_1B.json +29 -0
  2. configs/delta_net_340M.json +27 -0
  3. configs/gla_340M.json +24 -0
  4. configs/gla_7B.json +25 -0
  5. configs/gsa_340M.json +29 -0
  6. configs/hgrn2_340M.json +20 -0
  7. configs/rectified_transformer_120M.json +19 -0
  8. configs/rectified_transformer_340M.json +19 -0
  9. configs/scaled_softpick_transformer_120M.json +19 -0
  10. configs/scaled_softpick_transformer_340M.json +19 -0
  11. configs/scaled_vanilla_transformer_120M.json +19 -0
  12. configs/scaled_vanilla_transformer_340M.json +19 -0
  13. configs/softpick_transformer_120M.json +19 -0
  14. configs/softpick_transformer_1B.json +23 -0
  15. configs/softpick_transformer_340M.json +19 -0
  16. configs/softpick_transformer_7B.json +22 -0
  17. configs/softpick_transformer_with_pruning_340M.json +63 -0
  18. configs/stochastic_softpick_transformer_120M.json +20 -0
  19. configs/transformer_120M.json +19 -0
  20. configs/transformer_340M.json +18 -0
  21. configs/vanilla_transformer_340M.json +19 -0
  22. fla/__init__.py +114 -0
  23. fla/__pycache__/__init__.cpython-311.pyc +0 -0
  24. fla/__pycache__/utils.cpython-311.pyc +0 -0
  25. fla/layers/__init__.py +44 -0
  26. fla/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  27. fla/layers/__pycache__/abc.cpython-311.pyc +0 -0
  28. fla/layers/__pycache__/attn.cpython-311.pyc +0 -0
  29. fla/layers/__pycache__/based.cpython-311.pyc +0 -0
  30. fla/layers/__pycache__/bitattn.cpython-311.pyc +0 -0
  31. fla/layers/__pycache__/delta_net.cpython-311.pyc +0 -0
  32. fla/layers/__pycache__/gsa.cpython-311.pyc +0 -0
  33. fla/layers/__pycache__/linear_attn.cpython-311.pyc +0 -0
  34. fla/layers/__pycache__/rebased.cpython-311.pyc +0 -0
  35. fla/layers/abc.py +218 -0
  36. fla/layers/attn.py +490 -0
  37. fla/layers/based.py +96 -0
  38. fla/layers/bitattn.py +192 -0
  39. fla/layers/delta_net.py +291 -0
  40. fla/layers/forgetting_attn.py +109 -0
  41. fla/layers/gated_deltanet.py +293 -0
  42. fla/layers/gated_deltaproduct.py +351 -0
  43. fla/layers/gla.py +294 -0
  44. fla/layers/gsa.py +227 -0
  45. fla/layers/hgrn.py +168 -0
  46. fla/layers/hgrn2.py +211 -0
  47. fla/layers/lightnet.py +210 -0
  48. fla/layers/linear_attn.py +166 -0
  49. fla/layers/multiscale_retention.py +298 -0
  50. fla/layers/nsa.py +138 -0
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/rectified_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_rectified_attn"
19
+ }
configs/rectified_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_rectified_attn"
19
+ }
configs/scaled_softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_softpick_attn"
19
+ }
configs/scaled_softpick_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_softpick_attn"
19
+ }
configs/scaled_vanilla_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/scaled_vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/softpick_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": false,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_softpick_attn"
19
+ }
configs/softpick_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "attn_impl": "parallel_softpick_attn"
23
+ }
configs/softpick_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_softpick_attn"
19
+ }
configs/softpick_transformer_7B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null,
21
+ "attn_impl": "parallel_softpick_attn"
22
+ }
configs/softpick_transformer_with_pruning_340M.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "attn_impl": "parallel_softpick_attn",
4
+ "bos_token_id": 1,
5
+ "elementwise_affine": true,
6
+ "eos_token_id": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "fuse_swiglu": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "layer_head_pruned": [
16
+ [
17
+ 2,
18
+ 1
19
+ ],
20
+ [
21
+ 2,
22
+ 7
23
+ ],
24
+ [
25
+ 2,
26
+ 12
27
+ ],
28
+ [
29
+ 2,
30
+ 13
31
+ ],
32
+ [
33
+ 3,
34
+ 5
35
+ ],
36
+ [
37
+ 3,
38
+ 13
39
+ ],
40
+ [
41
+ 3,
42
+ 14
43
+ ],
44
+ [
45
+ 13,
46
+ 6
47
+ ]
48
+ ],
49
+ "max_position_embeddings": 8192,
50
+ "model_type": "transformer_with_pruning",
51
+ "norm_eps": 1e-06,
52
+ "num_heads": 16,
53
+ "num_hidden_layers": 24,
54
+ "num_kv_heads": null,
55
+ "qk_norm": false,
56
+ "qkv_bias": false,
57
+ "rope_theta": 10000.0,
58
+ "tie_word_embeddings": false,
59
+ "transformers_version": "4.51.3",
60
+ "use_cache": true,
61
+ "vocab_size": 32000,
62
+ "window_size": null
63
+ }
configs/stochastic_softpick_transformer_120M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": false,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "stochastic_softpick_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_softpick_attn",
19
+ "stochastic_p": 0.9
20
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": false,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl" : "parallel_attn"
19
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": false,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/vanilla_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_attn"
19
+ }
fla/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention,
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel,
55
+ TransformerWithPruningForCausalLM,
56
+ TransformerWithPruningModel
57
+ )
58
+
59
+ __all__ = [
60
+ 'ABCAttention',
61
+ 'Attention',
62
+ 'BasedLinearAttention',
63
+ 'BitAttention',
64
+ 'DeltaNet',
65
+ 'GatedDeltaNet',
66
+ 'GatedDeltaProduct',
67
+ 'GatedLinearAttention',
68
+ 'GatedSlotAttention',
69
+ 'HGRNAttention',
70
+ 'HGRN2Attention',
71
+ 'LightNetAttention',
72
+ 'LinearAttention',
73
+ 'MultiScaleRetention',
74
+ 'NativeSparseAttention',
75
+ 'ReBasedLinearAttention',
76
+ 'RWKV6Attention',
77
+ 'RWKV7Attention',
78
+ 'ABCForCausalLM',
79
+ 'ABCModel',
80
+ 'BitNetForCausalLM',
81
+ 'BitNetModel',
82
+ 'DeltaNetForCausalLM',
83
+ 'DeltaNetModel',
84
+ 'GatedDeltaNetForCausalLM',
85
+ 'GatedDeltaNetModel',
86
+ 'GatedDeltaProductForCausalLM',
87
+ 'GatedDeltaProductModel',
88
+ 'GLAForCausalLM',
89
+ 'GLAModel',
90
+ 'GSAForCausalLM',
91
+ 'GSAModel',
92
+ 'HGRNForCausalLM',
93
+ 'HGRNModel',
94
+ 'HGRN2ForCausalLM',
95
+ 'HGRN2Model',
96
+ 'LightNetForCausalLM',
97
+ 'LightNetModel',
98
+ 'LinearAttentionForCausalLM',
99
+ 'LinearAttentionModel',
100
+ 'NSAForCausalLM',
101
+ 'NSAModel',
102
+ 'RetNetForCausalLM',
103
+ 'RetNetModel',
104
+ 'RWKV6ForCausalLM',
105
+ 'RWKV6Model',
106
+ 'RWKV7ForCausalLM',
107
+ 'RWKV7Model',
108
+ 'TransformerForCausalLM',
109
+ 'TransformerModel',
110
+ 'TransformerWithPruningForCausalLM',
111
+ 'TransformerWithPruningModel',
112
+ ]
113
+
114
+ __version__ = '0.1.2'
fla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.46 kB). View file
 
fla/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
fla/layers/__pycache__/abc.cpython-311.pyc ADDED
Binary file (9.8 kB). View file
 
fla/layers/__pycache__/attn.cpython-311.pyc ADDED
Binary file (25 kB). View file
 
fla/layers/__pycache__/based.cpython-311.pyc ADDED
Binary file (6.93 kB). View file
 
fla/layers/__pycache__/bitattn.cpython-311.pyc ADDED
Binary file (9.64 kB). View file
 
fla/layers/__pycache__/delta_net.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
fla/layers/__pycache__/gsa.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
fla/layers/__pycache__/linear_attn.cpython-311.pyc ADDED
Binary file (7.99 kB). View file
 
fla/layers/__pycache__/rebased.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/attn.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+ from fla.ops import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class Attention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ qkv_bias: bool = False,
43
+ qk_norm: bool = False,
44
+ window_size: Optional[int] = None,
45
+ rope_theta: Optional[float] = 10000.,
46
+ max_position_embeddings: Optional[int] = None,
47
+ layer_idx: int = None,
48
+ attn_impl: str = "flash_attn",
49
+ ):
50
+ super().__init__()
51
+
52
+ self.hidden_size = hidden_size
53
+ self.num_heads = num_heads
54
+ if num_kv_heads is None:
55
+ self.num_kv_heads = self.num_heads
56
+ else:
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_kv_groups = num_heads // self.num_kv_heads
59
+ self.head_dim = self.hidden_size // self.num_heads
60
+ self.kv_dim = self.num_kv_heads * self.head_dim
61
+ self.qkv_bias = qkv_bias
62
+ self.qk_norm = qk_norm
63
+
64
+ self.window_size = window_size
65
+ self.rope_theta = rope_theta
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.layer_idx = layer_idx
68
+ self.attn_impl = attn_impl
69
+
70
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
71
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
72
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
73
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
74
+
75
+ if "scaled" in self.attn_impl:
76
+ self.s = nn.Parameter(torch.empty(self.num_heads, 1))
77
+ self.register_buffer("logn", torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
78
+
79
+ if qk_norm:
80
+ self.q_norm = RMSNorm(self.head_dim)
81
+ self.k_norm = RMSNorm(self.head_dim)
82
+
83
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
84
+
85
+ def reset_parameters(self):
86
+ if "scaled" in self.attn_impl:
87
+ nn.init.constant_(self.s, 0.3)
88
+ self.logn.copy_(torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ attention_mask: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[Cache] = None,
95
+ output_attentions: bool = False,
96
+ use_cache: bool = False,
97
+ **kwargs,
98
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
99
+ if attention_mask is not None:
100
+ assert len(attention_mask.shape) == 2, (
101
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
102
+ "for padding purposes (0 indicating padding). "
103
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
104
+ )
105
+
106
+ batch_size, q_len, _ = hidden_states.size()
107
+
108
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
109
+
110
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
111
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
112
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
113
+
114
+ if self.qk_norm:
115
+ q, k = self.q_norm(q), self.k_norm(k)
116
+
117
+ # equivalent to cu_seqlens in `flash_attn`
118
+ cu_seqlens = kwargs.get('cu_seqlens', None)
119
+
120
+ seqlen_offset, max_seqlen = 0, q_len
121
+ if past_key_values is not None:
122
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
123
+ max_seqlen = q.shape[1] + seqlen_offset
124
+
125
+ if attention_mask is not None:
126
+ # to deliminate the offsets of padding tokens
127
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
128
+ max_seqlen = q.shape[1] + max(seqlen_offset)
129
+
130
+ if self.max_position_embeddings is not None:
131
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
132
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
133
+
134
+ if past_key_values is not None:
135
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
136
+ k_cached, v_cached = past_key_values.update(
137
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
138
+ layer_idx=self.layer_idx,
139
+ offset=q_len,
140
+ cache_kwargs=dict(window_size=self.window_size)
141
+ )['attn_state']
142
+ if cache_has_content:
143
+ k, v = k_cached, v_cached
144
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
145
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
146
+
147
+ # if flash_attn_func is None:
148
+ # raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
149
+
150
+ if "scaled" in self.attn_impl:
151
+ k_len = k.shape[1]
152
+ q = q * self.s.to(q.dtype) * self.logn[k_len-q_len:k_len].to(q.dtype)
153
+
154
+ # Contains at least one padding token in the sequence
155
+ if self.attn_impl == "flash_attn":
156
+ if attention_mask is not None:
157
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
158
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
159
+ max_seqlen_q, max_seqlen_k = max_seq_lens
160
+ o = flash_attn_varlen_func(
161
+ q, k, v,
162
+ cu_seqlens_q=cu_seqlens_q,
163
+ cu_seqlens_k=cu_seqlens_k,
164
+ max_seqlen_q=max_seqlen_q,
165
+ max_seqlen_k=max_seqlen_k,
166
+ causal=True,
167
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
168
+ )
169
+ o = pad_input(o, indices_q, batch_size, q_len)
170
+ elif cu_seqlens is not None:
171
+ o = flash_attn_varlen_func(
172
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
173
+ cu_seqlens_q=cu_seqlens,
174
+ cu_seqlens_k=cu_seqlens,
175
+ max_seqlen_q=max_seqlen,
176
+ max_seqlen_k=max_seqlen,
177
+ causal=True,
178
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
179
+ ).unsqueeze(0)
180
+ else:
181
+ o = flash_attn_func(
182
+ q, k, v,
183
+ causal=True,
184
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
185
+ )
186
+ elif self.attn_impl == "parallel_attn":
187
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
188
+ elif self.attn_impl == "parallel_scaled_attn":
189
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
190
+ elif self.attn_impl == "parallel_rectified_attn":
191
+ o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
192
+ elif self.attn_impl == "parallel_softpick_attn":
193
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
194
+ elif self.attn_impl == "parallel_scaled_softpick_attn":
195
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
196
+ elif self.attn_impl == "naive_attn":
197
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
198
+ elif self.attn_impl == "naive_scaled_attn":
199
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
200
+ elif self.attn_impl == "naive_rectified_attn":
201
+ o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
202
+ elif self.attn_impl == "naive_softpick_attn":
203
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
204
+ elif self.attn_impl == "naive_scaled_softpick_attn":
205
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
206
+ else:
207
+ raise ValueError(f"Unknown attention implementation: {self.attn_impl}")
208
+
209
+ o = o.reshape(batch_size, q_len, -1)
210
+ o = self.o_proj(o)
211
+
212
+ if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl:
213
+ attentions = None
214
+
215
+ return o, attentions, past_key_values
216
+
217
+ def _upad_input(self, q, k, v, attention_mask, q_len):
218
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
219
+ cache_mask = attention_mask[:, -seq_len:]
220
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
221
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
222
+ max_seqlen_k = seqlens.max().item()
223
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
224
+
225
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
226
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
227
+ if q_len == seq_len:
228
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
229
+ cu_seqlens_q = cu_seqlens_k
230
+ max_seqlen_q = max_seqlen_k
231
+ indices_q = indices_k
232
+ elif q_len == 1:
233
+ max_seqlen_q = 1
234
+ # There is a memcpy here, that is very bad.
235
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
236
+ indices_q = cu_seqlens_q[:-1]
237
+ q = q.squeeze(1)
238
+ else:
239
+ # The -q_len: slice assumes left padding.
240
+ attention_mask = attention_mask[:, -q_len:]
241
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
242
+
243
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
244
+
245
+ class StochasticSoftpickAttention(nn.Module):
246
+
247
+ def __init__(
248
+ self,
249
+ hidden_size: int = 2048,
250
+ num_heads: int = 32,
251
+ num_kv_heads: Optional[int] = None,
252
+ qkv_bias: bool = False,
253
+ qk_norm: bool = False,
254
+ window_size: Optional[int] = None,
255
+ rope_theta: Optional[float] = 10000.,
256
+ max_position_embeddings: Optional[int] = None,
257
+ layer_idx: int = None,
258
+ attn_impl: str = "flash_attn",
259
+ stochastic_p: float = 0.5,
260
+ ):
261
+ super().__init__()
262
+
263
+ self.hidden_size = hidden_size
264
+ self.num_heads = num_heads
265
+ if num_kv_heads is None:
266
+ self.num_kv_heads = self.num_heads
267
+ else:
268
+ self.num_kv_heads = num_kv_heads
269
+ self.num_kv_groups = num_heads // self.num_kv_heads
270
+ self.head_dim = self.hidden_size // self.num_heads
271
+ self.kv_dim = self.num_kv_heads * self.head_dim
272
+ self.qkv_bias = qkv_bias
273
+ self.qk_norm = qk_norm
274
+
275
+ self.window_size = window_size
276
+ self.rope_theta = rope_theta
277
+ self.max_position_embeddings = max_position_embeddings
278
+ self.layer_idx = layer_idx
279
+ self.attn_impl = attn_impl
280
+ self.stochastic_value = stochastic_p
281
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
282
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
283
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
284
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
285
+
286
+ if "scaled" in self.attn_impl:
287
+ self.s = nn.Parameter(torch.empty(self.num_heads, 1))
288
+ self.register_buffer("logn", torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
289
+
290
+ if qk_norm:
291
+ self.q_norm = RMSNorm(self.head_dim)
292
+ self.k_norm = RMSNorm(self.head_dim)
293
+
294
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
295
+
296
+ def reset_parameters(self):
297
+ if "scaled" in self.attn_impl:
298
+ nn.init.constant_(self.s, 0.3)
299
+ self.logn.copy_(torch.log(torch.arange(2, self.max_position_embeddings*4+2, dtype=self.s.dtype)[:, None, None]))
300
+
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ attention_mask: Optional[torch.LongTensor] = None,
306
+ past_key_values: Optional[Cache] = None,
307
+ output_attentions: bool = False,
308
+ use_cache: bool = False,
309
+ **kwargs,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ if attention_mask is not None:
312
+ assert len(attention_mask.shape) == 2, (
313
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
314
+ "for padding purposes (0 indicating padding). "
315
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
316
+ )
317
+
318
+ batch_size, q_len, _ = hidden_states.size()
319
+
320
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
321
+
322
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
323
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
324
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
325
+
326
+ if self.qk_norm:
327
+ q, k = self.q_norm(q), self.k_norm(k)
328
+
329
+ # equivalent to cu_seqlens in `flash_attn`
330
+ cu_seqlens = kwargs.get('cu_seqlens', None)
331
+
332
+ seqlen_offset, max_seqlen = 0, q_len
333
+ if past_key_values is not None:
334
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
335
+ max_seqlen = q.shape[1] + seqlen_offset
336
+
337
+ if attention_mask is not None:
338
+ # to deliminate the offsets of padding tokens
339
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
340
+ max_seqlen = q.shape[1] + max(seqlen_offset)
341
+
342
+ if self.max_position_embeddings is not None:
343
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
344
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
345
+
346
+ if past_key_values is not None:
347
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
348
+ k_cached, v_cached = past_key_values.update(
349
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
350
+ layer_idx=self.layer_idx,
351
+ offset=q_len,
352
+ cache_kwargs=dict(window_size=self.window_size)
353
+ )['attn_state']
354
+ if cache_has_content:
355
+ k, v = k_cached, v_cached
356
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
357
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
358
+
359
+ # if flash_attn_func is None:
360
+ # raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
361
+
362
+ if "scaled" in self.attn_impl:
363
+ k_len = k.shape[1]
364
+ q = q * self.s.to(q.dtype) * self.logn[k_len-q_len:k_len].to(q.dtype)
365
+
366
+ # Contains at least one padding token in the sequence
367
+
368
+ p = torch.rand(1, device=q.device)
369
+ stochastic_p = torch.tensor(self.stochastic_value, dtype=torch.float32, device=q.device)
370
+ cond = torch.where(p < stochastic_p, torch.tensor(1, dtype=torch.bool, device=q.device), torch.tensor(0, dtype=torch.bool, device=q.device))
371
+ if self.attn_impl == "flash_attn":
372
+ if attention_mask is not None:
373
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
374
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
375
+ max_seqlen_q, max_seqlen_k = max_seq_lens
376
+ o = flash_attn_varlen_func(
377
+ q, k, v,
378
+ cu_seqlens_q=cu_seqlens_q,
379
+ cu_seqlens_k=cu_seqlens_k,
380
+ max_seqlen_q=max_seqlen_q,
381
+ max_seqlen_k=max_seqlen_k,
382
+ causal=True,
383
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
384
+ )
385
+ o = pad_input(o, indices_q, batch_size, q_len)
386
+ elif cu_seqlens is not None:
387
+ o = flash_attn_varlen_func(
388
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
389
+ cu_seqlens_q=cu_seqlens,
390
+ cu_seqlens_k=cu_seqlens,
391
+ max_seqlen_q=max_seqlen,
392
+ max_seqlen_k=max_seqlen,
393
+ causal=True,
394
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
395
+ ).unsqueeze(0)
396
+ else:
397
+ o = flash_attn_func(
398
+ q, k, v,
399
+ causal=True,
400
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
401
+ )
402
+
403
+ elif self.attn_impl == "parallel_attn":
404
+ if cond:
405
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
406
+ else:
407
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
408
+ elif self.attn_impl == "parallel_scaled_attn":
409
+ if cond:
410
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
411
+ else:
412
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
413
+ elif self.attn_impl == "parallel_rectified_attn":
414
+ if cond:
415
+ o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
416
+ else:
417
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
418
+ elif self.attn_impl == "parallel_softpick_attn":
419
+ if cond:
420
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
421
+ else:
422
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
423
+ elif self.attn_impl == "parallel_scaled_softpick_attn":
424
+ if cond:
425
+ o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
426
+ else:
427
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
428
+ elif self.attn_impl == "naive_attn":
429
+ if cond:
430
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
431
+ else:
432
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
433
+ elif self.attn_impl == "naive_scaled_attn":
434
+ if cond:
435
+ o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
436
+ else:
437
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
438
+ elif self.attn_impl == "naive_rectified_attn":
439
+ if cond:
440
+ o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
441
+ else:
442
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
443
+ elif self.attn_impl == "naive_softpick_attn":
444
+ if cond:
445
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
446
+ else:
447
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
448
+ elif self.attn_impl == "naive_scaled_softpick_attn":
449
+ if cond:
450
+ o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
451
+ else:
452
+ o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens)
453
+ else:
454
+ raise ValueError(f"Unknown attention implementation: {self.attn_impl}")
455
+
456
+ o = o.reshape(batch_size, q_len, -1)
457
+ o = self.o_proj(o)
458
+
459
+ if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl:
460
+ attentions = None
461
+
462
+ return o, attentions, past_key_values
463
+
464
+ def _upad_input(self, q, k, v, attention_mask, q_len):
465
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
466
+ cache_mask = attention_mask[:, -seq_len:]
467
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
468
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
469
+ max_seqlen_k = seqlens.max().item()
470
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
471
+
472
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
473
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
474
+ if q_len == seq_len:
475
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
476
+ cu_seqlens_q = cu_seqlens_k
477
+ max_seqlen_q = max_seqlen_k
478
+ indices_q = indices_k
479
+ elif q_len == 1:
480
+ max_seqlen_q = 1
481
+ # There is a memcpy here, that is very bad.
482
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
483
+ indices_q = cu_seqlens_q[:-1]
484
+ q = q.squeeze(1)
485
+ else:
486
+ # The -q_len: slice assumes left padding.
487
+ attention_mask = attention_mask[:, -q_len:]
488
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
489
+
490
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/delta_net.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ **kwargs
90
+ ) -> DeltaNet:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.qk_activation = qk_activation
95
+ self.qk_norm = qk_norm
96
+
97
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
98
+ assert self.qk_norm in ['l2', 'sum']
99
+
100
+ if d_model is not None:
101
+ hidden_size = d_model
102
+ self.hidden_size = hidden_size
103
+ self.expand_k = expand_k
104
+ self.expand_v = expand_v
105
+ self.num_heads = num_heads
106
+ self.use_gate = use_gate
107
+ self.use_short_conv = use_short_conv
108
+ self.conv_size = conv_size
109
+ self.conv_bias = conv_bias
110
+ self.allow_neg_eigval = allow_neg_eigval
111
+
112
+ self.key_dim = int(hidden_size * expand_k)
113
+ self.value_dim = int(hidden_size * expand_v)
114
+ self.head_k_dim = self.key_dim // num_heads
115
+ self.head_v_dim = self.value_dim // num_heads
116
+ self.layer_idx = layer_idx
117
+
118
+ self.silu = nn.SiLU()
119
+ if mode == 'fused_chunk':
120
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
121
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
122
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
123
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
124
+
125
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
126
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
127
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ self.use_beta = use_beta
130
+ if self.use_beta:
131
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ if use_short_conv:
133
+ self.conv_size = conv_size
134
+ self.q_conv1d = ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation='silu' if qk_activation == 'silu' else None
138
+ )
139
+ self.k_conv1d = ShortConvolution(
140
+ hidden_size=self.key_dim,
141
+ kernel_size=conv_size,
142
+ activation='silu' if qk_activation == 'silu' else None
143
+ )
144
+ self.v_conv1d = ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation='silu'
148
+ )
149
+ else:
150
+ raise UserWarning(
151
+ "ShortConvolution is crucial to the performance. "
152
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
153
+ )
154
+ if use_gate:
155
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
156
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
157
+ else:
158
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
159
+
160
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Cache] = None,
167
+ use_cache: Optional[bool] = False,
168
+ output_attentions: Optional[bool] = False,
169
+ **kwargs: Unpack[Dict]
170
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
171
+ if attention_mask is not None:
172
+ assert len(attention_mask.shape) == 2, (
173
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
174
+ "for padding purposes (0 indicating padding). "
175
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
176
+ )
177
+
178
+ # change to inference mode.
179
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
180
+
181
+ last_state = None
182
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
183
+ last_state = past_key_values[self.layer_idx]
184
+
185
+ cu_seqlens = kwargs.get('cu_seqlens', None)
186
+ if self.use_short_conv:
187
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
188
+ if last_state is not None:
189
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
190
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
191
+ q, conv_state_q = self.q_conv1d(
192
+ x=self.q_proj(hidden_states),
193
+ mask=conv_mask,
194
+ cache=conv_state_q,
195
+ output_final_state=use_cache,
196
+ cu_seqlens=cu_seqlens
197
+ )
198
+ k, conv_state_k = self.k_conv1d(
199
+ x=self.k_proj(hidden_states),
200
+ mask=conv_mask,
201
+ cache=conv_state_k,
202
+ output_final_state=use_cache,
203
+ cu_seqlens=cu_seqlens
204
+ )
205
+ v, conv_state_v = self.v_conv1d(
206
+ x=self.v_proj(hidden_states),
207
+ mask=conv_mask,
208
+ cache=conv_state_v,
209
+ output_final_state=use_cache,
210
+ cu_seqlens=cu_seqlens
211
+ )
212
+ else:
213
+ q = self.q_proj(hidden_states)
214
+ k = self.k_proj(hidden_states)
215
+ if self.qk_activation == 'silu':
216
+ q, k = self.silu(q), self.silu(k)
217
+ v = self.silu(self.v_proj(hidden_states))
218
+
219
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
220
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
221
+ if self.qk_activation != 'silu':
222
+ if self.qk_activation == 'relu':
223
+ q, k = q.relu(), k.relu()
224
+ elif self.qk_activation == 'elu':
225
+ q, k = elu_p1(q), elu_p1(k)
226
+ elif self.qk_activation == 'identity':
227
+ pass
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ if self.qk_norm == 'sum':
232
+ q = sum_norm(q).to(q)
233
+ k = sum_norm(k).to(k)
234
+
235
+ if self.use_beta:
236
+ beta = self.b_proj(hidden_states).sigmoid()
237
+ else:
238
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
239
+
240
+ if self.allow_neg_eigval:
241
+ beta = beta * 2.
242
+
243
+ # dealing with padding
244
+ if attention_mask is not None:
245
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
246
+
247
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
248
+ if mode == 'fused_recurrent':
249
+ o, recurrent_state = fused_recurrent_delta_rule(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ beta=beta,
254
+ initial_state=recurrent_state,
255
+ output_final_state=use_cache,
256
+ cu_seqlens=cu_seqlens,
257
+ head_first=False,
258
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
259
+ )
260
+ elif mode == 'chunk':
261
+ o, recurrent_state = chunk_delta_rule(
262
+ q=q,
263
+ k=k,
264
+ v=v,
265
+ beta=beta,
266
+ initial_state=recurrent_state,
267
+ output_final_state=use_cache,
268
+ cu_seqlens=cu_seqlens,
269
+ head_first=False,
270
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
271
+ )
272
+ else:
273
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
274
+
275
+ if past_key_values is not None:
276
+ past_key_values.update(
277
+ recurrent_state=recurrent_state,
278
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
279
+ layer_idx=self.layer_idx,
280
+ offset=q.shape[1]
281
+ )
282
+
283
+ if self.use_gate:
284
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
285
+ o = self.o_norm(o, g)
286
+ else:
287
+ o = self.o_norm(o)
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/gated_deltanet.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ @torch.compile
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ @torch.compile
29
+ def sum_norm(x):
30
+ return (x / x.sum(-1, keepdim=True)).to(x)
31
+
32
+
33
+ class GatedDeltaNet(nn.Module):
34
+ """
35
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
36
+
37
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
38
+
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ **kwargs
95
+ ) -> GatedDeltaNet:
96
+ super().__init__()
97
+
98
+ self.mode = mode
99
+
100
+ self.hidden_size = hidden_size
101
+ self.expand_v = expand_v
102
+
103
+ self.use_gate = use_gate
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+
108
+ self.head_dim = head_dim
109
+ self.num_heads = num_heads
110
+
111
+ self.key_dim = int(self.num_heads * self.head_dim)
112
+ self.value_dim = int(self.key_dim * self.expand_v)
113
+ self.head_k_dim = head_dim
114
+ self.head_v_dim = int(head_dim * self.expand_v)
115
+ self.layer_idx = layer_idx
116
+
117
+ # Consistency check: Ensure expand_v produces integer values
118
+ if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
119
+ raise ValueError(
120
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
121
+ f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
122
+ )
123
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
124
+ raise ValueError(
125
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
126
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
127
+ )
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+
130
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
131
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
133
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
134
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
135
+
136
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
137
+ self.A_log = nn.Parameter(torch.log(A))
138
+ self.A_log._no_weight_decay = True
139
+ # hard coded for now
140
+ dt_min = 0.001
141
+ dt_max = 0.1
142
+ dt_init_floor = 1e-4
143
+ dt = torch.exp(
144
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
145
+ + math.log(dt_min)
146
+ )
147
+ dt = torch.clamp(dt, min=dt_init_floor)
148
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
149
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
150
+ self.dt_bias = nn.Parameter(inv_dt)
151
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
152
+ # name.endswith("bias") in param_grouping.py
153
+ self.dt_bias._no_weight_decay = True
154
+
155
+ if use_short_conv:
156
+ self.conv_size = conv_size
157
+ self.q_conv1d = ShortConvolution(
158
+ hidden_size=self.key_dim,
159
+ kernel_size=conv_size,
160
+ activation='silu'
161
+ )
162
+ self.k_conv1d = ShortConvolution(
163
+ hidden_size=self.key_dim,
164
+ kernel_size=conv_size,
165
+ activation='silu'
166
+ )
167
+ self.v_conv1d = ShortConvolution(
168
+ hidden_size=self.value_dim,
169
+ kernel_size=conv_size,
170
+ activation='silu'
171
+ )
172
+ else:
173
+ raise UserWarning(
174
+ "ShortConvolution is crucial to the performance. "
175
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
176
+ )
177
+ if use_gate:
178
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
179
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
180
+ else:
181
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
182
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[Cache] = None,
189
+ use_cache: Optional[bool] = False,
190
+ output_attentions: Optional[bool] = False,
191
+ **kwargs: Unpack[Dict]
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
193
+ if attention_mask is not None:
194
+ assert len(attention_mask.shape) == 2, (
195
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
196
+ "for padding purposes (0 indicating padding). "
197
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
198
+ )
199
+
200
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
201
+ if self.training:
202
+ assert mode == 'chunk', "Only chunk mode is supported in training."
203
+
204
+ last_state = None
205
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
206
+ last_state = past_key_values[self.layer_idx]
207
+
208
+ cu_seqlens = kwargs.get('cu_seqlens', None)
209
+ if self.use_short_conv:
210
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
211
+ if last_state is not None:
212
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
213
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
214
+ q, conv_state_q = self.q_conv1d(
215
+ x=self.q_proj(hidden_states),
216
+ mask=conv_mask,
217
+ cache=conv_state_q,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens
220
+ )
221
+ k, conv_state_k = self.k_conv1d(
222
+ x=self.k_proj(hidden_states),
223
+ mask=conv_mask,
224
+ cache=conv_state_k,
225
+ output_final_state=use_cache,
226
+ cu_seqlens=cu_seqlens
227
+ )
228
+ v, conv_state_v = self.v_conv1d(
229
+ x=self.v_proj(hidden_states),
230
+ mask=conv_mask,
231
+ cache=conv_state_v,
232
+ output_final_state=use_cache,
233
+ cu_seqlens=cu_seqlens
234
+ )
235
+ else:
236
+ q = F.silu(self.q_proj(hidden_states))
237
+ k = F.silu(self.k_proj(hidden_states))
238
+ v = F.silu(self.v_proj(hidden_states))
239
+
240
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
241
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
242
+ beta = self.b_proj(hidden_states).sigmoid()
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
244
+
245
+ # dealing with padding
246
+ if attention_mask is not None:
247
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
248
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
249
+
250
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
251
+ if mode == 'chunk':
252
+ o, recurrent_state = chunk_gated_delta_rule(
253
+ q=q,
254
+ k=k,
255
+ v=v,
256
+ g=g,
257
+ beta=beta,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False,
262
+ use_qk_l2norm_in_kernel=True
263
+ )
264
+ elif mode == 'fused_recurrent':
265
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
266
+ q=q,
267
+ k=k,
268
+ v=v,
269
+ g=g,
270
+ beta=beta,
271
+ initial_state=recurrent_state,
272
+ output_final_state=use_cache,
273
+ cu_seqlens=cu_seqlens,
274
+ head_first=False,
275
+ use_qk_l2norm_in_kernel=True
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281
+ layer_idx=self.layer_idx,
282
+ offset=q.shape[1]
283
+ )
284
+
285
+ if self.use_gate:
286
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
287
+ o = self.o_norm(o, g)
288
+ else:
289
+ o = self.o_norm(o)
290
+ o = rearrange(o, 'b t h d -> b t (h d)')
291
+ o = self.o_proj(o)
292
+
293
+ return o, None, past_key_values
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_eps: float = 1e-5,
41
+ gate_logit_normalizer: int = 8,
42
+ feature_map: str = 'swish',
43
+ use_output_gate: bool = False,
44
+ use_norm: bool = True,
45
+ layer_idx: Optional[int] = None,
46
+ scale: Optional[float] = 1.,
47
+ **kwargs
48
+ ) -> GatedSlotAttention:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+ self.hidden_size = hidden_size
53
+ self.expand_k = expand_k
54
+ self.expand_v = expand_v
55
+ self.num_heads = num_heads
56
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
57
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
58
+ self.key_dim = int(hidden_size * expand_k)
59
+ self.value_dim = int(hidden_size * expand_v)
60
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
61
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
62
+ self.head_k_dim = self.key_dim // self.num_heads
63
+ self.head_v_dim = self.value_dim // self.num_heads
64
+
65
+ self.use_short_conv = use_short_conv
66
+ self.conv_size = conv_size
67
+ self.conv_bias = conv_bias
68
+
69
+ self.gate_logit_normalizer = gate_logit_normalizer
70
+
71
+ self.use_output_gate = use_output_gate
72
+ self.use_norm = use_norm
73
+ self.scale = scale
74
+
75
+ if num_slots is None:
76
+ num_slots = self.head_k_dim
77
+ self.num_slots = num_slots
78
+
79
+ self.layer_idx = layer_idx
80
+
81
+ if layer_idx is None:
82
+ warnings.warn(
83
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
84
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
85
+ "when creating this class."
86
+ )
87
+
88
+ self.register_module('feature_map', None)
89
+ if feature_map == 'swish':
90
+ self.feature_map = SwishFeatureMap()
91
+ elif feature_map == 'relu':
92
+ self.feature_map = ReLUFeatureMap()
93
+ elif feature_map == 't2r':
94
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
95
+ else:
96
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
97
+
98
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
101
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
102
+
103
+ if use_short_conv:
104
+ self.conv_size = conv_size
105
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
106
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
107
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
108
+
109
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
110
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ past_key_values: Optional[Cache] = None,
117
+ use_cache: Optional[bool] = False,
118
+ output_attentions: Optional[bool] = False,
119
+ **kwargs: Unpack[Dict]
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
121
+ if attention_mask is not None:
122
+ assert len(attention_mask.shape) == 2, (
123
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
124
+ "for padding purposes (0 indicating padding). "
125
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
126
+ )
127
+
128
+ # launching the triton kernel for just one token will actually be slower
129
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
130
+
131
+ last_state = None
132
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
133
+ last_state = past_key_values[self.layer_idx]
134
+
135
+ cu_seqlens = kwargs.get('cu_seqlens', None)
136
+ if self.use_short_conv:
137
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
138
+ if last_state is not None:
139
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
140
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
141
+ q, conv_state_q = self.q_conv1d(
142
+ x=self.q_proj(hidden_states),
143
+ mask=conv_mask,
144
+ cache=conv_state_q,
145
+ output_final_state=use_cache,
146
+ cu_seqlens=cu_seqlens
147
+ )
148
+ k, conv_state_k = self.k_conv1d(
149
+ x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache,
153
+ cu_seqlens=cu_seqlens
154
+ )
155
+ v, conv_state_v = self.v_conv1d(
156
+ x=self.v_proj(hidden_states),
157
+ mask=conv_mask,
158
+ cache=conv_state_v,
159
+ output_final_state=use_cache,
160
+ cu_seqlens=cu_seqlens
161
+ )
162
+ else:
163
+ q = self.q_proj(hidden_states)
164
+ k = self.k_proj(hidden_states)
165
+ v = self.v_proj(hidden_states)
166
+ f = self.f_proj(hidden_states)
167
+
168
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
169
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
170
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
171
+ f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
172
+
173
+ if self.feature_map is not None:
174
+ q, k = map(lambda x: self.feature_map(x), (q, k))
175
+ v = F.silu(v)
176
+
177
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
178
+ s = (1 - f.exp()).to(f.dtype)
179
+ # dealing with left-padding
180
+ if attention_mask is not None:
181
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
182
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
183
+
184
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
185
+ if mode == 'fused_recurrent':
186
+ o, recurrent_state = fused_recurrent_gsa(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ s=s,
191
+ g=f,
192
+ initial_state=recurrent_state,
193
+ output_final_state=use_cache,
194
+ scale=self.scale,
195
+ cu_seqlens=cu_seqlens,
196
+ head_first=False
197
+ )
198
+ elif mode == 'chunk':
199
+ o, recurrent_state = chunk_gsa(
200
+ q=q,
201
+ k=k,
202
+ v=v,
203
+ s=s,
204
+ g=f,
205
+ initial_state=recurrent_state,
206
+ output_final_state=use_cache,
207
+ scale=self.scale,
208
+ cu_seqlens=cu_seqlens,
209
+ head_first=False
210
+ )
211
+ else:
212
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
213
+
214
+ if past_key_values is not None:
215
+ past_key_values.update(
216
+ recurrent_state=recurrent_state,
217
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
218
+ layer_idx=self.layer_idx,
219
+ offset=q.shape[1]
220
+ )
221
+
222
+ o = rearrange(o, 'b t h d -> b t (h d)')
223
+ o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
224
+ return o, None, past_key_values
225
+
226
+ def state_size(self, *args, **kwargs) -> int:
227
+ return 2 * self.num_slots * self.hidden_size
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/lightnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022)
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import FusedRMSNormGated, ShortConvolution
16
+ from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear
17
+ from fla.ops.gla import chunk_gla, fused_recurrent_gla
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class LightNetAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ num_heads: Optional[int] = None,
32
+ expand_ratio: Optional[int] = 128,
33
+ use_short_conv: bool = False,
34
+ conv_size: int = 4,
35
+ conv_bias: bool = False,
36
+ gate_low_rank_dim: int = 128,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> LightNetAttention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.key_dim = int(self.num_heads * self.expand_ratio)
60
+ self.value_dim = hidden_size
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.layer_idx = layer_idx
63
+
64
+ assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
65
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
66
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
67
+
68
+ self.head_f_dim = self.expand_ratio
69
+ self.head_i_dim = self.hidden_size // num_heads
70
+
71
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
72
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
73
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
74
+
75
+ if use_short_conv:
76
+ self.conv_size = conv_size
77
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
78
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
79
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None)
80
+
81
+ self.g_proj = nn.Sequential(
82
+ nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
83
+ nn.Linear(gate_low_rank_dim, hidden_size, bias=False)
84
+ )
85
+ self.g_norm = FusedRMSNormGated(
86
+ hidden_size=hidden_size,
87
+ elementwise_affine=elementwise_affine,
88
+ eps=norm_eps
89
+ )
90
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict]
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ if attention_mask is not None:
102
+ assert len(attention_mask.shape) == 2, (
103
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
104
+ "for padding purposes (0 indicating padding). "
105
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
106
+ )
107
+
108
+ # launching the triton kernel for just one token will actually be slower
109
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
110
+
111
+ last_state = None
112
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
113
+ last_state = past_key_values[self.layer_idx]
114
+
115
+ cu_seqlens = kwargs.get('cu_seqlens', None)
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(
122
+ x=self.q_proj(hidden_states),
123
+ mask=conv_mask,
124
+ cache=conv_state_q,
125
+ output_final_state=use_cache,
126
+ cu_seqlens=cu_seqlens
127
+ )
128
+ k, conv_state_k = self.k_conv1d(
129
+ x=self.k_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_k,
132
+ output_final_state=use_cache,
133
+ cu_seqlens=cu_seqlens
134
+ )
135
+ v, conv_state_v = self.v_conv1d(
136
+ x=self.v_proj(hidden_states),
137
+ mask=conv_mask,
138
+ cache=conv_state_v,
139
+ output_final_state=use_cache,
140
+ cu_seqlens=cu_seqlens
141
+ )
142
+ else:
143
+ q = self.q_proj(hidden_states)
144
+ k = self.k_proj(hidden_states)
145
+ v = self.v_proj(hidden_states)
146
+
147
+ # dealing with left-padding
148
+ if attention_mask is not None:
149
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
150
+
151
+ q = F.silu(q)
152
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k))
153
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim)
154
+ # TODO: this 2 steps took huge amount of time, which should be optimized
155
+ z = k.float().logcumsumexp(1)
156
+
157
+ if cu_seqlens is not None:
158
+ raise NotImplementedError("LightNet does not support variable-length sequences for now.")
159
+ k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype)
160
+
161
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
162
+ if mode == 'fused_recurrent':
163
+ o, recurrent_state = fused_recurrent_gla(
164
+ q=q,
165
+ k=k,
166
+ v=v,
167
+ gk=g,
168
+ initial_state=recurrent_state,
169
+ output_final_state=use_cache,
170
+ cu_seqlens=cu_seqlens,
171
+ head_first=False
172
+ )
173
+ elif mode == 'chunk':
174
+ o, recurrent_state = chunk_gla(
175
+ q=q,
176
+ k=k,
177
+ v=v,
178
+ g=g,
179
+ initial_state=recurrent_state,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens,
182
+ head_first=False
183
+ )
184
+ else:
185
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
186
+
187
+ if past_key_values is not None:
188
+ past_key_values.update(
189
+ recurrent_state=recurrent_state,
190
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
191
+ layer_idx=self.layer_idx,
192
+ offset=q.shape[1]
193
+ )
194
+
195
+ o = rms_norm_swish_gate_linear(
196
+ rearrange(o, 'b t h d -> b t (h d)'),
197
+ self.g_proj(hidden_states),
198
+ self.g_norm.weight,
199
+ self.g_norm.bias,
200
+ self.o_proj.weight,
201
+ self.o_proj.bias
202
+ )
203
+ return o, None, past_key_values
204
+
205
+ def state_size(self, **kwargs) -> int:
206
+ state_size = self.key_dim * self.head_i_dim
207
+ for module in self.children():
208
+ if isinstance(module, ShortConvolution):
209
+ state_size += module.state_size
210
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class MultiScaleRetention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
24
+
25
+ Args:
26
+ mode (str, Optional):
27
+ Which Retention kernel to use.
28
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
29
+ Default: `chunk`.
30
+ hidden_size (int, Optional):
31
+ The hidden size of the input. Default: 1024.
32
+ expand_k (float, Optional):
33
+ The expansion ratio for the key dim. Default: 1.0.
34
+ expand_v (float, Optional):
35
+ The expansion ratio for the value dim. Default: 2.0.
36
+ num_heads (int, Optional):
37
+ The number of heads. Default: 8.
38
+ num_kv_heads (int, Optional):
39
+ The number of key/value heads, used for MQA. Default: None.
40
+ feature_map (str, Optional):
41
+ Feature map function applied to queries/keys. Default: None.
42
+ use_short_conv (bool, Optional):
43
+ Whether to use short convolutions. Default: `False`.
44
+ conv_size (int, Optional):
45
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
46
+ conv_bias (bool, Optional):
47
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
48
+ use_output_gate (bool, Optional):
49
+ Whether to use output gate. Default: `True`.
50
+ gate_fn (str, Optional):
51
+ The activation function for the output gate. Default: `swish`.
52
+ elementwise_affine (bool, Optional):
53
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
54
+ norm_eps (float, Optional):
55
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
56
+ fuse_norm (bool, Optional):
57
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
58
+ layer_idx (int, Optional):
59
+ The index of the layer. Default: None.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ mode: str = 'chunk',
65
+ hidden_size: int = 1024,
66
+ expand_k: float = 1.0,
67
+ expand_v: float = 2.0,
68
+ num_heads: int = 8,
69
+ num_kv_heads: Optional[int] = None,
70
+ feature_map: Optional[str] = None,
71
+ use_short_conv: bool = False,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ use_output_gate: bool = True,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ fuse_norm: bool = True,
79
+ layer_idx: int = None,
80
+ **kwargs
81
+ ) -> MultiScaleRetention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+ self.use_output_gate = use_output_gate
97
+
98
+ self.key_dim = int(hidden_size * expand_k)
99
+ self.value_dim = int(hidden_size * expand_v)
100
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
101
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
102
+ self.layer_idx = layer_idx
103
+
104
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
105
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
106
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
107
+
108
+ self.head_k_dim = self.key_dim // num_heads
109
+ self.head_v_dim = self.value_dim // num_heads
110
+
111
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
112
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
113
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
114
+ if self.use_output_gate:
115
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
116
+
117
+ if use_short_conv:
118
+ self.conv_size = conv_size
119
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
120
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
121
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
122
+
123
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
124
+
125
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
126
+ self.g_norm_swish_gate = FusedRMSNormGated(
127
+ hidden_size=self.head_v_dim,
128
+ elementwise_affine=elementwise_affine,
129
+ eps=norm_eps
130
+ )
131
+ self.fuse_norm_and_gate = True
132
+ else:
133
+ self.fuse_norm_and_gate = False
134
+ self.g_norm = RMSNorm(
135
+ hidden_size=self.head_v_dim,
136
+ elementwise_affine=elementwise_affine,
137
+ eps=norm_eps
138
+ )
139
+ self.gate_fn = ACT2FN[gate_fn]
140
+
141
+ # TODO: fix this issue
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
143
+ # Ideally, we would want to support arbitrary d_head_qk
144
+ assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256"
145
+ self.rotary = RotaryEmbedding(dim=self.head_k_dim)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_values: Optional[Cache] = None,
152
+ use_cache: Optional[bool] = False,
153
+ output_attentions: Optional[bool] = False,
154
+ **kwargs
155
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
156
+ if attention_mask is not None:
157
+ assert len(attention_mask.shape) == 2, (
158
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
159
+ "for padding purposes (0 indicating padding). "
160
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
161
+ )
162
+
163
+ # launching the triton kernel for just one token will actually be slower
164
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
165
+
166
+ last_state = None
167
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
168
+ last_state = past_key_values[self.layer_idx]
169
+
170
+ cu_seqlens = kwargs.get('cu_seqlens', None)
171
+ if self.use_short_conv:
172
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
173
+ if last_state is not None:
174
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
175
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
176
+ q, conv_state_q = self.q_conv1d(
177
+ x=self.q_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_q,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens
182
+ )
183
+ k, conv_state_k = self.k_conv1d(
184
+ x=self.k_proj(hidden_states),
185
+ mask=conv_mask,
186
+ cache=conv_state_k,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens
189
+ )
190
+ v, conv_state_v = self.v_conv1d(
191
+ x=self.v_proj(hidden_states),
192
+ mask=conv_mask,
193
+ cache=conv_state_v,
194
+ output_final_state=use_cache,
195
+ cu_seqlens=cu_seqlens
196
+ )
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+
202
+ # dealing with left-padding
203
+ if attention_mask is not None:
204
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
205
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
206
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
207
+ if self.feature_map_fn is not None:
208
+ q, k = map(self.feature_map_fn, (q, k))
209
+
210
+ seqlen_offset, max_seqlen = 0, q.shape[1]
211
+ if past_key_values is not None:
212
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
213
+ max_seqlen = q.shape[1] + seqlen_offset
214
+
215
+ if attention_mask is not None:
216
+ # to deliminate the offsets of padding tokens
217
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
218
+ max_seqlen = q.shape[1] + max(seqlen_offset)
219
+
220
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
221
+
222
+ if self.num_kv_groups > 1:
223
+ k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
224
+ v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
225
+ else:
226
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
227
+
228
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
229
+ if mode == 'chunk':
230
+ o, recurrent_state = chunk_retention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ cu_seqlens=cu_seqlens,
237
+ head_first=False
238
+ )
239
+ elif mode == 'fused_chunk':
240
+ o, recurrent_state = fused_chunk_retention(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ cu_seqlens=cu_seqlens,
247
+ head_first=False
248
+ )
249
+ elif mode == 'parallel':
250
+ o, recurrent_state = parallel_retention(
251
+ q=q,
252
+ k=k,
253
+ v=v,
254
+ cu_seqlens=cu_seqlens,
255
+ head_first=False
256
+ )
257
+ elif mode == 'fused_recurrent':
258
+ o, recurrent_state = fused_recurrent_retention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ initial_state=recurrent_state,
263
+ output_final_state=use_cache,
264
+ cu_seqlens=cu_seqlens,
265
+ head_first=False
266
+ )
267
+ else:
268
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
269
+
270
+ if past_key_values is not None:
271
+ past_key_values.update(
272
+ recurrent_state=recurrent_state,
273
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274
+ layer_idx=self.layer_idx,
275
+ offset=q.shape[1]
276
+ )
277
+
278
+ if self.use_output_gate:
279
+ g = self.g_proj(hidden_states)
280
+ if self.fuse_norm_and_gate:
281
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
282
+ o = self.g_norm_swish_gate(o, g)
283
+ o = rearrange(o, 'b t h d -> b t (h d)')
284
+ else:
285
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
286
+ o = o * self.gate_fn(g)
287
+ else:
288
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
292
+
293
+ def state_size(self, **kwargs) -> int:
294
+ state_size = self.key_dim * self.head_v_dim
295
+ for module in self.children():
296
+ if isinstance(module, ShortConvolution):
297
+ state_size += module.state_size
298
+ return state_size
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values