Spaces:
Running
Running
| # data_viewer.py | |
| import base64 | |
| import json | |
| from functools import lru_cache | |
| from io import BytesIO | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from PIL import Image | |
| IGNORE_DETAILS = True | |
| DATASET_NAME = "MMInstruction/VRewardBench" | |
| def load_cached_dataset(dataset_name, split): | |
| return load_dataset(dataset_name, split=split) | |
| def base64_to_image(base64_string): | |
| img_data = base64.b64decode(base64_string) | |
| return Image.open(BytesIO(img_data)) | |
| def get_responses(responses, rankings): | |
| if isinstance(responses, str): | |
| responses = json.loads(responses) | |
| if isinstance(rankings, str): | |
| rankings = json.loads(rankings) | |
| chosen = next((resp for resp, rank in zip(responses, rankings) if rank == 0), "No chosen response") | |
| rejected = next((resp for resp, rank in zip(responses, rankings) if rank == 1), "No rejected response") | |
| return chosen, rejected | |
| def load_and_display_sample(split, idx): | |
| try: | |
| dataset = load_cached_dataset(DATASET_NAME, split) | |
| max_idx = len(dataset) - 1 | |
| idx = min(max(0, int(idx)), max_idx) | |
| sample = dataset[idx] | |
| # Get responses | |
| chosen_response, rejected_response = get_responses(sample["response"], sample["human_ranking"]) | |
| # Process JSON data | |
| models = json.loads(sample["models"]) if isinstance(sample["models"], str) else sample["models"] | |
| return ( | |
| sample["image"], # image | |
| sample["id"], # sample_id | |
| chosen_response, # chosen_response | |
| rejected_response, # rejected_response | |
| sample["judge"], # judge | |
| sample["query_source"], # query_source | |
| sample["query"], # query | |
| json.dumps(models, indent=2), # models_json | |
| sample["rationale"], # rationale | |
| sample["ground_truth"], # ground_truth | |
| f"Total samples: {len(dataset)}", # total_samples | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Error loading dataset: {str(e)}") | |
| def create_data_viewer(): | |
| # Pre-fetch initial data | |
| initial_split = "test" | |
| initial_idx = 0 | |
| initial_data = load_and_display_sample(initial_split, initial_idx) | |
| ( | |
| init_image, | |
| init_sample_id, | |
| init_chosen_response, | |
| init_rejected_response, | |
| init_judge, | |
| init_query_source, | |
| init_query, | |
| init_models_json, | |
| init_rationale, | |
| init_ground_truth, | |
| init_total_samples, | |
| ) = initial_data | |
| with gr.Column(): | |
| with gr.Row(): | |
| dataset_split = gr.Radio(choices=["test"], value=initial_split, label="Dataset Split") | |
| sample_idx = gr.Number(label="Sample Index", value=initial_idx, minimum=0, step=1, interactive=True) | |
| total_samples = gr.Textbox( | |
| label="Total Samples", value=init_total_samples, interactive=False # Set initial total samples | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Sample Image", type="pil", value=init_image) # Set initial image | |
| query = gr.Textbox(label="Query", value=init_query, interactive=False) # Set initial query | |
| with gr.Column(): | |
| sample_id = gr.Textbox( | |
| label="Sample ID", value=init_sample_id, interactive=False # Set initial sample ID | |
| ) | |
| chosen_response = gr.TextArea( | |
| label="Chosen Response ✅", | |
| value=init_chosen_response, | |
| interactive=False, # Set initial chosen response | |
| ) | |
| rejected_response = gr.TextArea( | |
| label="Rejected Response ❌", | |
| value=init_rejected_response, # Set initial rejected response | |
| interactive=False, | |
| ) | |
| with gr.Row(visible=not IGNORE_DETAILS): | |
| judge = gr.Textbox(label="Judge", value=init_judge, interactive=False) # Set initial judge | |
| query_source = gr.Textbox( | |
| label="Query Source", value=init_query_source, interactive=False # Set initial query source | |
| ) | |
| with gr.Row(visible=not IGNORE_DETAILS): | |
| with gr.Column(): | |
| models_json = gr.JSON(label="Models", value=json.loads(init_models_json)) # Set initial models | |
| rationale = gr.TextArea( | |
| label="Rationale", value=init_rationale, interactive=False # Set initial rationale | |
| ) | |
| with gr.Column(): | |
| ground_truth = gr.TextArea( | |
| label="Ground Truth", value=init_ground_truth, interactive=False # Set initial ground truth | |
| ) | |
| # Auto-update when any input changes | |
| for input_component in [dataset_split, sample_idx]: | |
| input_component.change( | |
| fn=load_and_display_sample, | |
| inputs=[dataset_split, sample_idx], | |
| outputs=[ | |
| image, | |
| sample_id, | |
| chosen_response, | |
| rejected_response, | |
| judge, | |
| query_source, | |
| query, | |
| models_json, | |
| rationale, | |
| ground_truth, | |
| total_samples, | |
| ], | |
| ) | |
| return dataset_split, sample_idx | |