|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import gc |
|
|
import os |
|
|
import datetime |
|
|
import time |
|
|
import spaces |
|
|
|
|
|
|
|
|
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B" |
|
|
MAX_NEW_TOKENS = 512 |
|
|
USE_GPU = True |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
print("경고: HF_TOKEN 환경 변수가 설정되지 않았습니다. 비공개 모델에 접근할 수 없을 수 있습니다.") |
|
|
|
|
|
|
|
|
print("--- Environment Setup ---") |
|
|
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu") |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"Running on device: {device}") |
|
|
print(f"Torch Threads: {torch.get_num_threads()}") |
|
|
print(f"HF_TOKEN 설정 여부: {'있음' if HF_TOKEN else '없음'}") |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 850px !important; |
|
|
margin: auto; |
|
|
} |
|
|
.gr-chat { |
|
|
border-radius: 10px; |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
.user-message { |
|
|
background-color: #f0f7ff !important; |
|
|
border-radius: 8px; |
|
|
} |
|
|
.assistant-message { |
|
|
background-color: #f9f9f9 !important; |
|
|
border-radius: 8px; |
|
|
} |
|
|
.gr-button.primary-button { |
|
|
background-color: #1f4e79 !important; |
|
|
} |
|
|
.gr-form { |
|
|
padding: 20px; |
|
|
border-radius: 10px; |
|
|
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); |
|
|
} |
|
|
#intro-message { |
|
|
text-align: center; |
|
|
margin-bottom: 20px; |
|
|
padding: 15px; |
|
|
background: linear-gradient(135deg, #e8f4ff 0%, #f0f7ff 100%); |
|
|
border-radius: 10px; |
|
|
border-left: 4px solid #1f4e79; |
|
|
} |
|
|
.footer { |
|
|
text-align: center; |
|
|
margin-top: 20px; |
|
|
font-size: 0.8em; |
|
|
color: #666; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
print(f"--- Loading Model: {MODEL_ID} ---") |
|
|
print("This might take a few minutes, especially on the first launch...") |
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
load_successful = False |
|
|
stop_token_ids_list = [] |
|
|
|
|
|
try: |
|
|
start_load_time = time.time() |
|
|
|
|
|
|
|
|
tokenizer_kwargs = { |
|
|
"trust_remote_code": True |
|
|
} |
|
|
|
|
|
|
|
|
if HF_TOKEN: |
|
|
tokenizer_kwargs["token"] = HF_TOKEN |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
**tokenizer_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"device_map": "auto" if device.type == "cuda" else "cpu", |
|
|
"torch_dtype": torch.float16 if device.type == "cuda" else torch.float32, |
|
|
} |
|
|
|
|
|
|
|
|
if HF_TOKEN: |
|
|
model_kwargs["token"] = HF_TOKEN |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
load_time = time.time() - start_load_time |
|
|
print(f"--- Model and Tokenizer Loaded Successfully in {load_time:.2f} seconds ---") |
|
|
load_successful = True |
|
|
|
|
|
|
|
|
stop_token_strings = ["<|endofturn|>", "<|stop|>"] |
|
|
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings] |
|
|
|
|
|
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids: |
|
|
temp_stop_ids.append(tokenizer.eos_token_id) |
|
|
elif tokenizer.eos_token_id is None: |
|
|
print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.") |
|
|
|
|
|
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None] |
|
|
|
|
|
if not stop_token_ids_list: |
|
|
print("Warning: Could not find any stop token IDs. Using default EOS if available, otherwise generation might not stop correctly.") |
|
|
if tokenizer.eos_token_id is not None: |
|
|
stop_token_ids_list = [tokenizer.eos_token_id] |
|
|
else: |
|
|
print("Error: No stop tokens found, including default EOS. Generation may run indefinitely.") |
|
|
|
|
|
print(f"Using Stop Token IDs: {stop_token_ids_list}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"!!! Error loading model: {e}") |
|
|
if 'model' in locals() and model is not None: del model |
|
|
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer |
|
|
gc.collect() |
|
|
|
|
|
raise gr.Error(f"Failed to load the model {MODEL_ID}. Cannot start the application. Error: {e}") |
|
|
|
|
|
|
|
|
def get_system_prompt(): |
|
|
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)") |
|
|
return ( |
|
|
f"- AI 언어모델의 이름은 \"CLOVA X\" 이며 네이버에서 만들었다.\n" |
|
|
f"- 오늘은 {current_date}이다.\n" |
|
|
f"- 사용자의 질문에 대해 친절하고 자세하게 한국어로 답변해야 한다." |
|
|
) |
|
|
|
|
|
|
|
|
def warmup_model(): |
|
|
if not load_successful or model is None or tokenizer is None: |
|
|
print("Skipping warmup: Model not loaded successfully.") |
|
|
return |
|
|
|
|
|
print("--- Starting Model Warm-up ---") |
|
|
try: |
|
|
start_warmup_time = time.time() |
|
|
warmup_message = "안녕하세요" |
|
|
system_prompt = get_system_prompt() |
|
|
warmup_chat = [ |
|
|
{"role": "tool_list", "content": ""}, |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": warmup_message} |
|
|
] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
warmup_chat, |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": 10, |
|
|
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, |
|
|
"do_sample": False |
|
|
} |
|
|
if stop_token_ids_list: |
|
|
gen_kwargs["eos_token_id"] = stop_token_ids_list |
|
|
else: |
|
|
print("Warmup Warning: No stop tokens defined for generation.") |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
del inputs |
|
|
del output_ids |
|
|
gc.collect() |
|
|
warmup_time = time.time() - start_warmup_time |
|
|
print(f"--- Model Warm-up Completed in {warmup_time:.2f} seconds ---") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"!!! Error during model warm-up: {e}") |
|
|
finally: |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def predict(message, history): |
|
|
""" |
|
|
Generates response using HyperCLOVAX. |
|
|
Assumes 'history' is in the Gradio 'messages' format: List[Dict]. |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
return "오류: 모델이 로드되지 않았습니다." |
|
|
|
|
|
system_prompt = get_system_prompt() |
|
|
|
|
|
|
|
|
chat_history_formatted = [ |
|
|
{"role": "tool_list", "content": ""}, |
|
|
{"role": "system", "content": system_prompt} |
|
|
] |
|
|
|
|
|
|
|
|
if isinstance(history, list): |
|
|
for user_msg, assistant_msg in history: |
|
|
chat_history_formatted.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
chat_history_formatted.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
|
|
|
chat_history_formatted.append({"role": "user", "content": message}) |
|
|
|
|
|
inputs = None |
|
|
output_ids = None |
|
|
|
|
|
try: |
|
|
inputs = tokenizer.apply_chat_template( |
|
|
chat_history_formatted, |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
print(f"\nInput tokens: {input_length}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"!!! Error applying chat template: {e}") |
|
|
return f"오류: 입력 형식을 처리하는 중 문제가 발생했습니다. ({e})" |
|
|
|
|
|
try: |
|
|
print("Generating response...") |
|
|
generation_start_time = time.time() |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": MAX_NEW_TOKENS, |
|
|
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, |
|
|
"do_sample": True, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
} |
|
|
if stop_token_ids_list: |
|
|
gen_kwargs["eos_token_id"] = stop_token_ids_list |
|
|
else: |
|
|
print("Generation Warning: No stop tokens defined.") |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
generation_time = time.time() - generation_start_time |
|
|
print(f"Generation complete in {generation_time:.2f} seconds.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"!!! Error during model generation: {e}") |
|
|
if inputs is not None: del inputs |
|
|
if output_ids is not None: del output_ids |
|
|
gc.collect() |
|
|
return f"오류: 응답을 생성하는 중 문제가 발생했습니다. ({e})" |
|
|
|
|
|
|
|
|
response = "오류: 응답 생성에 실패했습니다." |
|
|
if output_ids is not None: |
|
|
try: |
|
|
new_tokens = output_ids[0, input_length:] |
|
|
response = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
print(f"Output tokens: {len(new_tokens)}") |
|
|
del new_tokens |
|
|
except Exception as e: |
|
|
print(f"!!! Error decoding response: {e}") |
|
|
response = "오류: 응답을 디코딩하는 중 문제가 발생했습니다." |
|
|
|
|
|
|
|
|
if inputs is not None: del inputs |
|
|
if output_ids is not None: del output_ids |
|
|
gc.collect() |
|
|
print("Memory cleaned.") |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
print("--- Setting up Gradio Interface ---") |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
|
gr.Markdown(""" |
|
|
# NAVER hyperclovax: HyperCLOVAX-SEED-Text-Instruct-0.5B |
|
|
|
|
|
""", elem_id="intro-message") |
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
fn=predict, |
|
|
examples=[ |
|
|
["네이버 클로바X는 무엇인가요?"], |
|
|
["슈뢰딩거 방정식과 양자역학의 관계를 설명해주세요."], |
|
|
["딥러닝 모델 학습 과정을 단계별로 알려줘."], |
|
|
["제주도 여행 계획을 세우고 있는데, 3박 4일 추천 코스 좀 짜줄래?"], |
|
|
["한국 역사에서 가장 중요한 사건 5가지는 무엇인가요?"], |
|
|
["인공지능 윤리에 대해 설명해주세요."], |
|
|
], |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
with gr.Accordion("모델 정보", open=False): |
|
|
gr.Markdown(f""" |
|
|
- **모델**: {MODEL_ID} |
|
|
- **환경**: ZeroGPU 공유 환경에서 실행 중 |
|
|
- **토큰 제한**: 최대 생성 토큰 수는 {MAX_NEW_TOKENS}개로 제한됩니다. |
|
|
- **하드웨어**: {"GPU" if device.type == "cuda" else "CPU"} 환경에서 실행 중 |
|
|
""") |
|
|
|
|
|
gr.Markdown( |
|
|
"© 2025 네이버 HyperCLOVA X 데모 | Powered by Hugging Face & ZeroGPU", |
|
|
elem_classes="footer" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if load_successful: |
|
|
warmup_model() |
|
|
else: |
|
|
print("Skipping warm-up because model loading failed.") |
|
|
|
|
|
print("--- Launching Gradio App ---") |
|
|
demo.queue().launch( |
|
|
|
|
|
server_name="0.0.0.0" |
|
|
) |