ginipick's picture
Update app.py
52ba3fd verified
raw
history blame
11.8 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
import os
import datetime
import time
import spaces # Import spaces module for GPU acceleration
# --- Configuration ---
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B"
MAX_NEW_TOKENS = 512
USE_GPU = True # Enable GPU usage
# Hugging Face 토큰 설정 - 환경 변수에서 가져오기
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("경고: HF_TOKEN 환경 변수가 설정되지 않았습니다. 비공개 모델에 접근할 수 없을 수 있습니다.")
# --- Environment setup ---
print("--- Environment Setup ---")
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"Running on device: {device}")
print(f"Torch Threads: {torch.get_num_threads()}")
print(f"HF_TOKEN 설정 여부: {'있음' if HF_TOKEN else '없음'}")
# Custom CSS for improved UI
custom_css = """
.gradio-container {
max-width: 850px !important;
margin: auto;
}
.gr-chat {
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.user-message {
background-color: #f0f7ff !important;
border-radius: 8px;
}
.assistant-message {
background-color: #f9f9f9 !important;
border-radius: 8px;
}
.gr-button.primary-button {
background-color: #1f4e79 !important;
}
.gr-form {
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
}
#intro-message {
text-align: center;
margin-bottom: 20px;
padding: 15px;
background: linear-gradient(135deg, #e8f4ff 0%, #f0f7ff 100%);
border-radius: 10px;
border-left: 4px solid #1f4e79;
}
.footer {
text-align: center;
margin-top: 20px;
font-size: 0.8em;
color: #666;
}
"""
# --- Model and Tokenizer Loading ---
print(f"--- Loading Model: {MODEL_ID} ---")
print("This might take a few minutes, especially on the first launch...")
model = None
tokenizer = None
load_successful = False
stop_token_ids_list = [] # Initialize stop_token_ids_list
try:
start_load_time = time.time()
# 토크나이저 로딩
tokenizer_kwargs = {
"trust_remote_code": True
}
# HF_TOKEN이 설정되어 있으면 추가
if HF_TOKEN:
tokenizer_kwargs["token"] = HF_TOKEN
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
**tokenizer_kwargs
)
# 모델 로딩
model_kwargs = {
"trust_remote_code": True,
"device_map": "auto" if device.type == "cuda" else "cpu",
"torch_dtype": torch.float16 if device.type == "cuda" else torch.float32,
}
# HF_TOKEN이 설정되어 있으면 추가
if HF_TOKEN:
model_kwargs["token"] = HF_TOKEN
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
**model_kwargs
)
model.eval()
load_time = time.time() - start_load_time
print(f"--- Model and Tokenizer Loaded Successfully in {load_time:.2f} seconds ---")
load_successful = True
# --- Stop Token Configuration ---
stop_token_strings = ["<|endofturn|>", "<|stop|>"]
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
temp_stop_ids.append(tokenizer.eos_token_id)
elif tokenizer.eos_token_id is None:
print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.")
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]
if not stop_token_ids_list:
print("Warning: Could not find any stop token IDs. Using default EOS if available, otherwise generation might not stop correctly.")
if tokenizer.eos_token_id is not None:
stop_token_ids_list = [tokenizer.eos_token_id]
else:
print("Error: No stop tokens found, including default EOS. Generation may run indefinitely.")
print(f"Using Stop Token IDs: {stop_token_ids_list}")
except Exception as e:
print(f"!!! Error loading model: {e}")
if 'model' in locals() and model is not None: del model
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
gc.collect()
# Raise Gradio error to display in the Space UI if loading fails
raise gr.Error(f"Failed to load the model {MODEL_ID}. Cannot start the application. Error: {e}")
# --- System Prompt Definition ---
def get_system_prompt():
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
return (
f"- AI 언어모델의 이름은 \"CLOVA X\" 이며 네이버에서 만들었다.\n"
f"- 오늘은 {current_date}이다.\n"
f"- 사용자의 질문에 대해 친절하고 자세하게 한국어로 답변해야 한다."
)
# --- Warm-up Function ---
def warmup_model():
if not load_successful or model is None or tokenizer is None:
print("Skipping warmup: Model not loaded successfully.")
return
print("--- Starting Model Warm-up ---")
try:
start_warmup_time = time.time()
warmup_message = "안녕하세요"
system_prompt = get_system_prompt()
warmup_chat = [
{"role": "tool_list", "content": ""},
{"role": "system", "content": system_prompt},
{"role": "user", "content": warmup_message}
]
inputs = tokenizer.apply_chat_template(
warmup_chat,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(device)
# Check if stop_token_ids_list is empty and handle appropriately
gen_kwargs = {
"max_new_tokens": 10,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": False
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("Warmup Warning: No stop tokens defined for generation.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
del inputs
del output_ids
gc.collect()
warmup_time = time.time() - start_warmup_time
print(f"--- Model Warm-up Completed in {warmup_time:.2f} seconds ---")
except Exception as e:
print(f"!!! Error during model warm-up: {e}")
finally:
gc.collect()
# --- Inference Function with GPU decorator ---
@spaces.GPU() # Important: Add the spaces.GPU() decorator for ZeroGPU
def predict(message, history):
"""
Generates response using HyperCLOVAX.
Assumes 'history' is in the Gradio 'messages' format: List[Dict].
"""
if model is None or tokenizer is None:
return "오류: 모델이 로드되지 않았습니다."
system_prompt = get_system_prompt()
# Start with system prompt
chat_history_formatted = [
{"role": "tool_list", "content": ""}, # As required by model card
{"role": "system", "content": system_prompt}
]
# Process history based on Gradio ChatInterface format (list of tuples)
if isinstance(history, list):
for user_msg, assistant_msg in history:
chat_history_formatted.append({"role": "user", "content": user_msg})
if assistant_msg: # Check if not None or empty
chat_history_formatted.append({"role": "assistant", "content": assistant_msg})
# Append the latest user message
chat_history_formatted.append({"role": "user", "content": message})
inputs = None
output_ids = None
try:
inputs = tokenizer.apply_chat_template(
chat_history_formatted,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(device)
input_length = inputs['input_ids'].shape[1]
print(f"\nInput tokens: {input_length}")
except Exception as e:
print(f"!!! Error applying chat template: {e}")
return f"오류: 입력 형식을 처리하는 중 문제가 발생했습니다. ({e})"
try:
print("Generating response...")
generation_start_time = time.time()
# Prepare generation arguments, handling empty stop_token_ids_list
gen_kwargs = {
"max_new_tokens": MAX_NEW_TOKENS,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("Generation Warning: No stop tokens defined.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
generation_time = time.time() - generation_start_time
print(f"Generation complete in {generation_time:.2f} seconds.")
except Exception as e:
print(f"!!! Error during model generation: {e}")
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
return f"오류: 응답을 생성하는 중 문제가 발생했습니다. ({e})"
# Decode the response
response = "오류: 응답 생성에 실패했습니다."
if output_ids is not None:
try:
new_tokens = output_ids[0, input_length:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print(f"Output tokens: {len(new_tokens)}")
del new_tokens
except Exception as e:
print(f"!!! Error decoding response: {e}")
response = "오류: 응답을 디코딩하는 중 문제가 발생했습니다."
# Clean up memory
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
print("Memory cleaned.")
return response
# --- Gradio Interface Setup ---
print("--- Setting up Gradio Interface ---")
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("""
# NAVER hyperclovax: HyperCLOVAX-SEED-Text-Instruct-0.5B
""", elem_id="intro-message")
# Using standard ChatInterface (compatible with all Gradio versions)
chatbot = gr.ChatInterface(
fn=predict,
examples=[
["네이버 클로바X는 무엇인가요?"],
["슈뢰딩거 방정식과 양자역학의 관계를 설명해주세요."],
["딥러닝 모델 학습 과정을 단계별로 알려줘."],
["제주도 여행 계획을 세우고 있는데, 3박 4일 추천 코스 좀 짜줄래?"],
["한국 역사에서 가장 중요한 사건 5가지는 무엇인가요?"],
["인공지능 윤리에 대해 설명해주세요."],
],
cache_examples=False,
)
with gr.Accordion("모델 정보", open=False):
gr.Markdown(f"""
- **모델**: {MODEL_ID}
- **환경**: ZeroGPU 공유 환경에서 실행 중
- **토큰 제한**: 최대 생성 토큰 수는 {MAX_NEW_TOKENS}개로 제한됩니다.
- **하드웨어**: {"GPU" if device.type == "cuda" else "CPU"} 환경에서 실행 중
""")
gr.Markdown(
"© 2025 네이버 HyperCLOVA X 데모 | Powered by Hugging Face & ZeroGPU",
elem_classes="footer"
)
# --- Application Launch ---
if __name__ == "__main__":
if load_successful:
warmup_model()
else:
print("Skipping warm-up because model loading failed.")
print("--- Launching Gradio App ---")
demo.queue().launch(
# share=True # Uncomment for public link
server_name="0.0.0.0" # Enable external access
)