hannahcyberey commited on
Commit
95f0c53
·
1 Parent(s): 0a56f88
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +65 -53
  3. scheduler.py +2 -0
  4. schemas.py +7 -3
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ dummy/
app.py CHANGED
@@ -98,10 +98,6 @@ div#banner {
98
 
99
  }
100
 
101
- div#main-components {
102
- align-items: flex-end;
103
- }
104
-
105
  div#steering-toggle {
106
  padding-top: 8px;
107
  padding-bottom: 8px;
@@ -233,13 +229,31 @@ async def get_endpoint_state():
233
  yield "Server Error"
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  async def generate(
237
  session_id: str, prompt: str, steering: bool, coeff: float,
238
- max_new_tokens: int, top_p: float, temperature: float
239
  ):
240
  req = UserRequest(
241
  session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
242
- max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
243
  )
244
 
245
  instances[session_id].append(req)
@@ -260,28 +274,11 @@ async def generate(
260
  else:
261
  logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
262
  raise gr.Error("API Server Error")
263
-
264
- except:
265
- logger.info("Client session Error")
266
-
267
-
268
- async def post_process(session_id, output):
269
- req = instances[session_id].pop()
270
-
271
- if "</think>" in output:
272
- p = [p for p in output.partition("</think>") if p != ""]
273
- reasoning = "".join(p[:-1])
274
- if len(p) == 1:
275
- answer = None
276
- else:
277
- answer = p[-1]
278
- else:
279
- answer = None
280
- reasoning = output
281
 
282
- steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer)
283
- instances[session_id].append(steering_output)
284
- return gr.update(interactive=True), gr.update(interactive=True)
285
 
286
 
287
  async def output_feedback(session_id, feedback):
@@ -299,6 +296,10 @@ async def output_feedback(session_id, feedback):
299
  logger.debug("Feedback submission error")
300
 
301
 
 
 
 
 
302
  gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
303
  theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
304
 
@@ -308,20 +309,20 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
308
 
309
  gr.HTML(HTML)
310
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  with gr.Row(elem_id="main-components"):
312
  with gr.Column(scale=1):
313
- @gr.render(inputs=endpoint_state, triggers=[endpoint_state.change])
314
- def render_state(endpoint_state):
315
- if endpoint_state == "Ready":
316
- color = "green"
317
- elif endpoint_state == "Server Error":
318
- color = "red"
319
- else:
320
- color = "orange"
321
-
322
- if endpoint_state != None:
323
- gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state")
324
-
325
  with gr.Row():
326
  steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
327
  coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
@@ -332,23 +333,28 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
332
  return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
333
  else:
334
  return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
335
-
336
- with gr.Accordion("⚙️ Advanced Settings", open=False):
337
- with gr.Row():
338
- temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2)
339
- top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2)
340
- max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
341
 
342
  input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
343
 
344
  with gr.Row():
345
  clear_btn = gr.ClearButton()
 
346
  generate_btn = gr.Button("Generate", variant="primary")
347
 
 
 
 
 
 
 
 
 
 
 
348
  with gr.Column(scale=1):
349
  output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
350
 
351
- with gr.Row():
352
  upvote_btn = gr.Button("👍 Upvote", interactive=False)
353
  downvote_btn = gr.Button("👎 Downvote", interactive=False)
354
 
@@ -357,17 +363,23 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
357
  gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
358
  gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
359
 
360
- @gr.on(triggers=[clear_btn.click], outputs=[upvote_btn, downvote_btn])
361
- def clear():
 
362
  return gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
363
 
364
- clear_btn.add([input_text, output])
365
- generate_btn.click(
366
- generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output
367
- ).success(
368
- post_process, inputs=[session_id, output], outputs=[upvote_btn, downvote_btn]
369
  )
370
 
 
 
 
371
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
372
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
373
 
 
98
 
99
  }
100
 
 
 
 
 
101
  div#steering-toggle {
102
  padding-top: 8px;
103
  padding-bottom: 8px;
 
229
  yield "Server Error"
230
 
231
 
232
+ async def post_process(session_id, output):
233
+ req = instances[session_id].pop()
234
+
235
+ if "</think>" in output:
236
+ p = [p for p in output.partition("</think>") if p != ""]
237
+ reasoning = "".join(p[:-1])
238
+ if len(p) == 1:
239
+ answer = None
240
+ else:
241
+ answer = p[-1]
242
+ else:
243
+ answer = None
244
+ reasoning = output
245
+
246
+ steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer)
247
+ instances[session_id].append(steering_output)
248
+
249
+
250
  async def generate(
251
  session_id: str, prompt: str, steering: bool, coeff: float,
252
+ max_new_tokens: int, top_p: float, temperature: float, vec_scaling: float
253
  ):
254
  req = UserRequest(
255
  session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
256
+ max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, k=vec_scaling
257
  )
258
 
259
  instances[session_id].append(req)
 
274
  else:
275
  logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
276
  raise gr.Error("API Server Error")
277
+
278
+ await post_process(session_id, generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ except:
281
+ logger.info("Client session error")
 
282
 
283
 
284
  async def output_feedback(session_id, feedback):
 
296
  logger.debug("Feedback submission error")
297
 
298
 
299
+ async def show_feedback_buttons(upvote_btn, downvote_btn):
300
+ return gr.update(interactive=True), gr.update(interactive=True)
301
+
302
+
303
  gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
304
  theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
305
 
 
309
 
310
  gr.HTML(HTML)
311
 
312
+ @gr.render(inputs=endpoint_state, triggers=[endpoint_state.change])
313
+ def render_state(endpoint_state):
314
+ if endpoint_state == "Ready":
315
+ color = "green"
316
+ elif endpoint_state == "Server Error":
317
+ color = "red"
318
+ else:
319
+ color = "orange"
320
+
321
+ if endpoint_state != None:
322
+ gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state")
323
+
324
  with gr.Row(elem_id="main-components"):
325
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
326
  with gr.Row():
327
  steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
328
  coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
 
333
  return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
334
  else:
335
  return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
 
 
 
 
 
 
336
 
337
  input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
338
 
339
  with gr.Row():
340
  clear_btn = gr.ClearButton()
341
+ stop_btn = gr.Button("Stop")
342
  generate_btn = gr.Button("Generate", variant="primary")
343
 
344
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
345
+ with gr.Row():
346
+ temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=1)
347
+ top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=1)
348
+
349
+ with gr.Row():
350
+ layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2)
351
+ max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
352
+ vec_scaling = gr.Number(CONFIG["k"], interactive=True, label="Vector scaling", scale=1)
353
+
354
  with gr.Column(scale=1):
355
  output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
356
 
357
+ with gr.Row():
358
  upvote_btn = gr.Button("👍 Upvote", interactive=False)
359
  downvote_btn = gr.Button("👎 Downvote", interactive=False)
360
 
 
363
  gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
364
  gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
365
 
366
+
367
+ @gr.on(triggers=[clear_btn.click, stop_btn.click], outputs=[upvote_btn, downvote_btn])
368
+ def clear_feedback_buttons():
369
  return gr.update(interactive=False), gr.update(interactive=False)
370
+
371
+ @gr.on(triggers=[generate_btn.click], outputs=[upvote_btn, downvote_btn])
372
+ def show_feedback_buttons():
373
+ return gr.update(interactive=True), gr.update(interactive=True)
374
 
375
+
376
+ submission = generate_btn.click(
377
+ generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, vec_scaling], outputs=output
 
 
378
  )
379
 
380
+ clear_btn.add([input_text, output])
381
+ stop_btn.click(None, None, None, cancels=[submission], queue=False)
382
+
383
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
384
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
385
 
scheduler.py CHANGED
@@ -28,6 +28,8 @@ def load_scheduler():
28
  "answer": {"_type": "Value", "dtype": "string"},
29
  "upvote": {"_type": "Value", "dtype": "bool"},
30
  "timestamp": {"_type": "Value", "dtype": "string"},
 
 
31
  }
32
  )
33
 
 
28
  "answer": {"_type": "Value", "dtype": "string"},
29
  "upvote": {"_type": "Value", "dtype": "bool"},
30
  "timestamp": {"_type": "Value", "dtype": "string"},
31
+ "layer": {"_type": "Value", "dtype": "int64"},
32
+ "k": {"_type": "Value", "dtype": "float64"},
33
  }
34
  )
35
 
schemas.py CHANGED
@@ -7,7 +7,8 @@ CONFIG = {
7
  "max_new_tokens": 3048,
8
  "top_p": 0.95,
9
  "temperature": 0.6,
10
- "k": 200
 
11
  }
12
 
13
  class UserRequest(BaseModel):
@@ -18,13 +19,16 @@ class UserRequest(BaseModel):
18
  max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
19
  top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
20
  temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
 
 
21
 
22
  def get_api_format(self):
23
  return {
24
  "prompt": self.prompt,
25
  "steering": self.steering,
26
  "coeff": self.coeff,
27
- "k": CONFIG["k"],
 
28
  "generation_config": {
29
  "max_new_tokens": self.max_new_tokens,
30
  "top_p": self.top_p,
@@ -36,6 +40,6 @@ class UserRequest(BaseModel):
36
  class SteeringOutput(UserRequest):
37
  max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
38
  reasoning: str = None
39
- answer: str = None
40
  upvote: Optional[bool] = None
41
  timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
 
7
  "max_new_tokens": 3048,
8
  "top_p": 0.95,
9
  "temperature": 0.6,
10
+ "k": 200,
11
+ "layer": 25
12
  }
13
 
14
  class UserRequest(BaseModel):
 
19
  max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
20
  top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
21
  temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
22
+ k: float = Field(CONFIG["k"])
23
+ layer: int = Field(CONFIG["layer"])
24
 
25
  def get_api_format(self):
26
  return {
27
  "prompt": self.prompt,
28
  "steering": self.steering,
29
  "coeff": self.coeff,
30
+ "k": self.k,
31
+ "layer": self.layer,
32
  "generation_config": {
33
  "max_new_tokens": self.max_new_tokens,
34
  "top_p": self.top_p,
 
40
  class SteeringOutput(UserRequest):
41
  max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
42
  reasoning: str = None
43
+ answer: Optional[str] = None
44
  upvote: Optional[bool] = None
45
  timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())