Spaces:
Runtime error
Runtime error
| """ | |
| ThinkingDhenu CRSA Chat – LLaMA-Factory-style WebUI | |
| """ | |
| import os, threading, sys, re | |
| import gradio as gr | |
| import torch | |
| from logger import log_qa | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| GenerationConfig, | |
| ) | |
| MODEL_NAME = "KissanAI/ThinkingDhenu1-Extension-USA-research-preview" | |
| DEFAULT_SYSTEM_PROMPT = "You are a helpful Agronomist." | |
| # ---------- Model ---------- | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False, trust_remote_code=True) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", trust_remote_code=True) | |
| model.eval() | |
| try: | |
| model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
| except Exception: | |
| pass | |
| DEVICE = next(model.parameters()).device | |
| print(f"[Startup] Model on {DEVICE}", flush=True) | |
| # ---------- Utility ---------- | |
| _THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE) | |
| def _escape(text: str) -> str: | |
| return text.replace("<", "<").replace(">", ">") | |
| def _format_result(raw: str) -> str: | |
| m = _THINK_RE.search(raw) | |
| if m: | |
| thinking = _escape(m.group(1).strip()) | |
| details = ( | |
| "<details><summary>🤔 Thinking… (click to expand)</summary>\n\n" + thinking + "\n\n</details>\n\n" | |
| ) | |
| answer = _escape(raw[m.end():].strip()) | |
| return details + answer | |
| return _escape(raw.strip()) | |
| # ---------- Helper ---------- | |
| def _apply_template(msgs): | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) | |
| prompt = "" | |
| for m in msgs: | |
| role, content = m["role"], m["content"] | |
| if role == "system": | |
| prompt += content.strip() + "\n\n" | |
| elif role == "user": | |
| prompt += f"<human>: {content.strip()}\n\n<assistant>: " | |
| else: | |
| prompt += f"{content.strip()}\n" | |
| return prompt | |
| def _history_to_pairs(msgs): | |
| pairs = [] | |
| for m in msgs: | |
| if m["role"] == "system": | |
| continue | |
| if m["role"] == "user": | |
| pairs.append((m["content"], "")) | |
| else: | |
| if pairs: | |
| pairs[-1] = (pairs[-1][0], m["content"]) | |
| return pairs | |
| # ---------- Callbacks ---------- | |
| def user_send(user_text, hist, sys_prompt): | |
| if hist is None: | |
| hist = [] | |
| if not hist or hist[0]["role"] != "system" or hist[0]["content"] != sys_prompt: | |
| hist = [{"role": "system", "content": sys_prompt}] | |
| hist.append({"role": "user", "content": user_text}) | |
| return "", _history_to_pairs(hist), hist | |
| def bot_reply( | |
| hist, sys_prompt, | |
| max_new_tokens, temperature, top_p, rep_penalty, | |
| stream_output | |
| ): | |
| if not hist: | |
| hist = [{"role": "system", "content": sys_prompt}] | |
| prompt_text = _apply_template(hist) | |
| inputs = tokenizer(prompt_text, return_tensors="pt", padding=False) | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=rep_penalty, | |
| do_sample=True, | |
| ) | |
| if stream_output: | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs["streamer"] = streamer | |
| threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True).start() | |
| partial = "" | |
| for tok in streamer: # ⇢ incremental updates | |
| partial += tok | |
| display = _format_result(partial) | |
| pairs = _history_to_pairs(hist) | |
| if pairs: # overwrite last assistant bubble | |
| pairs[-1] = (pairs[-1][0], display + "▌") | |
| yield pairs, hist # UI shows “typing…” | |
| final_md = _format_result(partial) | |
| pairs[-1] = (pairs[-1][0], final_md) | |
| hist.append({"role": "assistant", "content": final_md}) | |
| log_qa(hist[-2]["content"], final_md) | |
| yield pairs, hist | |
| return # defensively end here | |
| output = model.generate(**gen_kwargs) | |
| raw = tokenizer.decode(output[0][input_ids.shape[-1]:], | |
| skip_special_tokens=True) | |
| reply_md = _format_result(raw) | |
| hist.append({"role": "assistant", "content": reply_md}) | |
| log_qa(hist[-2]["content"], reply_md) | |
| yield _history_to_pairs(hist), hist | |
| def clear_chat(): | |
| return "", [], [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] | |
| # ---------- UI ---------- | |
| LOGO_URL = "file=dhenu_logo.webp" # or .png – must live in the repo root | |
| CSS = f""" | |
| /* ───────────────────────────────────────────── | |
| Chatbot watermark | |
| ───────────────────────────────────────────── */ | |
| #chatbot {{ | |
| position: relative; /* anchor for the pseudo-element */ | |
| }} | |
| #chatbot::before {{ | |
| content: ""; | |
| position: absolute; | |
| top: 50%; left: 50%; | |
| width: 200px; height: 200px; /* adjust size here */ | |
| transform: translate(-50%, -50%); /* centre it */ | |
| background: url('{LOGO_URL}') center / contain no-repeat; | |
| opacity: 0.08; /* 0-1; raise if too faint */ | |
| pointer-events: none; /* let clicks pass through */ | |
| z-index: 0; /* stays behind message bubbles */ | |
| }} | |
| /* keep your original bubble colours & sample-prompt layout */ | |
| #chatbot .message.user {{background: var(--color-primary-subdued); z-index:1;}} | |
| #chatbot .message.bot {{background: var(--color-accent-subdued); z-index:1;}} | |
| .sample-prompts button {{ | |
| margin: 2px 4px; /* keep your spacing */ | |
| font-size: 0.8rem; /* smaller text; adjust as you like */ | |
| font-weight: 400; /* 400 = “normal”; removes boldness */ | |
| flex: 0 0 220px; /* every button 220 px wide */ | |
| white-space: normal; /* let long text wrap inside */ | |
| text-align: center; /* centre multi-line text */ | |
| padding: 10px 12px; /* a bit more breathing room */ | |
| }} | |
| """ | |
| SAMPLE_QS = [ | |
| # General, high-level | |
| "What are the advantages and disadvantages of using preemergence versus postemergence herbicides?", | |
| "What practical steps can improve soil health for Iowa cornfields within one growing season?", | |
| "What are the primary benefits of using cereal cover crops in rotation with summer crops?", | |
| "Which intercropping combinations maximise corn yield while improving soil health in the Southeast?", | |
| # Simple day-to-day queries | |
| "Develop a strategy for Winston County farmers to mitigate the economic impacts of crop failures or market fluctuations.", | |
| "How does the risk of SWD oviposition change with increasing Brix levels in grapes?", | |
| "What are the critical factors to consider when developing a manure nutrient management plan for a corn-soybean rotation in Iowa?", | |
| "What seed rate and row spacing should I use for hybrid field corn on one acre in Illinois?", | |
| # Medium-depth, decision-making prompts | |
| "For corn and soybean rotations, what are the most effective integrated pest management strategies for managing multiple pest species?", | |
| "For bell pepper in Florida, compare fertigation versus broadcasting urea on yield and cost–benefit.", | |
| "Draft a three-step plan to manage fall armyworm in corn with pheromone traps and bio-pesticides.", | |
| "Suggest two cover crops after soybean harvest in Illinois that raise soil carbon and add income.", | |
| # Deeper reasoning and climate-smart agriculture tests | |
| "A 2 000-acre dryland wheat farm in eastern Washington gets erratic 12 in (305 mm) rainfall—rank the three biggest climate risks to winter wheat and justify the order.", | |
| "On sandy-loam soil in the Texas High Plains, which saves more water—subsurface drip versus center-pivot LEPA irrigation for cotton—and what are the expected yield trade-offs?", | |
| "In El Niño years, Iowa corn faces heightened tar spot risk: build an IPM calendar using a 7-day ensemble weather forecast.", | |
| "If 3 t ac⁻¹ corn stover is incorporated annually into a Mollisol in Minnesota, how much soil organic carbon (t C ac⁻¹) could accumulate over 10 years?" | |
| ] | |
| with gr.Blocks(title="ThinkingDhenu US Agriculture Extensions Chat", css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("<h2>🐮 ThinkingDhenu US Agriculture Extensions Chat</h2>") | |
| gr.Markdown("<h4>Research Preview of reasoning model based on US Agricultural extension knowledge in collaboration with Extension Foundation</h4>") | |
| with gr.Row(): | |
| # ---- Left panel ---- | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(elem_id="chatbot", height=580, label="Dialogue") | |
| with gr.Row(): | |
| user_box = gr.Textbox(placeholder="Ask me anything about crops, soil, …", show_label=False, lines=2, container=True, autofocus=True, scale=6) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| stop_btn = gr.Button("⏹ Stop", variant="stop", scale=1) | |
| clear_btn = gr.Button("🔄 Clear conversation") | |
| # ---- Right panel ---- | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Inference Settings") | |
| sys_prompt_in = gr.Textbox(DEFAULT_SYSTEM_PROMPT, label="System Prompt", lines=3) | |
| max_tokens_sl = gr.Slider(2048, 4096, value=2048, step=1, label="Max new tokens") | |
| temp_sl = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p_sl = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") | |
| rep_pen_sl = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty") | |
| stream_ck = gr.Checkbox(True, label="Stream response") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| # Clickable sample prompts | |
| gr.Markdown("#### 💡 Sample questions") | |
| sample_btns = [] | |
| with gr.Row(elem_classes=["sample-prompts"]): | |
| for q in SAMPLE_QS: | |
| sample_btns.append(gr.Button(q)) | |
| # ----- State & wiring ----- | |
| state = gr.State([{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]) | |
| # --- helper to attach send logic --- | |
| def wire_send(trigger): | |
| ev1 = trigger.then(user_send, [user_box, state, sys_prompt_in], [user_box, chatbot, state], queue=False) | |
| ev2 = ev1.then( | |
| bot_reply, | |
| [state, sys_prompt_in, max_tokens_sl, temp_sl, top_p_sl, rep_pen_sl, stream_ck], | |
| [chatbot, state], | |
| ) | |
| return ev2 # generation event (for cancellation) | |
| # Textbox submit and Send button | |
| gen_events = [] | |
| gen_events.append(wire_send(user_box.submit(lambda x: x, user_box, user_box, queue=False))) | |
| gen_events.append(wire_send(send_btn.click(lambda x: x, user_box, user_box, queue=False))) | |
| # Sample prompt buttons | |
| for btn, q in zip(sample_btns, SAMPLE_QS): | |
| fill = btn.click(lambda _, q=q: q, None, user_box, queue=False) | |
| gen_events.append(wire_send(fill)) | |
| # Clear conversation button | |
| clear_btn.click(clear_chat, None, [user_box, chatbot, state], queue=False) | |
| # ---- Stop button: cancels all in‑flight generation and resets the assistant bubble ---- | |
| def _abort(): | |
| # Return unchanged inputs (they'll be cleared in cancel chain), no outputs | |
| return None | |
| stop_btn.click(_abort, None, None, cancels=gen_events, queue=False).then( | |
| lambda: ("", _history_to_pairs(state.value), state.value) if state.value else ("", [], state.value), | |
| None, | |
| [user_box, chatbot, state], | |
| queue=False, | |
| ) | |
| # ---------- Launch ---------- | |
| if __name__ == "__main__": | |
| demo.queue(max_size=32).launch(show_api=False) | |