Bahaedev's picture
Update app.py
cb69e12 verified
raw
history blame
2.41 kB
import os
import threading
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
# =======================
# Load Secrets
# =======================
SYSTEM_PROMPT = os.environ.get(
"prompt",
"You are a placeholder Sovereign. No secrets found in environment."
)
# =======================
# Model Initialization
# =======================
MODEL_ID = "tiiuae/Falcon3-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Load model in 4-bit for faster CPU/GPU inference (requires bitsandbytes)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
load_in_4bit=True,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Create optimized text-generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
return_full_text=False,
max_new_tokens=256,
do_sample=True,
temperature=0.8,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
# =======================
# Core Chat Function
# =======================
def chat_fn(user_input: str) -> str:
prompt = f"### System:\n{SYSTEM_PROMPT}\n\n### User:\n{user_input}\n\n### Assistant:"
output = pipe(prompt)[0]["generated_text"].strip()
return output
# =======================
# Gradio UI
# =======================
def gradio_chat(user_input: str) -> str:
return chat_fn(user_input)
iface = gr.Interface(
fn=gradio_chat,
inputs=gr.Textbox(lines=5, placeholder="Enter your prompt…"),
outputs="text",
title="Prompt Cracking Challenge",
description="Does he really think he is the king?"
)
# =======================
# FastAPI for API access
# =======================
app = FastAPI(title="Prompt Cracking Challenge API")
class Request(BaseModel):
prompt: str
@app.post("/generate")
def generate(req: Request):
return {"response": chat_fn(req.prompt)}
# =======================
# Launch Both Servers
# =======================
def run_api():
port = int(os.environ.get("API_PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
# Start FastAPI in background thread
threading.Thread(target=run_api, daemon=True).start()
# Launch Gradio interface
iface.launch(server_name="0.0.0.0", server_port=7860)