Commit ·
95f0c53
1
Parent(s): 0a56f88
update
Browse files- .gitignore +2 -0
- app.py +65 -53
- scheduler.py +2 -0
- schemas.py +7 -3
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
dummy/
|
app.py
CHANGED
|
@@ -98,10 +98,6 @@ div#banner {
|
|
| 98 |
|
| 99 |
}
|
| 100 |
|
| 101 |
-
div#main-components {
|
| 102 |
-
align-items: flex-end;
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
div#steering-toggle {
|
| 106 |
padding-top: 8px;
|
| 107 |
padding-bottom: 8px;
|
|
@@ -233,13 +229,31 @@ async def get_endpoint_state():
|
|
| 233 |
yield "Server Error"
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
async def generate(
|
| 237 |
session_id: str, prompt: str, steering: bool, coeff: float,
|
| 238 |
-
max_new_tokens: int, top_p: float, temperature: float
|
| 239 |
):
|
| 240 |
req = UserRequest(
|
| 241 |
session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
|
| 242 |
-
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
| 243 |
)
|
| 244 |
|
| 245 |
instances[session_id].append(req)
|
|
@@ -260,28 +274,11 @@ async def generate(
|
|
| 260 |
else:
|
| 261 |
logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
|
| 262 |
raise gr.Error("API Server Error")
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
logger.info("Client session Error")
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
async def post_process(session_id, output):
|
| 269 |
-
req = instances[session_id].pop()
|
| 270 |
-
|
| 271 |
-
if "</think>" in output:
|
| 272 |
-
p = [p for p in output.partition("</think>") if p != ""]
|
| 273 |
-
reasoning = "".join(p[:-1])
|
| 274 |
-
if len(p) == 1:
|
| 275 |
-
answer = None
|
| 276 |
-
else:
|
| 277 |
-
answer = p[-1]
|
| 278 |
-
else:
|
| 279 |
-
answer = None
|
| 280 |
-
reasoning = output
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
return gr.update(interactive=True), gr.update(interactive=True)
|
| 285 |
|
| 286 |
|
| 287 |
async def output_feedback(session_id, feedback):
|
|
@@ -299,6 +296,10 @@ async def output_feedback(session_id, feedback):
|
|
| 299 |
logger.debug("Feedback submission error")
|
| 300 |
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
| 303 |
theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
|
| 304 |
|
|
@@ -308,20 +309,20 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
|
|
| 308 |
|
| 309 |
gr.HTML(HTML)
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
with gr.Row(elem_id="main-components"):
|
| 312 |
with gr.Column(scale=1):
|
| 313 |
-
@gr.render(inputs=endpoint_state, triggers=[endpoint_state.change])
|
| 314 |
-
def render_state(endpoint_state):
|
| 315 |
-
if endpoint_state == "Ready":
|
| 316 |
-
color = "green"
|
| 317 |
-
elif endpoint_state == "Server Error":
|
| 318 |
-
color = "red"
|
| 319 |
-
else:
|
| 320 |
-
color = "orange"
|
| 321 |
-
|
| 322 |
-
if endpoint_state != None:
|
| 323 |
-
gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state")
|
| 324 |
-
|
| 325 |
with gr.Row():
|
| 326 |
steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
|
| 327 |
coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
|
|
@@ -332,23 +333,28 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
|
|
| 332 |
return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
|
| 333 |
else:
|
| 334 |
return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
|
| 335 |
-
|
| 336 |
-
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 337 |
-
with gr.Row():
|
| 338 |
-
temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2)
|
| 339 |
-
top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2)
|
| 340 |
-
max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
|
| 341 |
|
| 342 |
input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
|
| 343 |
|
| 344 |
with gr.Row():
|
| 345 |
clear_btn = gr.ClearButton()
|
|
|
|
| 346 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
with gr.Column(scale=1):
|
| 349 |
output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
|
| 350 |
|
| 351 |
-
with gr.Row():
|
| 352 |
upvote_btn = gr.Button("👍 Upvote", interactive=False)
|
| 353 |
downvote_btn = gr.Button("👎 Downvote", interactive=False)
|
| 354 |
|
|
@@ -357,17 +363,23 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
|
|
| 357 |
gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
|
| 358 |
gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
|
| 359 |
|
| 360 |
-
|
| 361 |
-
|
|
|
|
| 362 |
return gr.update(interactive=False), gr.update(interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
|
| 365 |
-
generate_btn.click(
|
| 366 |
-
generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output
|
| 367 |
-
).success(
|
| 368 |
-
post_process, inputs=[session_id, output], outputs=[upvote_btn, downvote_btn]
|
| 369 |
)
|
| 370 |
|
|
|
|
|
|
|
|
|
|
| 371 |
upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
|
| 372 |
downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
|
| 373 |
|
|
|
|
| 98 |
|
| 99 |
}
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
div#steering-toggle {
|
| 102 |
padding-top: 8px;
|
| 103 |
padding-bottom: 8px;
|
|
|
|
| 229 |
yield "Server Error"
|
| 230 |
|
| 231 |
|
| 232 |
+
async def post_process(session_id, output):
|
| 233 |
+
req = instances[session_id].pop()
|
| 234 |
+
|
| 235 |
+
if "</think>" in output:
|
| 236 |
+
p = [p for p in output.partition("</think>") if p != ""]
|
| 237 |
+
reasoning = "".join(p[:-1])
|
| 238 |
+
if len(p) == 1:
|
| 239 |
+
answer = None
|
| 240 |
+
else:
|
| 241 |
+
answer = p[-1]
|
| 242 |
+
else:
|
| 243 |
+
answer = None
|
| 244 |
+
reasoning = output
|
| 245 |
+
|
| 246 |
+
steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer)
|
| 247 |
+
instances[session_id].append(steering_output)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
async def generate(
|
| 251 |
session_id: str, prompt: str, steering: bool, coeff: float,
|
| 252 |
+
max_new_tokens: int, top_p: float, temperature: float, vec_scaling: float
|
| 253 |
):
|
| 254 |
req = UserRequest(
|
| 255 |
session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
|
| 256 |
+
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, k=vec_scaling
|
| 257 |
)
|
| 258 |
|
| 259 |
instances[session_id].append(req)
|
|
|
|
| 274 |
else:
|
| 275 |
logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
|
| 276 |
raise gr.Error("API Server Error")
|
| 277 |
+
|
| 278 |
+
await post_process(session_id, generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
+
except:
|
| 281 |
+
logger.info("Client session error")
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
async def output_feedback(session_id, feedback):
|
|
|
|
| 296 |
logger.debug("Feedback submission error")
|
| 297 |
|
| 298 |
|
| 299 |
+
async def show_feedback_buttons(upvote_btn, downvote_btn):
|
| 300 |
+
return gr.update(interactive=True), gr.update(interactive=True)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
| 304 |
theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
|
| 305 |
|
|
|
|
| 309 |
|
| 310 |
gr.HTML(HTML)
|
| 311 |
|
| 312 |
+
@gr.render(inputs=endpoint_state, triggers=[endpoint_state.change])
|
| 313 |
+
def render_state(endpoint_state):
|
| 314 |
+
if endpoint_state == "Ready":
|
| 315 |
+
color = "green"
|
| 316 |
+
elif endpoint_state == "Server Error":
|
| 317 |
+
color = "red"
|
| 318 |
+
else:
|
| 319 |
+
color = "orange"
|
| 320 |
+
|
| 321 |
+
if endpoint_state != None:
|
| 322 |
+
gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state")
|
| 323 |
+
|
| 324 |
with gr.Row(elem_id="main-components"):
|
| 325 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
with gr.Row():
|
| 327 |
steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
|
| 328 |
coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
|
|
|
|
| 333 |
return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
|
| 334 |
else:
|
| 335 |
return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
|
| 338 |
|
| 339 |
with gr.Row():
|
| 340 |
clear_btn = gr.ClearButton()
|
| 341 |
+
stop_btn = gr.Button("Stop")
|
| 342 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 343 |
|
| 344 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 345 |
+
with gr.Row():
|
| 346 |
+
temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=1)
|
| 347 |
+
top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=1)
|
| 348 |
+
|
| 349 |
+
with gr.Row():
|
| 350 |
+
layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2)
|
| 351 |
+
max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
|
| 352 |
+
vec_scaling = gr.Number(CONFIG["k"], interactive=True, label="Vector scaling", scale=1)
|
| 353 |
+
|
| 354 |
with gr.Column(scale=1):
|
| 355 |
output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
|
| 356 |
|
| 357 |
+
with gr.Row():
|
| 358 |
upvote_btn = gr.Button("👍 Upvote", interactive=False)
|
| 359 |
downvote_btn = gr.Button("👎 Downvote", interactive=False)
|
| 360 |
|
|
|
|
| 363 |
gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
|
| 364 |
gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
|
| 365 |
|
| 366 |
+
|
| 367 |
+
@gr.on(triggers=[clear_btn.click, stop_btn.click], outputs=[upvote_btn, downvote_btn])
|
| 368 |
+
def clear_feedback_buttons():
|
| 369 |
return gr.update(interactive=False), gr.update(interactive=False)
|
| 370 |
+
|
| 371 |
+
@gr.on(triggers=[generate_btn.click], outputs=[upvote_btn, downvote_btn])
|
| 372 |
+
def show_feedback_buttons():
|
| 373 |
+
return gr.update(interactive=True), gr.update(interactive=True)
|
| 374 |
|
| 375 |
+
|
| 376 |
+
submission = generate_btn.click(
|
| 377 |
+
generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, vec_scaling], outputs=output
|
|
|
|
|
|
|
| 378 |
)
|
| 379 |
|
| 380 |
+
clear_btn.add([input_text, output])
|
| 381 |
+
stop_btn.click(None, None, None, cancels=[submission], queue=False)
|
| 382 |
+
|
| 383 |
upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
|
| 384 |
downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
|
| 385 |
|
scheduler.py
CHANGED
|
@@ -28,6 +28,8 @@ def load_scheduler():
|
|
| 28 |
"answer": {"_type": "Value", "dtype": "string"},
|
| 29 |
"upvote": {"_type": "Value", "dtype": "bool"},
|
| 30 |
"timestamp": {"_type": "Value", "dtype": "string"},
|
|
|
|
|
|
|
| 31 |
}
|
| 32 |
)
|
| 33 |
|
|
|
|
| 28 |
"answer": {"_type": "Value", "dtype": "string"},
|
| 29 |
"upvote": {"_type": "Value", "dtype": "bool"},
|
| 30 |
"timestamp": {"_type": "Value", "dtype": "string"},
|
| 31 |
+
"layer": {"_type": "Value", "dtype": "int64"},
|
| 32 |
+
"k": {"_type": "Value", "dtype": "float64"},
|
| 33 |
}
|
| 34 |
)
|
| 35 |
|
schemas.py
CHANGED
|
@@ -7,7 +7,8 @@ CONFIG = {
|
|
| 7 |
"max_new_tokens": 3048,
|
| 8 |
"top_p": 0.95,
|
| 9 |
"temperature": 0.6,
|
| 10 |
-
"k": 200
|
|
|
|
| 11 |
}
|
| 12 |
|
| 13 |
class UserRequest(BaseModel):
|
|
@@ -18,13 +19,16 @@ class UserRequest(BaseModel):
|
|
| 18 |
max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
|
| 19 |
top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
|
| 20 |
temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def get_api_format(self):
|
| 23 |
return {
|
| 24 |
"prompt": self.prompt,
|
| 25 |
"steering": self.steering,
|
| 26 |
"coeff": self.coeff,
|
| 27 |
-
"k":
|
|
|
|
| 28 |
"generation_config": {
|
| 29 |
"max_new_tokens": self.max_new_tokens,
|
| 30 |
"top_p": self.top_p,
|
|
@@ -36,6 +40,6 @@ class UserRequest(BaseModel):
|
|
| 36 |
class SteeringOutput(UserRequest):
|
| 37 |
max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
|
| 38 |
reasoning: str = None
|
| 39 |
-
answer: str = None
|
| 40 |
upvote: Optional[bool] = None
|
| 41 |
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
|
|
| 7 |
"max_new_tokens": 3048,
|
| 8 |
"top_p": 0.95,
|
| 9 |
"temperature": 0.6,
|
| 10 |
+
"k": 200,
|
| 11 |
+
"layer": 25
|
| 12 |
}
|
| 13 |
|
| 14 |
class UserRequest(BaseModel):
|
|
|
|
| 19 |
max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
|
| 20 |
top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
|
| 21 |
temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
|
| 22 |
+
k: float = Field(CONFIG["k"])
|
| 23 |
+
layer: int = Field(CONFIG["layer"])
|
| 24 |
|
| 25 |
def get_api_format(self):
|
| 26 |
return {
|
| 27 |
"prompt": self.prompt,
|
| 28 |
"steering": self.steering,
|
| 29 |
"coeff": self.coeff,
|
| 30 |
+
"k": self.k,
|
| 31 |
+
"layer": self.layer,
|
| 32 |
"generation_config": {
|
| 33 |
"max_new_tokens": self.max_new_tokens,
|
| 34 |
"top_p": self.top_p,
|
|
|
|
| 40 |
class SteeringOutput(UserRequest):
|
| 41 |
max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
|
| 42 |
reasoning: str = None
|
| 43 |
+
answer: Optional[str] = None
|
| 44 |
upvote: Optional[bool] = None
|
| 45 |
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|