KitsuVp commited on
Commit
6f890a3
Β·
verified Β·
1 Parent(s): 40e8a9a

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +39 -0
  3. configuration_unified.py +129 -0
  4. model.safetensors +3 -0
  5. modeling_unified.py +783 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a πŸ€— transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UnifiedModel"
4
+ ],
5
+ "attention_dropout": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_unified.UnifiedModelConfig",
8
+ "AutoModel": "modeling_unified.UnifiedModel",
9
+ "AutoModelForCausalLM": "modeling_unified.UnifiedModel"
10
+ },
11
+ "bos_token_id": 0,
12
+ "canon_a_enabled": true,
13
+ "canon_c_enabled": true,
14
+ "canon_enabled": true,
15
+ "canon_kernel_size": 4,
16
+ "embedding_dropout": 0.0,
17
+ "eos_token_id": 0,
18
+ "fanformer_p": 0.15,
19
+ "hidden_size": 512,
20
+ "intermediate_size": 2048,
21
+ "lax_enabled": true,
22
+ "lax_gate_type": "linear",
23
+ "max_position_embeddings": 1024,
24
+ "mlp_dropout": 0.1,
25
+ "model_type": "unified_model",
26
+ "num_attention_heads": 8,
27
+ "num_hidden_layers": 8,
28
+ "num_key_value_heads": 4,
29
+ "pad_token_id": 0,
30
+ "rms_norm_eps": 1e-06,
31
+ "rope_theta": 10000.0,
32
+ "tie_word_embeddings": true,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.55.0",
35
+ "vocab_size": 49152,
36
+ "xielu_alpha_n_init": 0.8,
37
+ "xielu_alpha_p_init": 0.8,
38
+ "xielu_beta": 0.5
39
+ }
configuration_unified.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================================================================
2
+ # configuration_unified.py
3
+ # ====================================================================
4
+
5
+ """
6
+ Configuration class for Unified Language Model
7
+ HuggingFace Transformers compatible configuration with AutoClass support
8
+ """
9
+
10
+ from transformers import PretrainedConfig
11
+ from typing import Optional
12
+
13
+ class UnifiedModelConfig(PretrainedConfig):
14
+ """
15
+ Configuration class for UnifiedModel.
16
+ Inherits from PretrainedConfig for full HuggingFace compatibility.
17
+ """
18
+ model_type = "unified_model"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size: int = None,
23
+ hidden_size: int = 256,
24
+ intermediate_size: int = 1024,
25
+ num_hidden_layers: int = 6,
26
+ num_attention_heads: int = 8,
27
+ num_key_value_heads: int = 4,
28
+ max_position_embeddings: int = 2048,
29
+ rms_norm_eps: float = 1e-6,
30
+ rope_theta: float = 10000.0,
31
+
32
+ attention_dropout: float = 0.1,
33
+ mlp_dropout: float = 0.1,
34
+ embedding_dropout: float = 0.1,
35
+
36
+ xielu_alpha_p_init: float = 0.8,
37
+ xielu_alpha_n_init: float = 0.8,
38
+ xielu_beta: float = 0.5,
39
+
40
+ tie_word_embeddings: bool = True, # HuggingFace standard parameter name
41
+
42
+ # LaX configuration (Linear only)
43
+ lax_enabled: bool = True,
44
+ lax_gate_type: str = "linear", # Only "linear" supported now
45
+
46
+ # Canon Layers configuration (A+C only)
47
+ canon_enabled: bool = True,
48
+ canon_kernel_size: int = 4,
49
+ canon_a_enabled: bool = True, # Before attention
50
+ canon_c_enabled: bool = True, # Before MLP
51
+ # Canon B and D are permanently disabled
52
+
53
+ # FANFormer configuration
54
+ fanformer_p: float = 0.15,
55
+
56
+ # HuggingFace standard parameters
57
+ pad_token_id: int = None,
58
+ bos_token_id: int = None,
59
+ eos_token_id: int = None,
60
+
61
+ **kwargs
62
+ ):
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ tie_word_embeddings=tie_word_embeddings,
68
+ **kwargs
69
+ )
70
+
71
+ self.vocab_size = vocab_size
72
+ self.hidden_size = hidden_size
73
+ self.intermediate_size = intermediate_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+ self.num_key_value_heads = num_key_value_heads
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.rms_norm_eps = rms_norm_eps
79
+ self.rope_theta = rope_theta
80
+
81
+ self.attention_dropout = attention_dropout
82
+ self.mlp_dropout = mlp_dropout
83
+ self.embedding_dropout = embedding_dropout
84
+
85
+ self.xielu_alpha_p_init = xielu_alpha_p_init
86
+ self.xielu_alpha_n_init = xielu_alpha_n_init
87
+ self.xielu_beta = xielu_beta
88
+ self.tie_word_embeddings = tie_word_embeddings
89
+
90
+ # LaX configuration
91
+ self.lax_enabled = lax_enabled
92
+ self.lax_gate_type = lax_gate_type
93
+
94
+ # Canon Layers configuration
95
+ self.canon_enabled = canon_enabled
96
+ self.canon_kernel_size = canon_kernel_size
97
+ self.canon_a_enabled = canon_a_enabled
98
+ self.canon_c_enabled = canon_c_enabled
99
+
100
+ # FANFormer
101
+ self.fanformer_p = fanformer_p
102
+
103
+ # βœ… FIXED: Force complete auto_map in config.json
104
+ self.auto_map = {
105
+ "AutoConfig": "configuration_unified.UnifiedModelConfig",
106
+ "AutoModel": "modeling_unified.UnifiedModel",
107
+ "AutoModelForCausalLM": "modeling_unified.UnifiedModel"
108
+ }
109
+
110
+ def to_diff_dict(self):
111
+ """
112
+ βœ… FIXED: Fuerza la serializaciΓ³n de tie_word_embeddings en config.json
113
+
114
+ Sobreescribe to_diff_dict() para asegurar que tie_word_embeddings
115
+ siempre aparezca en el config.json, evitando problemas de carga
116
+ donde HuggingFace no reconoce el weight tying.
117
+
118
+ Returns:
119
+ Dict: ConfiguraciΓ³n optimizada con tie_word_embeddings forzado
120
+ """
121
+ # Obtiene la serializaciΓ³n normal (solo diferencias)
122
+ output = super().to_diff_dict()
123
+
124
+ # βœ… FUERZA la inclusiΓ³n de tie_word_embeddings
125
+ # Esto asegura que aparezca en config.json sin importar si HF
126
+ # considera que es "default" o no
127
+ output["tie_word_embeddings"] = self.tie_word_embeddings
128
+
129
+ return output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24295526b57a7302b9a383da6bc54b22bc3ebe3ea59b759ae3aae0ab83cfac8f
3
+ size 173989448
modeling_unified.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================================================================
2
+ # modeling_unified.py
3
+ # ====================================================================
4
+
5
+ """
6
+ Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only)
7
+ MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES
8
+ UPDATED: Standard Transformer with advanced variance control, parameter efficiency, and Canon horizontal information flow
9
+ Combines advanced Transformer architecture with CORRECTED variance control mechanisms,
10
+ advanced variance control via GPAS and LNS, xIELU activation function, LaX integration, and Canon Layers (A+C only)
11
+ Based on LLaMA 3 architecture with 30M parameters
12
+
13
+ MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION:
14
+ ==============================================
15
+
16
+ 1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3
17
+ 2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated)
18
+ 3. **UPDATED COMPUTE_LOSS**: MΓ©todo actualizado con num_items_in_batch parameter
19
+ 4. **FIXED LOGGING**: Corregido self.log() syntax segΓΊn documentaciΓ³n oficial HF
20
+ 5. **RESTORED PAD HANDLING**: pad_token_id β†’ -100 conversion for CrossEntropyLoss (from original code)
21
+ 6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True)
22
+ 7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard)
23
+ 8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues
24
+ 9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parΓ‘metros
25
+ 10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa
26
+ 11. **MAINTAINED OPTIMIZERS**: Muon + AdamW hΓ­brido preservado
27
+ 12. **MAINTAINED PRECISION**: bf16-true preservado
28
+ 13. **MAINTAINED TRAINING**: Custom Trainer con todas las mΓ©tricas y logging
29
+ 14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada
30
+ 15. **AUTO TOKENIZER**: IntegraciΓ³n completa con AutoTokenizer dinΓ‘mico
31
+ 16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torch.utils.checkpoint import checkpoint
38
+ from transformers import (
39
+ AutoTokenizer,
40
+ AutoConfig,
41
+ AutoModel,
42
+ AutoModelForCausalLM,
43
+ PreTrainedModel,
44
+ )
45
+ import math
46
+ import os
47
+ from typing import Optional, Tuple, Dict, Any, cast, List
48
+ from flash_attn import flash_attn_func
49
+ import numpy as np
50
+
51
+ # βœ… ABSOLUTE IMPORT - No relative imports for Hub compatibility
52
+ from configuration_unified import UnifiedModelConfig
53
+
54
+ # Fix tokenizer parallelism warnings
55
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
56
+ torch.set_float32_matmul_precision('high')
57
+
58
+ def init_cola_components(A: nn.Linear, B: nn.Linear):
59
+ nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu')
60
+ nn.init.xavier_normal_(B.weight, gain=0.8)
61
+ if B.bias is not None:
62
+ nn.init.zeros_(B.bias)
63
+
64
+ def init_embedding(embedding: nn.Embedding):
65
+ nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
66
+
67
+ class CanonLayer(nn.Module):
68
+ def __init__(self, hidden_dim: int, kernel_size: int = 4):
69
+ """
70
+ Canon layer using a 1D causal convolution with residual connection.
71
+ """
72
+ super().__init__()
73
+ self.hidden_dim = hidden_dim
74
+ self.kernel_size = kernel_size
75
+
76
+ # Use causal convolution with explicit initialization
77
+ self.causal_conv1d = nn.Conv1d(
78
+ in_channels=hidden_dim,
79
+ out_channels=hidden_dim,
80
+ kernel_size=kernel_size,
81
+ groups=hidden_dim, # Depthwise convolution
82
+ padding=0, # No automatic padding
83
+ bias=True
84
+ )
85
+
86
+ # Initialize weights more conservatively (as per paper)
87
+ nn.init.zeros_(self.causal_conv1d.weight)
88
+ nn.init.zeros_(self.causal_conv1d.bias)
89
+
90
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ Applies the Canon layer transformation with causal masking.
93
+ """
94
+ # Conv1d expects input shape (batch_size, channels, sequence_length)
95
+ h_permuted = h.permute(0, 2, 1) # (batch, hidden_dim, seq_len)
96
+
97
+ # Add padding of (kernel_size - 1) only to the left side
98
+ padding = self.kernel_size - 1
99
+ h_padded = F.pad(h_permuted, (padding, 0))
100
+
101
+ # Apply causal convolution
102
+ conv_out = self.causal_conv1d(h_padded)
103
+
104
+ # Permute back to the original shape
105
+ conv_out_permuted = conv_out.permute(0, 2, 1)
106
+
107
+ # Add the residual connection
108
+ output = h + conv_out_permuted
109
+
110
+ return output
111
+
112
+ class CoLA_Linear(nn.Module):
113
+ def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True):
114
+ super().__init__()
115
+ if rank is None:
116
+ rank = in_features // 4
117
+ self.rank = rank
118
+ self.activation = activation
119
+
120
+ self.A = nn.Linear(in_features, rank, bias=False)
121
+ self.B = nn.Linear(rank, out_features, bias=bias)
122
+
123
+ # LaX Gate components (Linear only)
124
+ self.lax_gate = nn.Parameter(torch.zeros(1))
125
+
126
+ # Storage for previous layer's latent representation
127
+ self.prev_latent = None
128
+
129
+ init_cola_components(self.A, self.B)
130
+
131
+ def apply_lax_gate(self, prev_latent: torch.Tensor) -> torch.Tensor:
132
+ """Apply linear gate to previous latent representation."""
133
+ return F.sigmoid(self.lax_gate) * prev_latent
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ # Standard CoLA forward
137
+ latent = self.A(x)
138
+ latent_activated = self.activation(latent)
139
+
140
+ # Apply LaX if previous latent exists
141
+ if self.prev_latent is not None and self.prev_latent.shape == latent_activated.shape:
142
+ gated_prev = self.apply_lax_gate(self.prev_latent)
143
+ latent_activated = latent_activated + gated_prev
144
+
145
+ output = self.B(latent_activated)
146
+
147
+ # Store current latent for next layer (detached to avoid gradient issues)
148
+ self.prev_latent = latent_activated.detach()
149
+
150
+ return output
151
+
152
+ def reset_lax_state(self):
153
+ self.prev_latent = None
154
+
155
+ class LayerNormScaling(nn.Module):
156
+ def __init__(self, layer_depth: int):
157
+ super().__init__()
158
+
159
+ if layer_depth < 1:
160
+ raise ValueError(f"layer_depth debe ser β‰₯ 1, got {layer_depth}")
161
+
162
+ self.layer_depth = layer_depth
163
+ self.scaling_factor = 1.0 / math.sqrt(float(layer_depth))
164
+
165
+ def forward(self, normalized_input: torch.Tensor) -> torch.Tensor:
166
+ return normalized_input * self.scaling_factor
167
+
168
+ class GPAS(nn.Module):
169
+ def __init__(self, d_model: int):
170
+ super().__init__()
171
+
172
+ self.d_model = d_model
173
+ self.alpha = nn.Parameter(torch.zeros(1))
174
+
175
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ x_detached = x.detach()
177
+ scaled_component = F.silu(self.alpha) * x_detached
178
+ x_scaled = x - scaled_component
179
+
180
+ return x_scaled
181
+
182
+ class RotaryEmbedding(nn.Module):
183
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000):
184
+ super().__init__()
185
+ self.dim = dim
186
+ self.max_position_embeddings = max_position_embeddings
187
+ self.base = base
188
+
189
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
190
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
191
+
192
+ def forward(self, x, seq_len=None):
193
+ if seq_len is None:
194
+ seq_len = x.shape[-2]
195
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
196
+ freqs = torch.outer(t, self.inv_freq)
197
+ emb = torch.cat((freqs, freqs), dim=-1)
198
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
199
+
200
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
201
+ def rotate_half(x):
202
+ x1 = x[..., : x.shape[-1] // 2]
203
+ x2 = x[..., x.shape[-1] // 2 :]
204
+ return torch.cat((-x2, x1), dim=-1)
205
+
206
+ q_embed = (q * cos) + (rotate_half(q) * sin)
207
+ k_embed = (k * cos) + (rotate_half(k) * sin)
208
+ return q_embed, k_embed
209
+
210
+ class XIELU(nn.Module):
211
+ def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5):
212
+ super().__init__()
213
+
214
+ self.beta = beta
215
+
216
+ self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1))
217
+ self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1))
218
+
219
+ self.register_buffer('eps', torch.tensor(-1e-6))
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ alpha_p = F.softplus(self.alpha_p)
223
+ alpha_n = self.beta + F.softplus(self.alpha_n)
224
+
225
+ return torch.where(
226
+ x > 0,
227
+ alpha_p * x * x + self.beta * x,
228
+ alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x
229
+ )
230
+
231
+ class StandardMLP(nn.Module):
232
+ def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None):
233
+ super().__init__()
234
+
235
+ self.hidden_size = hidden_size
236
+ self.intermediate_size = intermediate_size
237
+ self.config = config
238
+
239
+ self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False)
240
+ self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False)
241
+
242
+ if config is not None:
243
+ self.activation = XIELU(
244
+ alpha_p_init=config.xielu_alpha_p_init,
245
+ alpha_n_init=config.xielu_alpha_n_init,
246
+ beta=config.xielu_beta
247
+ )
248
+ else:
249
+ self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5)
250
+
251
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
252
+
253
+ # Canon-D is permanently disabled
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ intermediate = self.up_proj(x)
257
+
258
+ # No Canon-D applied (eliminated)
259
+
260
+ activated = self.activation(intermediate)
261
+ activated = self.dropout(activated)
262
+ output = self.down_proj(activated)
263
+
264
+ return output
265
+
266
+ def reset_lax_state(self):
267
+ self.up_proj.reset_lax_state()
268
+ self.down_proj.reset_lax_state()
269
+
270
+ class GroupedQueryAttention(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.config = config
274
+ self.hidden_size = config.hidden_size
275
+ self.num_heads = config.num_attention_heads
276
+ self.num_key_value_heads = config.num_key_value_heads
277
+ self.head_dim = self.hidden_size // self.num_heads
278
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
279
+
280
+ # FANFormer components
281
+ self.fanformer_p = getattr(config, 'fanformer_p', 0.15)
282
+
283
+ self.d_periodic = int(self.hidden_size * self.fanformer_p)
284
+ self.d_standard = self.hidden_size - 2 * self.d_periodic
285
+
286
+ assert self.d_standard > 0, \
287
+ f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0"
288
+
289
+ self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False)
290
+ self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False)
291
+
292
+ self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
293
+ self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
294
+ self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
295
+ self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
296
+
297
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
298
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
299
+ self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
300
+
301
+ self.rotary_emb = RotaryEmbedding(
302
+ self.head_dim,
303
+ max_position_embeddings=config.max_position_embeddings,
304
+ base=config.rope_theta
305
+ )
306
+
307
+ # Canon-B is permanently disabled (no more canon_b)
308
+
309
+ def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor:
310
+ periodic_proj = self.fan_w_p(x)
311
+ standard_proj = self.fan_w_p_bar(x)
312
+
313
+ cos_component = torch.cos(periodic_proj)
314
+ sin_component = torch.sin(periodic_proj)
315
+
316
+ x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1)
317
+
318
+ return x_f
319
+
320
+ def _compute_flash_attention(
321
+ self,
322
+ query_states: torch.Tensor,
323
+ key_states: torch.Tensor,
324
+ value_states: torch.Tensor,
325
+ seq_len: int,
326
+ position_ids: Optional[torch.Tensor] = None
327
+ ) -> torch.Tensor:
328
+ batch_size = query_states.shape[0]
329
+
330
+ q_rope = query_states.transpose(1, 2)
331
+ k_rope = key_states.transpose(1, 2)
332
+
333
+ cos, sin = self.rotary_emb(value_states, seq_len=seq_len)
334
+ q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids)
335
+
336
+ query_states = q_rope.transpose(1, 2)
337
+ key_states = k_rope.transpose(1, 2)
338
+
339
+ from flash_attn import flash_attn_func
340
+
341
+ attn_output = flash_attn_func(
342
+ query_states,
343
+ key_states,
344
+ value_states,
345
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
346
+ causal=True,
347
+ )
348
+
349
+ return attn_output
350
+
351
+ def forward(self, hidden_states, position_ids=None, attention_mask=None):
352
+ batch_size, seq_len, _ = hidden_states.shape
353
+
354
+ enhanced_input = self._fan_layer_prime(hidden_states)
355
+
356
+ query_states = self.q_proj(enhanced_input)
357
+ key_states = self.k_proj(enhanced_input)
358
+ value_states = self.v_proj(enhanced_input)
359
+
360
+ # No Canon-B applied (eliminated)
361
+
362
+ query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim)
363
+ key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
364
+ value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
365
+
366
+ q_flat = query_states.reshape(-1, self.head_dim)
367
+ k_flat = key_states.reshape(-1, self.head_dim)
368
+ v_flat = value_states.reshape(-1, self.head_dim)
369
+
370
+ q_normalized = self.q_norm(q_flat)
371
+ k_normalized = self.k_norm(k_flat)
372
+ v_normalized = self.v_norm(v_flat)
373
+
374
+ query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim)
375
+ key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
376
+ value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
377
+
378
+ attn_output = self._compute_flash_attention(
379
+ query_states=query_states,
380
+ key_states=key_states,
381
+ value_states=value_states,
382
+ seq_len=seq_len,
383
+ position_ids=position_ids
384
+ )
385
+
386
+ attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
387
+ return self.o_proj(attn_output)
388
+
389
+ def reset_lax_state(self):
390
+ self.fan_w_p.reset_lax_state()
391
+ self.fan_w_p_bar.reset_lax_state()
392
+ self.q_proj.reset_lax_state()
393
+ self.k_proj.reset_lax_state()
394
+ self.v_proj.reset_lax_state()
395
+ self.o_proj.reset_lax_state()
396
+
397
+ class DecoderLayer(nn.Module):
398
+ def __init__(self, config, layer_idx: int):
399
+ super().__init__()
400
+ self.config = config
401
+ self.layer_idx = layer_idx
402
+
403
+ if layer_idx < 0:
404
+ raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}")
405
+
406
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
407
+ self.self_attn = GroupedQueryAttention(config)
408
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
409
+
410
+ self.mlp = StandardMLP(
411
+ config.hidden_size,
412
+ config.intermediate_size,
413
+ config.mlp_dropout,
414
+ config
415
+ )
416
+
417
+ self.dropout_output = nn.Dropout(0.01)
418
+
419
+ self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1)
420
+ self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1)
421
+
422
+ self.gpas_attention = GPAS(config.hidden_size)
423
+ self.gpas_mlp = GPAS(config.hidden_size)
424
+
425
+ # Canon layers (A+C only)
426
+ # Canon-A: Before attention block
427
+ if config.canon_enabled and config.canon_a_enabled:
428
+ self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size)
429
+ else:
430
+ self.canon_a = None
431
+
432
+ # Canon-C: Before MLP block
433
+ if config.canon_enabled and config.canon_c_enabled:
434
+ self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size)
435
+ else:
436
+ self.canon_c = None
437
+
438
+ def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
439
+ residual = hidden_states
440
+
441
+ # Apply Canon-A before attention
442
+ if self.canon_a is not None:
443
+ hidden_states = self.canon_a(hidden_states)
444
+
445
+ attention_input = self.input_layernorm(hidden_states)
446
+ attention_input = self.lns_attention(attention_input)
447
+ attention_output = self.self_attn(attention_input, position_ids, attention_mask)
448
+ hidden_states = residual + attention_output
449
+ hidden_states = self.gpas_attention(hidden_states)
450
+ hidden_states = self.dropout_output(hidden_states)
451
+
452
+ residual = hidden_states
453
+
454
+ # Apply Canon-C before MLP
455
+ if self.canon_c is not None:
456
+ hidden_states = self.canon_c(hidden_states)
457
+
458
+ mlp_input = self.post_attention_layernorm(hidden_states)
459
+ mlp_input = self.lns_mlp(mlp_input)
460
+ mlp_output = self.mlp(mlp_input)
461
+ hidden_states = residual + mlp_output
462
+ hidden_states = self.gpas_mlp(hidden_states)
463
+ hidden_states = self.dropout_output(hidden_states)
464
+
465
+ return hidden_states
466
+
467
+ def reset_lax_state(self):
468
+ self.self_attn.reset_lax_state()
469
+ self.mlp.reset_lax_state()
470
+
471
+ class UnifiedModel(PreTrainedModel):
472
+ """
473
+ UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility.
474
+ With AutoClass support for seamless Hub integration.
475
+ """
476
+ config_class = UnifiedModelConfig
477
+
478
+ # βœ… FIXED: _tied_weights_keys as class attribute (HuggingFace standard)
479
+ _tied_weights_keys = ["lm_head.weight"]
480
+
481
+ def __init__(self, config: UnifiedModelConfig):
482
+ super().__init__(config)
483
+ self.config = config
484
+
485
+ if config.vocab_size is None:
486
+ raise ValueError("config.vocab_size must be set from tokenizer before model initialization")
487
+
488
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
489
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout)
490
+ self.output_dropout = nn.Dropout(0.05)
491
+
492
+ # Create lm_head for output projections
493
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
494
+
495
+ self.layers = nn.ModuleList()
496
+ for i in range(config.num_hidden_layers):
497
+ self.layers.append(DecoderLayer(config, i))
498
+
499
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
500
+
501
+ # Initialize weights
502
+ self.post_init()
503
+
504
+ self._print_configuration()
505
+
506
+ def tie_weights(self):
507
+ """
508
+ βœ… FIXED: Simplified tie_weights method following HuggingFace standard.
509
+ Tie the word embeddings and the output layer.
510
+ This is called automatically if config.tie_word_embeddings is True.
511
+ """
512
+ if self.config.tie_word_embeddings:
513
+ print("πŸ”— Applying weight tying: lm_head.weight = embed_tokens.weight")
514
+ self.lm_head.weight = self.embed_tokens.weight
515
+ print("βœ… Weight tying successful: Parameters are properly shared")
516
+
517
+ def _init_weights(self, module):
518
+ """Initialize weights following the custom initialization scheme."""
519
+ if isinstance(module, nn.Linear):
520
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
521
+ if module.bias is not None:
522
+ nn.init.zeros_(module.bias)
523
+ elif isinstance(module, nn.Embedding):
524
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04)
525
+ elif isinstance(module, CoLA_Linear):
526
+ pass # CoLA_Linear has its own initialization
527
+
528
+ def _print_configuration(self):
529
+ # Conteo ingenuo de todos los parΓ‘metros registrados
530
+ total_params_naive = sum(p.numel() for p in self.parameters())
531
+
532
+ # Conteo inteligente considerando weight tying
533
+ total_params_actual = total_params_naive
534
+ vocab_params = self.config.vocab_size * self.config.hidden_size
535
+ tied_savings = 0
536
+
537
+ # βœ… CORRECCIΓ“N: Detectar y ajustar por weight tying real
538
+ if self.config.tie_word_embeddings:
539
+ # Verificar si los tensors estΓ‘n realmente atados en memoria
540
+ embed_weight = self.embed_tokens.weight
541
+ lm_head_weight = self.lm_head.weight
542
+
543
+ if embed_weight is lm_head_weight:
544
+ # Los tensors son idΓ©nticos - restar la duplicaciΓ³n
545
+ tied_savings = vocab_params
546
+ total_params_actual = total_params_naive - tied_savings
547
+ else:
548
+ # Weight tying configurado pero no aplicado aΓΊn
549
+ tied_savings = 0
550
+
551
+ # CΓ‘lculos de optimizaciΓ³n existentes
552
+ total_linear_params = 0
553
+ total_cola_params = 0
554
+ canon_params = 0
555
+
556
+ for name, module in self.named_modules():
557
+ if isinstance(module, CoLA_Linear):
558
+ in_features = module.A.in_features
559
+ out_features = module.B.out_features
560
+ rank = module.rank
561
+
562
+ standard_params = in_features * out_features
563
+ cola_params = (in_features * rank) + (rank * out_features)
564
+
565
+ total_linear_params += standard_params
566
+ total_cola_params += cola_params
567
+ elif isinstance(module, CanonLayer):
568
+ # Canon layer parameters: depthwise conv1d + bias
569
+ canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim
570
+ canon_params += canon_layer_params
571
+
572
+ cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0
573
+ canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0
574
+
575
+ print(f"\nπŸ“Š UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:")
576
+
577
+ # βœ… MEJORADO: Mostrar conteo real vs ingenuo para transparencia
578
+ if self.config.tie_word_embeddings and tied_savings > 0:
579
+ print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)")
580
+ print(f"πŸ“Š Parameter Breakdown:")
581
+ print(f" β€’ Naive count: {total_params_naive/1e6:.2f}M (all registered params)")
582
+ print(f" β€’ Actual count: {total_params_actual/1e6:.2f}M (after weight tying)")
583
+ print(f" β€’ Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)")
584
+ else:
585
+ print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M")
586
+
587
+ print(f"πŸ“š DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)")
588
+ print(f"πŸ”— βœ… PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}")
589
+
590
+ # βœ… CORRECCIΓ“N: Mostrar estado real del weight tying
591
+ if self.config.tie_word_embeddings:
592
+ if tied_savings > 0:
593
+ print(f"πŸ’Ύ Weight Tying Status: βœ… ACTIVE (tensors are shared in memory)")
594
+ else:
595
+ print(f"πŸ’Ύ Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)")
596
+
597
+ print(f"πŸš€ ACTIVATION: xIELU (Ξ±p_init={self.config.xielu_alpha_p_init}, Ξ±n_init={self.config.xielu_alpha_n_init}, Ξ²={self.config.xielu_beta})")
598
+ print(f"πŸ”„ UPGRADE: SwiGLU β†’ StandardMLP + xIELU (better efficiency & adaptability)")
599
+ print(f"πŸ—œοΈ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections")
600
+ print(f"πŸ”€ LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} (Gate: LINEAR ONLY)")
601
+ print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)")
602
+ if self.config.canon_enabled:
603
+ print(f" β€’ Canon-A (Before Attention): {'βœ…' if self.config.canon_a_enabled else '❌'}")
604
+ print(f" β€’ Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED")
605
+ print(f" β€’ Canon-C (Before MLP): {'βœ…' if self.config.canon_c_enabled else '❌'}")
606
+ print(f" β€’ Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED")
607
+ print(f" β€’ Canon Kernel Size: {self.config.canon_kernel_size}")
608
+ print(f" β€’ Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)")
609
+ print(f"⚑ GPAS Enabled: ALWAYS (Dynamic variance control)")
610
+ print(f"πŸ“ LNS Enabled: ALWAYS (Static depth scaling)")
611
+ print(f"πŸ”§ Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS")
612
+ print(f"πŸ”— Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)")
613
+ print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally")
614
+ print(f"⚑ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking")
615
+ print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id")
616
+ print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead")
617
+ print(f"πŸ”— βœ… FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)")
618
+ print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters")
619
+ print(f"πŸ€— HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3")
620
+ print(f"⚑ βœ… NATIVE TORCH COMPILE: Will be handled by TrainingArguments")
621
+ print(f"πŸš€ βœ… AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()")
622
+ def reset_lax_state(self):
623
+ """Reset LaX state for all layers."""
624
+ for layer in self.layers:
625
+ layer.reset_lax_state()
626
+
627
+ def forward(
628
+ self,
629
+ input_ids: torch.Tensor,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ position_ids: Optional[torch.Tensor] = None,
632
+ labels: Optional[torch.Tensor] = None,
633
+ **kwargs
634
+ ):
635
+ batch_size, seq_len = input_ids.shape
636
+
637
+ # Reset LaX state at the beginning of each forward pass
638
+ self.reset_lax_state()
639
+
640
+ hidden_states = self.embed_tokens(input_ids)
641
+ hidden_states = hidden_states.detach()
642
+ hidden_states = self.embedding_dropout(hidden_states)
643
+
644
+ for layer in self.layers:
645
+ hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask)
646
+
647
+ hidden_states = self.norm(hidden_states)
648
+ hidden_states = self.output_dropout(hidden_states)
649
+
650
+ logits = self.lm_head(hidden_states)
651
+
652
+ loss = None
653
+ if labels is not None:
654
+ # Shift so that tokens < n predict n
655
+ shift_logits = logits[..., :-1, :].contiguous()
656
+ shift_labels = labels[..., 1:].contiguous()
657
+ # Flatten the tokens
658
+ loss_fct = nn.CrossEntropyLoss()
659
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
660
+ shift_labels = shift_labels.view(-1)
661
+ # Enable model parallelism
662
+ shift_labels = shift_labels.to(shift_logits.device)
663
+
664
+ # βœ… RESTORED: Change pad tokens to -100 so CrossEntropyLoss ignores them (from original code)
665
+ if self.config.pad_token_id is not None:
666
+ shift_labels[shift_labels == self.config.pad_token_id] = -100
667
+
668
+ loss = loss_fct(shift_logits, shift_labels)
669
+
670
+ # Return in HuggingFace format
671
+ from transformers.modeling_outputs import CausalLMOutputWithPast
672
+ return CausalLMOutputWithPast(
673
+ loss=loss,
674
+ logits=logits,
675
+ past_key_values=None,
676
+ hidden_states=None,
677
+ attentions=None,
678
+ )
679
+
680
+ def get_input_embeddings(self):
681
+ return self.embed_tokens
682
+
683
+ def set_input_embeddings(self, value):
684
+ self.embed_tokens = value
685
+
686
+ def get_output_embeddings(self):
687
+ return self.lm_head
688
+
689
+ def set_output_embeddings(self, new_embeddings):
690
+ self.lm_head = new_embeddings
691
+
692
+ @torch.no_grad()
693
+ def generate(
694
+ self,
695
+ input_ids: torch.Tensor,
696
+ max_new_tokens: int = 50,
697
+ temperature: float = 1.0,
698
+ top_p: float = 0.9,
699
+ top_k: Optional[int] = None,
700
+ do_sample: bool = True,
701
+ pad_token_id: Optional[int] = None,
702
+ eos_token_id: Optional[int] = None,
703
+ **kwargs
704
+ ) -> torch.Tensor:
705
+ """
706
+ Generate sequences using the model.
707
+ Compatible with AutoModelForCausalLM interface.
708
+ """
709
+ # Set default token IDs
710
+ if pad_token_id is None:
711
+ pad_token_id = self.config.pad_token_id
712
+ if eos_token_id is None:
713
+ eos_token_id = self.config.eos_token_id
714
+
715
+ batch_size = input_ids.shape[0]
716
+ device = input_ids.device
717
+
718
+ # Reset LaX state for generation
719
+ self.reset_lax_state()
720
+
721
+ generated = input_ids.clone()
722
+
723
+ for _ in range(max_new_tokens):
724
+ # Forward pass
725
+ outputs = self.forward(generated)
726
+ logits = outputs.logits
727
+
728
+ # Get the logits for the last token
729
+ next_token_logits = logits[:, -1, :]
730
+
731
+ if do_sample:
732
+ # Apply temperature
733
+ if temperature != 1.0:
734
+ next_token_logits = next_token_logits / temperature
735
+
736
+ # Apply top-k filtering
737
+ if top_k is not None:
738
+ values, indices = torch.topk(next_token_logits, top_k)
739
+ next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf')
740
+
741
+ # Apply top-p (nucleus) filtering
742
+ if top_p < 1.0:
743
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
744
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
745
+
746
+ # Remove tokens with cumulative probability above the threshold
747
+ sorted_indices_to_remove = cumulative_probs > top_p
748
+ # Shift the indices to the right to keep also the first token above the threshold
749
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
750
+ sorted_indices_to_remove[..., 0] = 0
751
+
752
+ # Scatter sorted tensors to original indexing
753
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
754
+ next_token_logits[indices_to_remove] = -float('inf')
755
+
756
+ # Sample from the filtered distribution
757
+ probs = F.softmax(next_token_logits, dim=-1)
758
+ next_token = torch.multinomial(probs, num_samples=1)
759
+ else:
760
+ # Greedy decoding
761
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
762
+
763
+ # Append the new token
764
+ generated = torch.cat([generated, next_token], dim=1)
765
+
766
+ # Check for EOS token
767
+ if eos_token_id is not None and (next_token == eos_token_id).all():
768
+ break
769
+
770
+ return generated
771
+
772
+
773
+ # βœ… AUTOCLASS REGISTRATION - Required for Hub compatibility
774
+ # Register the configuration and model for AutoClass support
775
+ AutoConfig.register("unified_model", UnifiedModelConfig)
776
+ AutoModel.register(UnifiedModelConfig, UnifiedModel)
777
+ AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)
778
+
779
+ print("πŸš€ βœ… AUTOCLASS REGISTRATION COMPLETE:")
780
+ print(" β€’ AutoConfig.register('unified_model', UnifiedModelConfig)")
781
+ print(" β€’ AutoModel.register(UnifiedModelConfig, UnifiedModel)")
782
+ print(" β€’ AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)")
783
+ print(" β€’ Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)")