File size: 2,751 Bytes
d09bf4d
a5e6555
48cf71a
d09bf4d
48cf71a
a5e6555
48cf71a
d09bf4d
a5e6555
 
 
 
48cf71a
 
a5e6555
620d411
a5e6555
 
d09bf4d
 
48cf71a
 
d09bf4d
a5e6555
48cf71a
 
 
 
 
 
 
d09bf4d
 
 
48cf71a
d09bf4d
 
48cf71a
 
a5e6555
 
 
 
48cf71a
a5e6555
 
 
 
 
 
d09bf4d
 
 
a5e6555
 
 
 
 
 
48cf71a
 
d09bf4d
 
 
48cf71a
 
 
 
 
 
 
 
 
 
 
 
 
d09bf4d
a5e6555
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import flask
from flask import request, jsonify
from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
import torch
import warnings 

# Suppress minor warnings that occur on CPU runs
warnings.filterwarnings("ignore")

app = flask.Flask(__name__)

# ===========================
# LOAD MODEL (SmolLM-1.7B-Chat)
# This model is small (1.7B) and fully open-access.
# ===========================
model_id = "HuggingFaceTB/SmolLM-1.7B" 
print("🔄 Loading model...")

# CPU/GPU device set
device = 0 if torch.cuda.is_available() else -1
# Use float32 for CPU (or bfloat16 for GPU)
dtype = torch.float32 if device == -1 else torch.bfloat16

try:
    # 1. Load Tokenizer and set pad_token for stability
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        # Set pad_token to eos_token to fix generation warning/error
        tokenizer.pad_token = tokenizer.eos_token 

    # 2. Load Pipeline with the fixed tokenizer
    ai = pipeline(
        "text-generation", 
        model=model_id, 
        tokenizer=tokenizer, # Passing the configured tokenizer here
        max_new_tokens=200, 
        device=device,
        torch_dtype=dtype, 
        trust_remote_code=True
    )
    print("✅ Model loaded!")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    ai = None 

# ===========================
# CHAT API
# ===========================
@app.route('/chat', methods=['POST'])
def chat():
    if ai is None:
        return jsonify({"error": "Model initialization failed."}), 500
        
    try:
        data = request.get_json()
        msg = data.get("message", "")
        if not msg:
            return jsonify({"error": "No message sent"}), 400

        # Instruction Format: Using a simple template for this model
        prompt = f"User: {msg}\nAssistant:"
        
        output = ai(prompt)[0]["generated_text"]
        
        # Clean the output to extract only the model's reply
        # We split based on the 'Assistant:' tag in the prompt template
        if "Assistant:" in output:
             reply = output.split("Assistant:")[-1].strip()
        elif "User:" in output: # Sometimes the model repeats the prompt
             reply = output.split("User:")[0].strip()
        else:
             reply = output.strip()
             
        # Remove any remaining instruction markers from the start
        if reply.startswith(msg):
            reply = reply[len(msg):].strip()

        return jsonify({"reply": reply})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# ===========================
# RUN SERVER
# ===========================
if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)