gch0301 commited on
Commit
c794ce2
·
verified ·
1 Parent(s): 22bad4b

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py DELETED
@@ -1,97 +0,0 @@
1
- import gradio as gr
2
- from matplotlib import gridspec
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- from PIL import Image
6
- import torch
7
- from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
8
-
9
- MODEL_ID = "nvidia/segformer-b4-finetuned-cityscapes-1024-1024"
10
- processor = AutoImageProcessor.from_pretrained(MODEL_ID)
11
- model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
12
-
13
- def ade_palette():
14
- """ADE20K palette that maps each class to RGB values."""
15
- return [
16
- [204, 27, 92], [112, 185, 212], [45, 189, 106], [234, 123, 67], [78, 56, 123], [210, 32, 89],
17
- [90, 180, 56], [155, 102, 200], [33, 147, 176], [255, 183, 76], [67, 123, 89], [190, 190, 0],
18
- [134, 112, 200], [56, 45, 189], [200, 56, 123], [87, 92, 204], [120, 56, 123], [45, 78, 123],
19
- [156, 200, 56],
20
- ]
21
-
22
- labels_list = []
23
- with open("labels.txt", "r", encoding="utf-8") as fp:
24
- for line in fp:
25
- labels_list.append(line.rstrip("\n"))
26
-
27
- colormap = np.asarray(ade_palette(), dtype=np.uint8)
28
-
29
- def label_to_color_image(label):
30
- if label.ndim != 2:
31
- raise ValueError("Expect 2-D input label")
32
- if np.max(label) >= len(colormap):
33
- raise ValueError("label value too large.")
34
- return colormap[label]
35
-
36
- def draw_plot(pred_img, seg_np):
37
- fig = plt.figure(figsize=(20, 15))
38
- grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
39
-
40
- plt.subplot(grid_spec[0])
41
- plt.imshow(pred_img)
42
- plt.axis('off')
43
-
44
- LABEL_NAMES = np.asarray(labels_list)
45
- FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
46
- FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
47
-
48
- unique_labels = np.unique(seg_np.astype("uint8"))
49
- ax = plt.subplot(grid_spec[1])
50
- plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
51
- ax.yaxis.tick_right()
52
- plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
53
- plt.xticks([], [])
54
- ax.tick_params(width=0.0, labelsize=25)
55
- return fig
56
-
57
- def run_inference(input_img):
58
- # input: numpy array from gradio -> PIL
59
- img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
60
- if img.mode != "RGB":
61
- img = img.convert("RGB")
62
-
63
- inputs = processor(images=img, return_tensors="pt")
64
- with torch.no_grad():
65
- outputs = model(**inputs)
66
- logits = outputs.logits # (1, C, h/4, w/4)
67
-
68
- # resize to original
69
- upsampled = torch.nn.functional.interpolate(
70
- logits, size=img.size[::-1], mode="bilinear", align_corners=False
71
- )
72
- seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) # (H,W)
73
-
74
- # colorize & overlay
75
- color_seg = colormap[seg] # (H,W,3)
76
- pred_img = (np.array(img) * 0.5 + color_seg * 0.5).astype(np.uint8)
77
-
78
- fig = draw_plot(pred_img, seg)
79
- return fig
80
-
81
- demo = gr.Interface(
82
- fn=run_inference,
83
- inputs=gr.Image(type="numpy", label="Input Image"),
84
- outputs=gr.Plot(label="Overlay + Legend"),
85
- examples=[
86
- "person-1.jpg",
87
- "person-2.jpg",
88
- "person-3.jpg",
89
- "person-4.jpg",
90
- "person-5.jpg"
91
- ],
92
- flagging_mode="never",
93
- cache_examples=False,
94
- )
95
-
96
- if __name__ == "__main__":
97
- demo.launch()