|
|
import flask |
|
|
from flask import request, jsonify |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
app = flask.Flask(__name__) |
|
|
|
|
|
|
|
|
model_id = "Qwen/Qwen1.5-0.5B-Chat" |
|
|
|
|
|
print(f"🔄 Loading {model_id} model...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
print(f"✅ {model_id} Model loaded successfully!") |
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
try: |
|
|
data = request.get_json() |
|
|
msg = data.get("message", "") |
|
|
|
|
|
if not msg: |
|
|
return jsonify({"error": "No message sent"}), 400 |
|
|
|
|
|
|
|
|
|
|
|
chat_history = [{"role": "user", "content": msg}] |
|
|
|
|
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
|
chat_history, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_length=256, |
|
|
do_sample=True, |
|
|
top_p=0.8, |
|
|
temperature=0.6, |
|
|
|
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_reply = tokenizer.decode(output[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assistant_tag = "<|im_start|>assistant\n" |
|
|
|
|
|
if assistant_tag in full_reply: |
|
|
|
|
|
reply = full_reply.split(assistant_tag)[-1].strip() |
|
|
|
|
|
|
|
|
if "<|im_end|>" in reply: |
|
|
reply = reply.split("<|im_end|>")[0].strip() |
|
|
else: |
|
|
|
|
|
reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
|
|
|
|
|
return jsonify({"reply": reply}) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
app.run(host='0.0.0.0', port=7860) |