Upload inference.py
Browse files- inference.py +1 -10
inference.py
CHANGED
|
@@ -5,9 +5,7 @@ import re
|
|
| 5 |
import difflib
|
| 6 |
from utils import *
|
| 7 |
from config import *
|
| 8 |
-
from transformers import GPT2Config
|
| 9 |
-
from bitsandbytes.nn import Linear8bitLt
|
| 10 |
-
from bitsandbytes.optim import GlobalOptimManager
|
| 11 |
from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
|
| 12 |
from abctoolkit.transpose import Note_list, Pitch_sign_list
|
| 13 |
from abctoolkit.duration import calculate_bartext_duration
|
|
@@ -42,13 +40,6 @@ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
|
| 42 |
num_attention_heads=HIDDEN_SIZE // 64,
|
| 43 |
vocab_size=128)
|
| 44 |
|
| 45 |
-
quantization_config = BitsAndBytesConfig(
|
| 46 |
-
load_in_8bit=True,
|
| 47 |
-
llm_int8_skip_modules=["patch_embedding"],
|
| 48 |
-
bnb_4bit_use_double_quant=True # 双重量化进一步压缩
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
| 53 |
|
| 54 |
def download_model_weights():
|
|
|
|
| 5 |
import difflib
|
| 6 |
from utils import *
|
| 7 |
from config import *
|
| 8 |
+
from transformers import GPT2Config
|
|
|
|
|
|
|
| 9 |
from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
|
| 10 |
from abctoolkit.transpose import Note_list, Pitch_sign_list
|
| 11 |
from abctoolkit.duration import calculate_bartext_duration
|
|
|
|
| 40 |
num_attention_heads=HIDDEN_SIZE // 64,
|
| 41 |
vocab_size=128)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
| 44 |
|
| 45 |
def download_model_weights():
|