Spaces:
Build error
Build error
Commit ·
f0ac041
1
Parent(s): 6253bc5
Add generation configurations to chatbot interface
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ terminators = [
|
|
| 25 |
]
|
| 26 |
|
| 27 |
@spaces.GPU(duration=120)
|
| 28 |
-
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
|
| 29 |
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
| 30 |
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
| 31 |
|
|
@@ -60,22 +60,24 @@ def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
|
|
| 60 |
base_generation_kwargs = dict(
|
| 61 |
input_ids=base_input_ids,
|
| 62 |
streamer=base_text_streamer,
|
| 63 |
-
max_new_tokens=
|
| 64 |
eos_token_id=terminators,
|
| 65 |
pad_token_id=tokenizer.eos_token_id,
|
| 66 |
-
do_sample=True,
|
| 67 |
-
temperature=
|
| 68 |
-
top_p=
|
|
|
|
| 69 |
)
|
| 70 |
new_generation_kwargs = dict(
|
| 71 |
input_ids=new_input_ids,
|
| 72 |
streamer=new_text_streamer,
|
| 73 |
-
max_new_tokens=
|
| 74 |
eos_token_id=terminators,
|
| 75 |
pad_token_id=tokenizer.eos_token_id,
|
| 76 |
-
do_sample=True,
|
| 77 |
-
temperature=
|
| 78 |
-
top_p=
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
|
|
@@ -111,16 +113,21 @@ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
|
|
| 111 |
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
|
| 112 |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
|
| 113 |
with gr.Row(variant="panel"):
|
| 114 |
-
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True)
|
| 115 |
-
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True)
|
| 116 |
with gr.Row(variant="panel"):
|
| 117 |
with gr.Column(scale=1):
|
| 118 |
submit_btn = gr.Button(value="Generate", variant="primary")
|
| 119 |
clear_btn = gr.Button(value="Clear", variant="secondary")
|
| 120 |
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
|
| 125 |
|
| 126 |
demo.launch()
|
|
|
|
| 25 |
]
|
| 26 |
|
| 27 |
@spaces.GPU(duration=120)
|
| 28 |
+
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
|
| 29 |
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
| 30 |
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
| 31 |
|
|
|
|
| 60 |
base_generation_kwargs = dict(
|
| 61 |
input_ids=base_input_ids,
|
| 62 |
streamer=base_text_streamer,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
eos_token_id=terminators,
|
| 65 |
pad_token_id=tokenizer.eos_token_id,
|
| 66 |
+
do_sample=True if temperature > 0 else False,
|
| 67 |
+
temperature=temperature,
|
| 68 |
+
top_p=top_p,
|
| 69 |
+
repetition_penalty=repetition_penalty,
|
| 70 |
)
|
| 71 |
new_generation_kwargs = dict(
|
| 72 |
input_ids=new_input_ids,
|
| 73 |
streamer=new_text_streamer,
|
| 74 |
+
max_new_tokens=max_new_tokens,
|
| 75 |
eos_token_id=terminators,
|
| 76 |
pad_token_id=tokenizer.eos_token_id,
|
| 77 |
+
do_sample=True if temperature > 0 else False,
|
| 78 |
+
temperature=temperature,
|
| 79 |
+
top_p=top_p,
|
| 80 |
+
repetition_penalty=repetition_penalty,
|
| 81 |
)
|
| 82 |
|
| 83 |
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
|
|
|
|
| 113 |
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
|
| 114 |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
|
| 115 |
with gr.Row(variant="panel"):
|
| 116 |
+
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
|
| 117 |
+
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
|
| 118 |
with gr.Row(variant="panel"):
|
| 119 |
with gr.Column(scale=1):
|
| 120 |
submit_btn = gr.Button(value="Generate", variant="primary")
|
| 121 |
clear_btn = gr.Button(value="Clear", variant="secondary")
|
| 122 |
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
|
| 123 |
+
with gr.Accordion(label="Generation Configurations", open=False):
|
| 124 |
+
max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128)
|
| 125 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature", step=0.01)
|
| 126 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
|
| 127 |
+
repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
|
| 128 |
+
|
| 129 |
+
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
|
| 130 |
+
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
|
| 131 |
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
|
| 132 |
|
| 133 |
demo.launch()
|