import flask from flask import request, jsonify # Use AutoModelForCausalLM for Decoder-only models like Qwen from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Initialize the Flask application app = flask.Flask(__name__) # Qwen1.5-0.5B-Chat Model ID model_id = "Qwen/Qwen1.5-0.5B-Chat" print(f"🔄 Loading {model_id} model...") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) # Load the model using the correct CausalLM class # Using bfloat16 for better memory/speed if a compatible GPU is available model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) # Set the device (GPU/CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"✅ {model_id} Model loaded successfully!") @app.route('/chat', methods=['POST']) def chat(): try: data = request.get_json() msg = data.get("message", "") if not msg: return jsonify({"error": "No message sent"}), 400 # --- Qwen1.5 Chat Template Formatting --- # Qwen models require input in the ChatML format. chat_history = [{"role": "user", "content": msg}] # apply_chat_template handles the specific formatting (e.g., <|im_start|>user\n...) formatted_prompt = tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=True ) # Tokenize the formatted prompt inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) # Generation configuration output = model.generate( **inputs, max_length=256, do_sample=True, top_p=0.8, temperature=0.6, # Set pad_token_id to eos_token_id, which is often necessary for Causal LMs pad_token_id=tokenizer.eos_token_id ) # Decode the full output full_reply = tokenizer.decode(output[0], skip_special_tokens=False) # --- Extract only the Generated Response --- # Qwen ChatML format uses '<|im_start|>assistant\n' before the response assistant_tag = "<|im_start|>assistant\n" if assistant_tag in full_reply: # Split the full reply and take the content after the assistant tag reply = full_reply.split(assistant_tag)[-1].strip() # Remove the end-of-message tag if it was generated if "<|im_end|>" in reply: reply = reply.split("<|im_end|>")[0].strip() else: # Fallback: Decode only the newly generated tokens reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() return jsonify({"reply": reply}) except Exception as e: # Catch any runtime errors return jsonify({"error": str(e)}), 500 if __name__ == "__main__": # Run the Flask app app.run(host='0.0.0.0', port=7860)