Update transformers version to 5.0

#24
by SSON9 - opened
Files changed (4) hide show
  1. README.md +1 -2
  2. config.json +0 -5
  3. configuration_solar_open.py +0 -242
  4. modeling_solar_open.py +0 -605
README.md CHANGED
@@ -118,7 +118,7 @@ top_k=50
118
  Install the required dependencies:
119
 
120
  ```bash
121
- pip install -U transformers kernels torch accelerate
122
  ```
123
 
124
  Run inference with the following code:
@@ -136,7 +136,6 @@ model = AutoModelForCausalLM.from_pretrained(
136
  pretrained_model_name_or_path=MODEL_ID,
137
  torch_dtype=torch.bfloat16,
138
  device_map="auto",
139
- trust_remote_code=True,
140
  )
141
 
142
  # Prepare input
 
118
  Install the required dependencies:
119
 
120
  ```bash
121
+ pip install -U "transformers>=5.0" kernels torch accelerate
122
  ```
123
 
124
  Run inference with the following code:
 
136
  pretrained_model_name_or_path=MODEL_ID,
137
  torch_dtype=torch.bfloat16,
138
  device_map="auto",
 
139
  )
140
 
141
  # Prepare input
config.json CHANGED
@@ -3,11 +3,6 @@
3
  "architectures": [
4
  "SolarOpenForCausalLM"
5
  ],
6
- "auto_map": {
7
- "AutoConfig": "configuration_solar_open.SolarOpenConfig",
8
- "AutoModel": "modeling_solar_open.SolarOpenModel",
9
- "AutoModelForCausalLM": "modeling_solar_open.SolarOpenForCausalLM"
10
- },
11
  "pad_token_id": 2,
12
  "bos_token_id": 1,
13
  "eos_token_id": 2,
 
3
  "architectures": [
4
  "SolarOpenForCausalLM"
5
  ],
 
 
 
 
 
6
  "pad_token_id": 2,
7
  "bos_token_id": 1,
8
  "eos_token_id": 2,
configuration_solar_open.py DELETED
@@ -1,242 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 Upstage AI.
3
- # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This file has been modified by Upstage AI including
18
- # - Hyperparameter Adjustments: Modified the model architecture by increasing vocab_size and num_hidden_layers, while decreasing num_attention_heads, intermediate_size, and moe_intermediate_size.
19
- # RoPE Configuration: Replaced the generic rope_parameters argument with explicit rope_theta and rope_scaling parameters to define Rotary Positional Embeddings settings.
20
- #
21
- # Based on code from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4_moe/configuration_glm4_moe.py
22
-
23
- from transformers.configuration_utils import PretrainedConfig
24
- from transformers.modeling_rope_utils import rope_config_validation
25
-
26
-
27
- class SolarOpenConfig(PretrainedConfig):
28
- r"""
29
- This is the configuration class to store the configuration of a [`SolarOpenModel`]. It is used to instantiate a
30
- SolarOpen model according to the specified arguments, defining the model architecture.
31
-
32
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
- documentation from [`PretrainedConfig`] for more information.
34
-
35
-
36
- Args:
37
- vocab_size (`int`, *optional*, defaults to 196608):
38
- Vocabulary size of the SolarOpen model. Defines the number of different tokens that can be represented by the
39
- `inputs_ids` passed when calling [`SolarOpenModel`]
40
- hidden_size (`int`, *optional*, defaults to 4096):
41
- Dimension of the hidden representations.
42
- intermediate_size (`int`, *optional*, defaults to 10240):
43
- Dimension of the MLP representations.
44
- num_hidden_layers (`int`, *optional*, defaults to 48):
45
- Number of hidden layers in the Transformer encoder.
46
- num_attention_heads (`int`, *optional*, defaults to 64):
47
- Number of attention heads for each attention layer in the Transformer encoder.
48
- partial_rotary_factor (`float`, *optional*, defaults to 1.0):
49
- The factor of the partial rotary position.
50
- num_key_value_heads (`int`, *optional*, defaults to 8):
51
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
- by meanpooling all the original heads within that group. For more details, check out [this
56
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
57
-
58
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
- The non-linear activation function (function or string) in the decoder.
60
- max_position_embeddings (`int`, *optional*, defaults to 131072):
61
- The maximum sequence length that this model might ever be used with.
62
- initializer_range (`float`, *optional*, defaults to 0.02):
63
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
65
- The epsilon used by the rms normalization layers.
66
- use_cache (`bool`, *optional*, defaults to `True`):
67
- Whether or not the model should return the last key/values attentions (not used by all models). Only
68
- relevant if `config.is_decoder=True`.
69
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
- Whether the model's input and output word embeddings should be tied.
71
- rope_theta (`float`, *optional*, defaults to 1000000.0):
72
- The base period of the RoPE embeddings.
73
- rope_scaling (`Dict`, *optional*):
74
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
75
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
76
- accordingly.
77
- Expected contents:
78
- `rope_type` (`str`):
79
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
80
- 'llama3'], with 'default' being the original RoPE implementation.
81
- `factor` (`float`, *optional*):
82
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
83
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
84
- original maximum pre-trained length.
85
- `original_max_position_embeddings` (`int`, *optional*):
86
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
87
- pretraining.
88
- `attention_factor` (`float`, *optional*):
89
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
90
- computation. If unspecified, it defaults to value recommended by the implementation, using the
91
- `factor` field to infer the suggested value.
92
- `beta_fast` (`float`, *optional*):
93
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
94
- ramp function. If unspecified, it defaults to 32.
95
- `beta_slow` (`float`, *optional*):
96
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
97
- ramp function. If unspecified, it defaults to 1.
98
- `short_factor` (`list[float]`, *optional*):
99
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
100
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
101
- size divided by the number of attention heads divided by 2
102
- `long_factor` (`list[float]`, *optional*):
103
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
104
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
105
- size divided by the number of attention heads divided by 2
106
- `low_freq_factor` (`float`, *optional*):
107
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
108
- `high_freq_factor` (`float`, *optional*):
109
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
110
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
111
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
112
- attention_dropout (`float`, *optional*, defaults to 0.0):
113
- The dropout ratio for the attention probabilities.
114
- moe_intermediate_size (`int`, *optional*, defaults to 1280):
115
- Intermediate size of the routed expert.
116
- num_experts_per_tok (`int`, *optional*, defaults to 8):
117
- number of experts per token.
118
- n_shared_experts (`int`, *optional*, defaults to 1):
119
- Number of shared experts.
120
- n_routed_experts (`int`, *optional*, defaults to 128):
121
- Number of routed experts.
122
- routed_scaling_factor (`float`, *optional*, defaults to 1.0):
123
- Scaling factor or routed experts.
124
- n_group (`int`, *optional*, defaults to 1):
125
- Number of groups for routed experts.
126
- topk_group (`int`, *optional*, defaults to 1):
127
- Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
128
- first_k_dense_replace (`int`, *optional*, defaults to 0):
129
- Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
130
- \--k dense layers--/
131
- norm_topk_prob (`bool`, *optional*, defaults to `True`):
132
- Whether to normalize the topk probabilities.
133
- use_qk_norm (`bool`, *optional*, defaults to `False`):
134
- Whether to use query-key normalization in the attention
135
- ```python
136
- >>> from transformers import SolarOpenModel, SolarOpenConfig
137
-
138
- >>> # Initializing a SolarOpen style configuration
139
- >>> configuration = SolarOpenConfig()
140
-
141
- >>> # Initializing a model from the SolarOpen style configuration
142
- >>> model = SolarOpenModel(configuration)
143
-
144
- >>> # Accessing the model configuration
145
- >>> configuration = model.config
146
- ```"""
147
-
148
- model_type = "solar_open"
149
- keys_to_ignore_at_inference = ["past_key_values"]
150
-
151
- # Default tensor parallel plan for base model `SolarOpen`
152
- base_model_tp_plan = {
153
- "layers.*.self_attn.q_proj": "colwise",
154
- "layers.*.self_attn.k_proj": "colwise",
155
- "layers.*.self_attn.v_proj": "colwise",
156
- "layers.*.self_attn.o_proj": "rowwise",
157
- "layers.*.mlp.experts.*.gate_proj": "colwise",
158
- "layers.*.mlp.experts.*.up_proj": "colwise",
159
- "layers.*.mlp.experts.*.down_proj": "rowwise",
160
- "layers.*.mlp.gate_proj": "colwise",
161
- "layers.*.mlp.up_proj": "colwise",
162
- "layers.*.mlp.down_proj": "rowwise",
163
- }
164
- base_model_pp_plan = {
165
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
166
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
167
- "norm": (["hidden_states"], ["hidden_states"]),
168
- }
169
-
170
- def __init__(
171
- self,
172
- vocab_size=196608,
173
- hidden_size=4096,
174
- intermediate_size=10240,
175
- num_hidden_layers=48,
176
- num_attention_heads=64,
177
- partial_rotary_factor=1.0,
178
- num_key_value_heads=8,
179
- hidden_act="silu",
180
- max_position_embeddings=131072,
181
- initializer_range=0.02,
182
- rms_norm_eps=1e-5,
183
- use_cache=True,
184
- tie_word_embeddings=False,
185
- rope_theta=1000000.0,
186
- rope_scaling=None,
187
- attention_bias=False,
188
- attention_dropout=0.0,
189
- moe_intermediate_size=1280,
190
- num_experts_per_tok=8,
191
- n_shared_experts=1,
192
- n_routed_experts=128,
193
- routed_scaling_factor=1.0,
194
- n_group=1,
195
- topk_group=1,
196
- first_k_dense_replace=0,
197
- norm_topk_prob=True,
198
- use_qk_norm=False,
199
- **kwargs,
200
- ):
201
- self.vocab_size = vocab_size
202
- self.max_position_embeddings = max_position_embeddings
203
- self.hidden_size = hidden_size
204
- self.intermediate_size = intermediate_size
205
- self.num_hidden_layers = num_hidden_layers
206
- self.num_attention_heads = num_attention_heads
207
- self.partial_rotary_factor = partial_rotary_factor
208
-
209
- self.num_key_value_heads = num_key_value_heads
210
- self.hidden_act = hidden_act
211
- self.initializer_range = initializer_range
212
- self.rms_norm_eps = rms_norm_eps
213
- self.use_cache = use_cache
214
- self.rope_theta = rope_theta
215
- self.rope_scaling = rope_scaling
216
- self.attention_bias = attention_bias
217
- self.attention_dropout = attention_dropout
218
- # Validate the correctness of rotary position embeddings parameters
219
- # BC: if there is a 'type' field, move it to 'rope_type'.
220
- if self.rope_scaling is not None and "type" in self.rope_scaling:
221
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
222
- rope_config_validation(self)
223
-
224
- # MoE arguments
225
- self.moe_intermediate_size = moe_intermediate_size
226
- self.num_experts_per_tok = num_experts_per_tok
227
- self.n_group = n_group
228
- self.topk_group = topk_group
229
- self.n_shared_experts = n_shared_experts
230
- self.n_routed_experts = n_routed_experts
231
- self.routed_scaling_factor = routed_scaling_factor
232
- self.first_k_dense_replace = first_k_dense_replace
233
- self.norm_topk_prob = norm_topk_prob
234
- self.use_qk_norm = use_qk_norm
235
-
236
- super().__init__(
237
- tie_word_embeddings=tie_word_embeddings,
238
- **kwargs,
239
- )
240
-
241
-
242
- __all__ = ["SolarOpenConfig"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_solar_open.py DELETED
@@ -1,605 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 Upstage AI.
3
- # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This file has been modified by Upstage AI including:
18
- # - Hybrid MoE Architecture: Replaced the standard dense structure with a depth-dependent Hybrid MoE, adding `SolarOpenMoE` and `SolarOpenTopkRouter` classes.
19
- # - RoPE Strategy: Changed the rotary position embedding strategy from GLM4's interleaved rotation to Llama-style block rotation (via modified `rotate_half`).
20
- # - Normalization Logic: Simplified the layer normalization structure by removing GLM4's extra post-operation norms and adding optional Query-Key Normalization (`use_qk_norm`).
21
- #
22
- # Based on code from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modeling_glm4.py
23
-
24
- from typing import Callable, Optional, Union
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- from torch import nn
29
-
30
- from transformers.activations import ACT2FN
31
- from transformers.cache_utils import Cache, DynamicCache
32
- from transformers.generation import GenerationMixin
33
- from transformers.integrations import use_kernel_forward_from_hub
34
- from transformers.masking_utils import create_causal_mask
35
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
- from transformers.modeling_layers import GradientCheckpointingLayer
37
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
38
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
- from transformers.processing_utils import Unpack
41
- from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
42
- from transformers.utils.deprecation import deprecate_kwarg
43
- from transformers.utils.generic import check_model_inputs
44
- from .configuration_solar_open import SolarOpenConfig
45
-
46
-
47
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
48
- """
49
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
50
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
51
- """
52
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
53
- if n_rep == 1:
54
- return hidden_states
55
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
56
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
57
-
58
-
59
- def eager_attention_forward(
60
- module: nn.Module,
61
- query: torch.Tensor,
62
- key: torch.Tensor,
63
- value: torch.Tensor,
64
- attention_mask: Optional[torch.Tensor],
65
- scaling: float,
66
- dropout: float = 0.0,
67
- **kwargs: Unpack[TransformersKwargs],
68
- ):
69
- key_states = repeat_kv(key, module.num_key_value_groups)
70
- value_states = repeat_kv(value, module.num_key_value_groups)
71
-
72
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
73
- if attention_mask is not None:
74
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
75
- attn_weights = attn_weights + causal_mask
76
-
77
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
78
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
79
- attn_output = torch.matmul(attn_weights, value_states)
80
- attn_output = attn_output.transpose(1, 2).contiguous()
81
-
82
- return attn_output, attn_weights
83
-
84
-
85
- def rotate_half(x):
86
- """Rotates half the hidden dims of the input."""
87
- x1 = x[..., : x.shape[-1] // 2]
88
- x2 = x[..., x.shape[-1] // 2 :]
89
- return torch.cat((-x2, x1), dim=-1)
90
-
91
-
92
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
93
- """Applies Rotary Position Embedding to the query and key tensors.
94
-
95
- Args:
96
- q (`torch.Tensor`): The query tensor.
97
- k (`torch.Tensor`): The key tensor.
98
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
99
- sin (`torch.Tensor`): The sine part of the rotary embedding.
100
- position_ids (`torch.Tensor`, *optional*):
101
- Deprecated and unused.
102
- unsqueeze_dim (`int`, *optional*, defaults to 1):
103
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
104
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
105
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
106
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
107
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
108
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
109
- Returns:
110
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
111
- """
112
- cos = cos.unsqueeze(unsqueeze_dim)
113
- sin = sin.unsqueeze(unsqueeze_dim)
114
-
115
- # Keep half or full tensor for later concatenation
116
- rotary_dim = cos.shape[-1]
117
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
118
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
119
-
120
- # Apply rotary embeddings on the first half or full tensor
121
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
122
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
123
-
124
- # Concatenate back to full shape
125
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
126
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
127
- return q_embed, k_embed
128
-
129
-
130
- class SolarOpenAttention(nn.Module):
131
- """Multi-headed attention from 'Attention Is All You Need' paper"""
132
-
133
- def __init__(self, config: SolarOpenConfig, layer_idx: Optional[int] = None):
134
- super().__init__()
135
- self.config = config
136
- self.layer_idx = layer_idx
137
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
138
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
139
- self.scaling = self.head_dim**-0.5
140
- self.rope_scaling = config.rope_scaling
141
- self.attention_dropout = config.attention_dropout
142
- self.is_causal = True
143
-
144
- self.q_proj = nn.Linear(
145
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
146
- )
147
- self.k_proj = nn.Linear(
148
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
149
- )
150
- self.v_proj = nn.Linear(
151
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
152
- )
153
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
154
- self.use_qk_norm = config.use_qk_norm
155
- if self.use_qk_norm:
156
- self.q_norm = SolarOpenRMSNorm(self.head_dim, eps=config.rms_norm_eps)
157
- self.k_norm = SolarOpenRMSNorm(self.head_dim, eps=config.rms_norm_eps)
158
-
159
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
160
- def forward(
161
- self,
162
- hidden_states: torch.Tensor,
163
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
164
- attention_mask: Optional[torch.Tensor],
165
- past_key_values: Optional[Cache] = None,
166
- cache_position: Optional[torch.LongTensor] = None,
167
- **kwargs: Unpack[FlashAttentionKwargs],
168
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
169
- input_shape = hidden_states.shape[:-1]
170
- hidden_shape = (*input_shape, -1, self.head_dim)
171
-
172
- query_states = self.q_proj(hidden_states).view(hidden_shape)
173
- key_states = self.k_proj(hidden_states).view(hidden_shape)
174
- value_states = self.v_proj(hidden_states).view(hidden_shape)
175
-
176
- if self.use_qk_norm: # main diff from Llama
177
- query_states = self.q_norm(query_states)
178
- key_states = self.k_norm(key_states)
179
-
180
- query_states = query_states.transpose(1, 2)
181
- key_states = key_states.transpose(1, 2)
182
- value_states = value_states.transpose(1, 2)
183
-
184
- cos, sin = position_embeddings
185
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
186
-
187
- if past_key_values is not None:
188
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
189
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
190
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
191
-
192
- attention_interface: Callable = eager_attention_forward
193
- if self.config._attn_implementation != "eager":
194
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
195
-
196
- attn_output, attn_weights = attention_interface(
197
- self,
198
- query_states,
199
- key_states,
200
- value_states,
201
- attention_mask,
202
- dropout=0.0 if not self.training else self.attention_dropout,
203
- scaling=self.scaling,
204
- **kwargs,
205
- )
206
-
207
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
208
- attn_output = self.o_proj(attn_output)
209
- return attn_output, attn_weights
210
-
211
-
212
- class SolarOpenMLP(nn.Module):
213
- def __init__(self, config, hidden_size=None, intermediate_size=None):
214
- super().__init__()
215
- self.config = config
216
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
217
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
218
-
219
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
220
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
221
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
222
- self.act_fn = ACT2FN[config.hidden_act]
223
-
224
- def forward(self, x):
225
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
226
- return down_proj
227
-
228
-
229
- class SolarOpenTopkRouter(nn.Module):
230
- def __init__(self, config: SolarOpenConfig):
231
- super().__init__()
232
- self.config = config
233
- self.top_k = config.num_experts_per_tok
234
- self.n_routed_experts = config.n_routed_experts
235
- self.routed_scaling_factor = config.routed_scaling_factor
236
- self.n_group = config.n_group
237
- self.topk_group = config.topk_group
238
- self.norm_topk_prob = config.norm_topk_prob
239
-
240
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
241
- self.e_score_correction_bias = nn.Parameter(
242
- torch.zeros((self.n_routed_experts), dtype=torch.float32))
243
-
244
- @torch.no_grad()
245
- def get_topk_indices(self, scores):
246
- scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
247
- group_scores = (
248
- scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
249
- .topk(2, dim=-1)[0]
250
- .sum(dim=-1)
251
- )
252
- group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
253
- group_mask = torch.zeros_like(group_scores)
254
- group_mask.scatter_(1, group_idx, 1)
255
- score_mask = (
256
- group_mask.unsqueeze(-1)
257
- .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
258
- .reshape(-1, self.n_routed_experts)
259
- )
260
- scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
261
- topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
262
- return topk_indices
263
-
264
- def forward(self, hidden_states):
265
- hidden_states = hidden_states.view(-1, self.config.hidden_size)
266
- router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
267
- scores = router_logits.sigmoid()
268
- topk_indices = self.get_topk_indices(scores)
269
- topk_weights = scores.gather(1, topk_indices)
270
- if self.norm_topk_prob:
271
- denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
272
- topk_weights /= denominator
273
- topk_weights = topk_weights * self.routed_scaling_factor
274
- return topk_indices, topk_weights
275
-
276
-
277
- @use_kernel_forward_from_hub("RMSNorm")
278
- class SolarOpenRMSNorm(nn.Module):
279
- def __init__(self, hidden_size, eps=1e-6):
280
- """
281
- SolarOpenRMSNorm is equivalent to T5LayerNorm
282
- """
283
- super().__init__()
284
- self.weight = nn.Parameter(torch.ones(hidden_size))
285
- self.variance_epsilon = eps
286
-
287
- def forward(self, hidden_states):
288
- input_dtype = hidden_states.dtype
289
- hidden_states = hidden_states.to(torch.float32)
290
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
291
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
292
- return self.weight * hidden_states.to(input_dtype)
293
-
294
- def extra_repr(self):
295
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
296
-
297
-
298
- class SolarOpenMoE(nn.Module):
299
- """
300
- A mixed expert module containing shared experts.
301
- """
302
-
303
- def __init__(self, config):
304
- super().__init__()
305
- self.config = config
306
- self.experts = nn.ModuleList(
307
- [
308
- SolarOpenMLP(config, intermediate_size=config.moe_intermediate_size)
309
- for _ in range(config.n_routed_experts)
310
- ]
311
- )
312
- self.gate = SolarOpenTopkRouter(config)
313
- self.shared_experts = SolarOpenMLP(
314
- config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
315
- )
316
-
317
- @torch.compiler.disable()
318
- def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
319
- r"""
320
- MoE forward pass that only executes selected experts.
321
- Uses @torch.compiler.disable() to allow dynamic shape operations.
322
- Requires --enforce-eager flag when serving with vLLM.
323
- """
324
- final_hidden_states = torch.zeros_like(hidden_states)
325
-
326
- for expert_idx in range(len(self.experts)):
327
- expert = self.experts[expert_idx]
328
-
329
- # Find positions where this expert was selected
330
- batch_idx, topk_pos = torch.where(topk_indices == expert_idx)
331
-
332
- if batch_idx.numel() == 0:
333
- continue
334
-
335
- # Extract only the tokens routed to this expert
336
- expert_input = hidden_states[batch_idx]
337
- expert_output = expert(expert_input)
338
-
339
- # Apply weights and accumulate results
340
- weights = topk_weights[batch_idx, topk_pos].unsqueeze(-1)
341
- final_hidden_states.index_add_(0, batch_idx, (expert_output * weights).to(hidden_states.dtype))
342
-
343
- return final_hidden_states
344
-
345
- def forward(self, hidden_states):
346
- residuals = hidden_states
347
- orig_shape = hidden_states.shape
348
- topk_indices, topk_weights = self.gate(hidden_states)
349
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
350
- hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
351
- hidden_states = hidden_states + self.shared_experts(residuals)
352
- return hidden_states
353
-
354
-
355
- class SolarOpenDecoderLayer(GradientCheckpointingLayer):
356
- def __init__(self, config: SolarOpenConfig, layer_idx: int):
357
- super().__init__()
358
- self.hidden_size = config.hidden_size
359
-
360
- self.self_attn = SolarOpenAttention(config=config, layer_idx=layer_idx)
361
-
362
- if layer_idx >= config.first_k_dense_replace:
363
- self.mlp = SolarOpenMoE(config)
364
- else:
365
- self.mlp = SolarOpenMLP(config)
366
-
367
- self.input_layernorm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
- self.post_attention_layernorm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
369
-
370
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
371
- def forward(
372
- self,
373
- hidden_states: torch.Tensor,
374
- attention_mask: Optional[torch.Tensor] = None,
375
- position_ids: Optional[torch.LongTensor] = None,
376
- past_key_values: Optional[Cache] = None,
377
- use_cache: Optional[bool] = False,
378
- cache_position: Optional[torch.LongTensor] = None,
379
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
380
- **kwargs: Unpack[TransformersKwargs],
381
- ) -> torch.Tensor:
382
- residual = hidden_states
383
- hidden_states = self.input_layernorm(hidden_states)
384
- # Self Attention
385
- hidden_states, _ = self.self_attn(
386
- hidden_states=hidden_states,
387
- attention_mask=attention_mask,
388
- position_ids=position_ids,
389
- past_key_values=past_key_values,
390
- use_cache=use_cache,
391
- cache_position=cache_position,
392
- position_embeddings=position_embeddings,
393
- **kwargs,
394
- )
395
- hidden_states = residual + hidden_states
396
-
397
- # Fully Connected
398
- residual = hidden_states
399
- hidden_states = self.post_attention_layernorm(hidden_states)
400
- hidden_states = self.mlp(hidden_states)
401
- hidden_states = residual + hidden_states
402
- return hidden_states
403
-
404
-
405
- @auto_docstring
406
- class SolarOpenPreTrainedModel(PreTrainedModel):
407
- config: SolarOpenConfig
408
- base_model_prefix = "model"
409
- supports_gradient_checkpointing = True
410
- _no_split_modules = ["SolarOpenDecoderLayer"]
411
- _skip_keys_device_placement = ["past_key_values"]
412
- _supports_flash_attn = True
413
- _supports_sdpa = True
414
- _supports_flex_attn = True
415
- _can_compile_fullgraph = False
416
- _supports_attention_backend = True
417
- _can_record_outputs = {
418
- "hidden_states": SolarOpenDecoderLayer,
419
- "attentions": SolarOpenAttention,
420
- }
421
-
422
- def _init_weights(self, module):
423
- super()._init_weights(module)
424
- if isinstance(module, SolarOpenTopkRouter):
425
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
426
-
427
-
428
- class SolarOpenRotaryEmbedding(nn.Module):
429
- inv_freq: torch.Tensor # fix linting for `register_buffer`
430
-
431
- def __init__(self, config: SolarOpenConfig, device=None):
432
- super().__init__()
433
- # BC: "rope_type" was originally "type"
434
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
435
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
436
- else:
437
- self.rope_type = "default"
438
- self.max_seq_len_cached = config.max_position_embeddings
439
- self.original_max_seq_len = config.max_position_embeddings
440
-
441
- self.config = config
442
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
443
-
444
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
445
- self.register_buffer("inv_freq", inv_freq, persistent=False)
446
- self.original_inv_freq = self.inv_freq
447
-
448
- @torch.no_grad()
449
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
450
- def forward(self, x, position_ids):
451
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
452
- position_ids_expanded = position_ids[:, None, :].float()
453
-
454
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
455
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
456
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
457
- emb = torch.cat((freqs, freqs), dim=-1)
458
- cos = emb.cos() * self.attention_scaling
459
- sin = emb.sin() * self.attention_scaling
460
-
461
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
462
-
463
-
464
- @auto_docstring
465
- class SolarOpenModel(SolarOpenPreTrainedModel):
466
- _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
467
-
468
- def __init__(self, config: SolarOpenConfig):
469
- super().__init__(config)
470
- self.padding_idx = config.pad_token_id
471
- self.vocab_size = config.vocab_size
472
-
473
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
474
- self.layers = nn.ModuleList(
475
- [SolarOpenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
476
- )
477
- self.norm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
478
- self.rotary_emb = SolarOpenRotaryEmbedding(config=config)
479
- self.gradient_checkpointing = False
480
-
481
- # Initialize weights and apply final processing
482
- self.post_init()
483
-
484
- @check_model_inputs()
485
- @auto_docstring
486
- def forward(
487
- self,
488
- input_ids: Optional[torch.LongTensor] = None,
489
- attention_mask: Optional[torch.Tensor] = None,
490
- position_ids: Optional[torch.LongTensor] = None,
491
- past_key_values: Optional[Cache] = None,
492
- inputs_embeds: Optional[torch.FloatTensor] = None,
493
- cache_position: Optional[torch.LongTensor] = None,
494
- use_cache: Optional[bool] = None,
495
- **kwargs: Unpack[TransformersKwargs],
496
- ) -> BaseModelOutputWithPast:
497
- if (input_ids is None) ^ (inputs_embeds is not None):
498
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
499
-
500
- if inputs_embeds is None:
501
- inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
502
-
503
- if use_cache and past_key_values is None:
504
- past_key_values = DynamicCache(config=self.config)
505
-
506
- if cache_position is None:
507
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
508
- cache_position: torch.Tensor = torch.arange(
509
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
510
- )
511
-
512
- if position_ids is None:
513
- position_ids = cache_position.unsqueeze(0)
514
-
515
- causal_mask = create_causal_mask(
516
- config=self.config,
517
- input_embeds=inputs_embeds,
518
- attention_mask=attention_mask,
519
- cache_position=cache_position,
520
- past_key_values=past_key_values,
521
- position_ids=position_ids,
522
- )
523
-
524
- hidden_states = inputs_embeds
525
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
526
-
527
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
528
- hidden_states = decoder_layer(
529
- hidden_states,
530
- attention_mask=causal_mask,
531
- position_ids=position_ids,
532
- past_key_values=past_key_values,
533
- cache_position=cache_position,
534
- position_embeddings=position_embeddings,
535
- **kwargs,
536
- )
537
-
538
- hidden_states = self.norm(hidden_states)
539
- return BaseModelOutputWithPast(
540
- last_hidden_state=hidden_states,
541
- past_key_values=past_key_values,
542
- )
543
-
544
-
545
- @auto_docstring
546
- class SolarOpenForCausalLM(SolarOpenPreTrainedModel, GenerationMixin):
547
- _tied_weights_keys = ["lm_head.weight"]
548
- _tp_plan = {"lm_head": "colwise_rep"}
549
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
550
-
551
- def __init__(self, config):
552
- super().__init__(config)
553
- self.model = SolarOpenModel(config)
554
- self.vocab_size = config.vocab_size
555
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
556
-
557
- # Initialize weights and apply final processing
558
- self.post_init()
559
-
560
- @can_return_tuple
561
- @auto_docstring
562
- def forward(
563
- self,
564
- input_ids: Optional[torch.LongTensor] = None,
565
- attention_mask: Optional[torch.Tensor] = None,
566
- position_ids: Optional[torch.LongTensor] = None,
567
- past_key_values: Optional[Cache] = None,
568
- inputs_embeds: Optional[torch.FloatTensor] = None,
569
- labels: Optional[torch.LongTensor] = None,
570
- use_cache: Optional[bool] = None,
571
- cache_position: Optional[torch.LongTensor] = None,
572
- logits_to_keep: Union[int, torch.Tensor] = 0,
573
- **kwargs: Unpack[TransformersKwargs],
574
- ) -> CausalLMOutputWithPast:
575
-
576
- outputs: BaseModelOutputWithPast = self.model(
577
- input_ids=input_ids,
578
- attention_mask=attention_mask,
579
- position_ids=position_ids,
580
- past_key_values=past_key_values,
581
- inputs_embeds=inputs_embeds,
582
- use_cache=use_cache,
583
- cache_position=cache_position,
584
- **kwargs,
585
- )
586
-
587
- hidden_states = outputs.last_hidden_state
588
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
589
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
590
- logits = self.lm_head(hidden_states[:, slice_indices, :])
591
-
592
- loss = None
593
- if labels is not None:
594
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
595
-
596
- return CausalLMOutputWithPast(
597
- loss=loss,
598
- logits=logits,
599
- past_key_values=outputs.past_key_values,
600
- hidden_states=outputs.hidden_states,
601
- attentions=outputs.attentions,
602
- )
603
-
604
-
605
- __all__ = ["SolarOpenPreTrainedModel", "SolarOpenModel", "SolarOpenForCausalLM"]