Spaces:
Running
Running
| from datetime import datetime, timedelta | |
| import json | |
| import requests | |
| import streamlit as st | |
| from any_agent import AgentFramework | |
| from any_agent.tracing.trace import _is_tracing_supported | |
| from any_agent.evaluation import EvaluationCase | |
| from any_agent.evaluation.schemas import CheckpointCriteria | |
| import pandas as pd | |
| from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS | |
| import copy | |
| from pydantic import BaseModel, ConfigDict | |
| class UserInputs(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| model_id: str | |
| location: str | |
| max_driving_hours: int | |
| date: datetime | |
| framework: str | |
| evaluation_case: EvaluationCase | |
| run_evaluation: bool | |
| def get_area(area_name: str) -> dict: | |
| """Get the area from Nominatim. | |
| Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/). | |
| Args: | |
| area_name (str): The name of the area. | |
| Returns: | |
| dict: The area found. | |
| """ | |
| response = requests.get( | |
| f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2", | |
| headers={"User-Agent": "Mozilla/5.0"}, | |
| timeout=5, | |
| ) | |
| response.raise_for_status() | |
| response_json = json.loads(response.content.decode()) | |
| return response_json | |
| def get_user_inputs() -> UserInputs: | |
| default_val = "Los Angeles California, US" | |
| location = st.text_input("Enter a location", value=default_val) | |
| if location: | |
| location_check = get_area(location) | |
| if not location_check: | |
| st.error("β Invalid location") | |
| max_driving_hours = st.number_input( | |
| "Enter the maximum driving hours", min_value=1, value=2 | |
| ) | |
| col_date, col_time = st.columns([2, 1]) | |
| with col_date: | |
| date = st.date_input( | |
| "Select a date in the future", value=datetime.now() + timedelta(days=1) | |
| ) | |
| with col_time: | |
| # default to 9am | |
| time = st.selectbox( | |
| "Select a time", | |
| [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)], | |
| index=9, | |
| ) | |
| date = datetime.combine(date, time) | |
| supported_frameworks = [ | |
| framework for framework in AgentFramework if _is_tracing_supported(framework) | |
| ] | |
| framework = st.selectbox( | |
| "Select the agent framework to use", | |
| supported_frameworks, | |
| index=2, | |
| format_func=lambda x: x.name, | |
| ) | |
| model_id = st.selectbox( | |
| "Select the model to use", | |
| MODEL_OPTIONS, | |
| index=1, | |
| format_func=lambda x: "/".join(x.split("/")[-3:]), | |
| ) | |
| # Add evaluation case section | |
| with st.expander("Custom Evaluation"): | |
| evaluation_model_id = st.selectbox( | |
| "Select the model to use for LLM-as-a-Judge evaluation", | |
| MODEL_OPTIONS, | |
| index=2, | |
| format_func=lambda x: "/".join(x.split("/")[-3:]), | |
| ) | |
| evaluation_case = copy.deepcopy(DEFAULT_EVALUATION_CASE) | |
| evaluation_case.llm_judge = evaluation_model_id | |
| # make this an editable json section | |
| # convert the checkpoints to a df series so that it can be edited | |
| checkpoints = evaluation_case.checkpoints | |
| checkpoints_df = pd.DataFrame( | |
| [checkpoint.model_dump() for checkpoint in checkpoints] | |
| ) | |
| checkpoints_df = st.data_editor( | |
| checkpoints_df, | |
| column_config={ | |
| "points": st.column_config.NumberColumn(label="Points"), | |
| "criteria": st.column_config.TextColumn(label="Criteria"), | |
| }, | |
| hide_index=True, | |
| num_rows="dynamic", | |
| ) | |
| # for each checkpoint, convert it back to a CheckpointCriteria object | |
| new_ckpts = [] | |
| # don't let a user add more than 20 checkpoints | |
| if len(checkpoints_df) > 20: | |
| st.error( | |
| "You can only add up to 20 checkpoints for the purpose of this demo." | |
| ) | |
| checkpoints_df = checkpoints_df[:20] | |
| for _, row in checkpoints_df.iterrows(): | |
| if row["criteria"] == "": | |
| continue | |
| try: | |
| # Don't let people write essays for criteria in this demo | |
| if len(row["criteria"].split(" ")) > 100: | |
| raise ValueError("Criteria is too long") | |
| new_crit = CheckpointCriteria( | |
| criteria=row["criteria"], points=row["points"] | |
| ) | |
| new_ckpts.append(new_crit) | |
| except Exception as e: | |
| st.error(f"Error creating checkpoint: {e}") | |
| evaluation_case.checkpoints = new_ckpts | |
| return UserInputs( | |
| model_id=model_id, | |
| location=location, | |
| max_driving_hours=max_driving_hours, | |
| date=date, | |
| framework=framework, | |
| evaluation_case=evaluation_case, | |
| run_evaluation=st.checkbox("Run Evaluation", value=True), | |
| ) | |