chheplo's picture
Update app.py
fc85f1a verified
"""
ThinkingDhenu CRSA Chat – LLaMA-Factory-style WebUI
"""
import os, threading, sys, re
import gradio as gr
import torch
from logger import log_qa
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
GenerationConfig,
)
MODEL_NAME = "KissanAI/ThinkingDhenu1-Extension-USA-research-preview"
DEFAULT_SYSTEM_PROMPT = "You are a helpful Agronomist."
# ---------- Model ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", trust_remote_code=True)
model.eval()
try:
model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
except Exception:
pass
DEVICE = next(model.parameters()).device
print(f"[Startup] Model on {DEVICE}", flush=True)
# ---------- Utility ----------
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
def _escape(text: str) -> str:
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_result(raw: str) -> str:
m = _THINK_RE.search(raw)
if m:
thinking = _escape(m.group(1).strip())
details = (
"<details><summary>🤔 Thinking… (click to expand)</summary>\n\n" + thinking + "\n\n</details>\n\n"
)
answer = _escape(raw[m.end():].strip())
return details + answer
return _escape(raw.strip())
# ---------- Helper ----------
def _apply_template(msgs):
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
prompt = ""
for m in msgs:
role, content = m["role"], m["content"]
if role == "system":
prompt += content.strip() + "\n\n"
elif role == "user":
prompt += f"<human>: {content.strip()}\n\n<assistant>: "
else:
prompt += f"{content.strip()}\n"
return prompt
def _history_to_pairs(msgs):
pairs = []
for m in msgs:
if m["role"] == "system":
continue
if m["role"] == "user":
pairs.append((m["content"], ""))
else:
if pairs:
pairs[-1] = (pairs[-1][0], m["content"])
return pairs
# ---------- Callbacks ----------
def user_send(user_text, hist, sys_prompt):
if hist is None:
hist = []
if not hist or hist[0]["role"] != "system" or hist[0]["content"] != sys_prompt:
hist = [{"role": "system", "content": sys_prompt}]
hist.append({"role": "user", "content": user_text})
return "", _history_to_pairs(hist), hist
def bot_reply(
hist, sys_prompt,
max_new_tokens, temperature, top_p, rep_penalty,
stream_output
):
if not hist:
hist = [{"role": "system", "content": sys_prompt}]
prompt_text = _apply_template(hist)
inputs = tokenizer(prompt_text, return_tensors="pt", padding=False)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=rep_penalty,
do_sample=True,
)
if stream_output:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True).start()
partial = ""
for tok in streamer: # ⇢ incremental updates
partial += tok
display = _format_result(partial)
pairs = _history_to_pairs(hist)
if pairs: # overwrite last assistant bubble
pairs[-1] = (pairs[-1][0], display + "▌")
yield pairs, hist # UI shows “typing…”
final_md = _format_result(partial)
pairs[-1] = (pairs[-1][0], final_md)
hist.append({"role": "assistant", "content": final_md})
log_qa(hist[-2]["content"], final_md)
yield pairs, hist
return # defensively end here
output = model.generate(**gen_kwargs)
raw = tokenizer.decode(output[0][input_ids.shape[-1]:],
skip_special_tokens=True)
reply_md = _format_result(raw)
hist.append({"role": "assistant", "content": reply_md})
log_qa(hist[-2]["content"], reply_md)
yield _history_to_pairs(hist), hist
def clear_chat():
return "", [], [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
# ---------- UI ----------
LOGO_URL = "file=dhenu_logo.webp" # or .png – must live in the repo root
CSS = f"""
/* ─────────────────────────────────────────────
Chatbot watermark
───────────────────────────────────────────── */
#chatbot {{
position: relative; /* anchor for the pseudo-element */
}}
#chatbot::before {{
content: "";
position: absolute;
top: 50%; left: 50%;
width: 200px; height: 200px; /* adjust size here */
transform: translate(-50%, -50%); /* centre it */
background: url('{LOGO_URL}') center / contain no-repeat;
opacity: 0.08; /* 0-1; raise if too faint */
pointer-events: none; /* let clicks pass through */
z-index: 0; /* stays behind message bubbles */
}}
/* keep your original bubble colours & sample-prompt layout */
#chatbot .message.user {{background: var(--color-primary-subdued); z-index:1;}}
#chatbot .message.bot {{background: var(--color-accent-subdued); z-index:1;}}
.sample-prompts button {{
margin: 2px 4px; /* keep your spacing */
font-size: 0.8rem; /* smaller text; adjust as you like */
font-weight: 400; /* 400 = “normal”; removes boldness */
flex: 0 0 220px; /* every button 220 px wide */
white-space: normal; /* let long text wrap inside */
text-align: center; /* centre multi-line text */
padding: 10px 12px; /* a bit more breathing room */
}}
"""
SAMPLE_QS = [
# General, high-level
"What are the advantages and disadvantages of using preemergence versus postemergence herbicides?",
"What practical steps can improve soil health for Iowa cornfields within one growing season?",
"What are the primary benefits of using cereal cover crops in rotation with summer crops?",
"Which intercropping combinations maximise corn yield while improving soil health in the Southeast?",
# Simple day-to-day queries
"Develop a strategy for Winston County farmers to mitigate the economic impacts of crop failures or market fluctuations.",
"How does the risk of SWD oviposition change with increasing Brix levels in grapes?",
"What are the critical factors to consider when developing a manure nutrient management plan for a corn-soybean rotation in Iowa?",
"What seed rate and row spacing should I use for hybrid field corn on one acre in Illinois?",
# Medium-depth, decision-making prompts
"For corn and soybean rotations, what are the most effective integrated pest management strategies for managing multiple pest species?",
"For bell pepper in Florida, compare fertigation versus broadcasting urea on yield and cost–benefit.",
"Draft a three-step plan to manage fall armyworm in corn with pheromone traps and bio-pesticides.",
"Suggest two cover crops after soybean harvest in Illinois that raise soil carbon and add income.",
# Deeper reasoning and climate-smart agriculture tests
"A 2 000-acre dryland wheat farm in eastern Washington gets erratic 12 in (305 mm) rainfall—rank the three biggest climate risks to winter wheat and justify the order.",
"On sandy-loam soil in the Texas High Plains, which saves more water—subsurface drip versus center-pivot LEPA irrigation for cotton—and what are the expected yield trade-offs?",
"In El Niño years, Iowa corn faces heightened tar spot risk: build an IPM calendar using a 7-day ensemble weather forecast.",
"If 3 t ac⁻¹ corn stover is incorporated annually into a Mollisol in Minnesota, how much soil organic carbon (t C ac⁻¹) could accumulate over 10 years?"
]
with gr.Blocks(title="ThinkingDhenu US Agriculture Extensions Chat", css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("<h2>🐮 ThinkingDhenu US Agriculture Extensions Chat</h2>")
gr.Markdown("<h4>Research Preview of reasoning model based on US Agricultural extension knowledge in collaboration with Extension Foundation</h4>")
with gr.Row():
# ---- Left panel ----
with gr.Column(scale=3):
chatbot = gr.Chatbot(elem_id="chatbot", height=580, label="Dialogue")
with gr.Row():
user_box = gr.Textbox(placeholder="Ask me anything about crops, soil, …", show_label=False, lines=2, container=True, autofocus=True, scale=6)
send_btn = gr.Button("Send", variant="primary", scale=1)
stop_btn = gr.Button("⏹ Stop", variant="stop", scale=1)
clear_btn = gr.Button("🔄 Clear conversation")
# ---- Right panel ----
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Inference Settings")
sys_prompt_in = gr.Textbox(DEFAULT_SYSTEM_PROMPT, label="System Prompt", lines=3)
max_tokens_sl = gr.Slider(2048, 4096, value=2048, step=1, label="Max new tokens")
temp_sl = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p_sl = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
rep_pen_sl = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty")
stream_ck = gr.Checkbox(True, label="Stream response")
with gr.Row():
with gr.Column(scale=4):
# Clickable sample prompts
gr.Markdown("#### 💡 Sample questions")
sample_btns = []
with gr.Row(elem_classes=["sample-prompts"]):
for q in SAMPLE_QS:
sample_btns.append(gr.Button(q))
# ----- State & wiring -----
state = gr.State([{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}])
# --- helper to attach send logic ---
def wire_send(trigger):
ev1 = trigger.then(user_send, [user_box, state, sys_prompt_in], [user_box, chatbot, state], queue=False)
ev2 = ev1.then(
bot_reply,
[state, sys_prompt_in, max_tokens_sl, temp_sl, top_p_sl, rep_pen_sl, stream_ck],
[chatbot, state],
)
return ev2 # generation event (for cancellation)
# Textbox submit and Send button
gen_events = []
gen_events.append(wire_send(user_box.submit(lambda x: x, user_box, user_box, queue=False)))
gen_events.append(wire_send(send_btn.click(lambda x: x, user_box, user_box, queue=False)))
# Sample prompt buttons
for btn, q in zip(sample_btns, SAMPLE_QS):
fill = btn.click(lambda _, q=q: q, None, user_box, queue=False)
gen_events.append(wire_send(fill))
# Clear conversation button
clear_btn.click(clear_chat, None, [user_box, chatbot, state], queue=False)
# ---- Stop button: cancels all in‑flight generation and resets the assistant bubble ----
def _abort():
# Return unchanged inputs (they'll be cleared in cancel chain), no outputs
return None
stop_btn.click(_abort, None, None, cancels=gen_events, queue=False).then(
lambda: ("", _history_to_pairs(state.value), state.value) if state.value else ("", [], state.value),
None,
[user_box, chatbot, state],
queue=False,
)
# ---------- Launch ----------
if __name__ == "__main__":
demo.queue(max_size=32).launch(show_api=False)