Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import json | |
| import requests | |
| data_url = "http://opencompass.openxlab.space/utils/RiseVis/data.json" | |
| data = json.loads(requests.get(data_url).text) | |
| # Get model names from the first entry | |
| model_names = list(data[0]['results'].keys()) | |
| HTML_HEAD = '<table class="center">' | |
| HTML_TAIL = '</table>' | |
| N_COL = 5 | |
| WIDTH = 100 // N_COL | |
| def get_image_gallery(idx, models): | |
| assert isinstance(idx, str) | |
| item = [x for x in data if x['index'] == idx] | |
| assert len(item) == 1 | |
| item = item[0] | |
| html = HTML_HEAD | |
| models = list(models) | |
| models.sort() | |
| num_models = len(models) | |
| for i in range((num_models - 1) // N_COL + 1): | |
| sub_models = models[N_COL * i: N_COL * (i + 1)] | |
| html += '<tr>' | |
| for j in range(N_COL): | |
| if j >= len(sub_models): | |
| html += f'<td width={WIDTH}% style="text-align:center;"></td>' | |
| else: | |
| html += f'<td width={WIDTH}% style="text-align:center;"><h3>{sub_models[j]}</h3></td>' | |
| html += '</tr><tr>' | |
| for j in range(N_COL): | |
| if j >= len(sub_models): | |
| html += f'<td width={WIDTH}% style="text-align:center;"></td>' | |
| else: | |
| html += f'<td width={WIDTH}% style="text-align:center;"><img src="{URL_BASE + item["results"][sub_models[j]]}"></td>' | |
| html += '</tr>' | |
| html += HTML_TAIL | |
| return html | |
| URL_BASE = 'https://opencompass.openxlab.space/utils/RiseVis/' | |
| def get_origin_image(idx, model='original'): | |
| assert isinstance(idx, str) | |
| item = [x for x in data if x['index'] == idx] | |
| assert len(item) == 1 | |
| item = item[0] | |
| file_name = item['image'] if model == 'original' else item['results']['model'] | |
| url = URL_BASE + file_name | |
| return url | |
| def read_instruction(idx): | |
| assert isinstance(idx, str) | |
| item = [x for x in data if x['index'] == idx] | |
| assert len(item) == 1 | |
| return item[0]['instruction'] | |
| def on_prev(state): | |
| for i, item in enumerate(data): | |
| if item['index'] == state: | |
| break | |
| return data[i - 1]['index'], data[i - 1]['index'] | |
| def on_next(state): | |
| for i, item in enumerate(data): | |
| if item['index'] == state: | |
| break | |
| return data[i + 1]['index'], data[i + 1]['index'] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gallery of Generation Results on RISEBench") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| prev_button = gr.Button("PREV") | |
| next_button = gr.Button("NEXT") | |
| problem_index = gr.Textbox(value='causal_reasoning_1', label='Problem Index', interactive=True, visible=True) | |
| state = gr.Markdown(value='causal_reasoning_1', label='Current Problem Index', visible=False) | |
| def update_state(problem_index): | |
| return problem_index | |
| problem_index.submit(fn=update_state, inputs=[problem_index], outputs=[state]) | |
| prev_button.click(fn=on_prev, inputs=[state], outputs=[state, problem_index]) | |
| next_button.click(fn=on_next, inputs=[state], outputs=[state, problem_index]) | |
| model_checkboxes = gr.CheckboxGroup(label="Select Models", choices=model_names, value=model_names) | |
| with gr.Column(scale=2): | |
| instruction = gr.Textbox(label="Instruction", interactive=False, value=read_instruction, inputs=[state]) | |
| with gr.Column(scale=1): | |
| image = gr.Image(label="Input Image", value=get_origin_image, inputs=[state]) | |
| gallery = gr.HTML(value=get_image_gallery, inputs=[state, model_checkboxes]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name='0.0.0.0') | |