Spaces:
Running
Running
Lev McKinney
commited on
Commit
Β·
4004daa
1
Parent(s):
c35da92
fixed several bugs in app.py
Browse files- .dockerignore +1 -1
- README.md +0 -1
- app.py +7 -8
.dockerignore
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
lens
|
| 2 |
-
.git
|
|
|
|
| 1 |
lens
|
| 2 |
+
.git
|
README.md
CHANGED
|
@@ -3,7 +3,6 @@ title: Tuned Lens
|
|
| 3 |
emoji: π
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: blue
|
| 6 |
-
port: 7860
|
| 7 |
sdk: docker
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
|
|
|
| 3 |
emoji: π
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: blue
|
|
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from plotly import graph_objects as go
|
|
| 7 |
|
| 8 |
device = torch.device("cpu")
|
| 9 |
print(f"Using device {device} for inference")
|
| 10 |
-
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
| 11 |
model = model.to(device)
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
| 13 |
tuned_lens = TunedLens.from_model_and_pretrained(
|
|
@@ -29,19 +29,19 @@ statistic_options_dict = {
|
|
| 29 |
|
| 30 |
|
| 31 |
def make_plot(lens, text, statistic, token_cutoff):
|
| 32 |
-
input_ids = tokenizer.encode(text
|
| 33 |
input_ids = [tokenizer.bos_token_id] + input_ids
|
| 34 |
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
| 35 |
|
| 36 |
-
if len(input_ids
|
| 37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
| 38 |
|
| 39 |
if token_cutoff < 1:
|
| 40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
| 41 |
|
| 42 |
-
start_pos=max(len(input_ids
|
| 43 |
pred_traj = PredictionTrajectory.from_lens_and_model(
|
| 44 |
-
lens=lens,
|
| 45 |
model=model,
|
| 46 |
input_ids=input_ids,
|
| 47 |
tokenizer=tokenizer,
|
|
@@ -49,7 +49,7 @@ def make_plot(lens, text, statistic, token_cutoff):
|
|
| 49 |
start_pos=start_pos,
|
| 50 |
)
|
| 51 |
|
| 52 |
-
return getattr(pred_traj, statistic)().figure(
|
| 53 |
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
| 54 |
)
|
| 55 |
|
|
@@ -114,5 +114,4 @@ with gr.Blocks() as demo:
|
|
| 114 |
demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
demo.launch()
|
| 118 |
-
|
|
|
|
| 7 |
|
| 8 |
device = torch.device("cpu")
|
| 9 |
print(f"Using device {device} for inference")
|
| 10 |
+
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped", torch_dtype="auto")
|
| 11 |
model = model.to(device)
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
| 13 |
tuned_lens = TunedLens.from_model_and_pretrained(
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def make_plot(lens, text, statistic, token_cutoff):
|
| 32 |
+
input_ids = tokenizer.encode(text)
|
| 33 |
input_ids = [tokenizer.bos_token_id] + input_ids
|
| 34 |
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
| 35 |
|
| 36 |
+
if len(input_ids) == 1:
|
| 37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
| 38 |
|
| 39 |
if token_cutoff < 1:
|
| 40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
| 41 |
|
| 42 |
+
start_pos=max(len(input_ids) - token_cutoff, 0)
|
| 43 |
pred_traj = PredictionTrajectory.from_lens_and_model(
|
| 44 |
+
lens=lens_options_dict[lens],
|
| 45 |
model=model,
|
| 46 |
input_ids=input_ids,
|
| 47 |
tokenizer=tokenizer,
|
|
|
|
| 49 |
start_pos=start_pos,
|
| 50 |
)
|
| 51 |
|
| 52 |
+
return getattr(pred_traj, statistic_options_dict[statistic])().figure(
|
| 53 |
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
| 54 |
)
|
| 55 |
|
|
|
|
| 114 |
demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|