Smllm / 3app.py
ghosthets's picture
Rename app.py to 3app.py
5715944 verified
import flask
from flask import request, jsonify
from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
import torch
import warnings
# Suppress minor warnings that occur on CPU runs
warnings.filterwarnings("ignore")
app = flask.Flask(__name__)
# ===========================
# LOAD MODEL (SmolLM-1.7B-Chat)
# This model is small (1.7B) and fully open-access.
# ===========================
model_id = "HuggingFaceTB/SmolLM-1.7B"
print("🔄 Loading model...")
# CPU/GPU device set
device = 0 if torch.cuda.is_available() else -1
# Use float32 for CPU (or bfloat16 for GPU)
dtype = torch.float32 if device == -1 else torch.bfloat16
try:
# 1. Load Tokenizer and set pad_token for stability
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
# Set pad_token to eos_token to fix generation warning/error
tokenizer.pad_token = tokenizer.eos_token
# 2. Load Pipeline with the fixed tokenizer
ai = pipeline(
"text-generation",
model=model_id,
tokenizer=tokenizer, # Passing the configured tokenizer here
max_new_tokens=200,
device=device,
torch_dtype=dtype,
trust_remote_code=True
)
print("✅ Model loaded!")
except Exception as e:
print(f"❌ Error loading model: {e}")
ai = None
# ===========================
# CHAT API
# ===========================
@app.route('/chat', methods=['POST'])
def chat():
if ai is None:
return jsonify({"error": "Model initialization failed."}), 500
try:
data = request.get_json()
msg = data.get("message", "")
if not msg:
return jsonify({"error": "No message sent"}), 400
# Instruction Format: Using a simple template for this model
prompt = f"User: {msg}\nAssistant:"
output = ai(prompt)[0]["generated_text"]
# Clean the output to extract only the model's reply
# We split based on the 'Assistant:' tag in the prompt template
if "Assistant:" in output:
reply = output.split("Assistant:")[-1].strip()
elif "User:" in output: # Sometimes the model repeats the prompt
reply = output.split("User:")[0].strip()
else:
reply = output.strip()
# Remove any remaining instruction markers from the start
if reply.startswith(msg):
reply = reply[len(msg):].strip()
return jsonify({"reply": reply})
except Exception as e:
return jsonify({"error": str(e)}), 500
# ===========================
# RUN SERVER
# ===========================
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)