Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '3' | |
| import os | |
| import sys | |
| import faiss | |
| import numpy as np | |
| import streamlit as st | |
| from text2vec import SentenceModel | |
| # 请确保 JSONLIndexer 在 src 目录下或者已正确安装 | |
| from src.jsonl_Indexer import JSONLIndexer | |
| # 命令行参数处理函数 | |
| def get_cli_args(): | |
| args = {} | |
| # 跳过第一个参数(脚本名)和第二个参数(streamlit run) | |
| argv = sys.argv[2:] if len(sys.argv) > 2 else [] | |
| for arg in argv: | |
| if '=' in arg: | |
| key, value = arg.split('=', 1) | |
| args[key.strip()] = value.strip() | |
| return args | |
| # 获取命令行参数 | |
| cli_args = get_cli_args() | |
| # 设置默认值(适用于 JSONL 文件) | |
| DEFAULT_CONFIG = { | |
| 'model_path': 'BAAI/bge-base-en-v1.5', | |
| 'dataset_path': 'src/tool-embedding.jsonl', # JSONL 文件路径 | |
| 'vector_size': 768, | |
| 'embedding_field': 'embedding', # JSON中存储embedding的字段名 | |
| 'id_field': 'id' # JSON中作为待检索文本的字段 | |
| } | |
| # 合并默认配置和命令行参数 | |
| config = DEFAULT_CONFIG.copy() | |
| config.update(cli_args) | |
| # 将 vector_size 转换为整数 | |
| config['vector_size'] = int(config['vector_size']) | |
| def get_model(model_path: str = config['model_path']): | |
| model = SentenceModel(model_path) | |
| return model | |
| def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): | |
| retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) | |
| retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) | |
| return retriever | |
| # 在侧边栏显示当前配置 | |
| if st.sidebar.checkbox("Show Configuration"): | |
| st.sidebar.write("Current Configuration:") | |
| for key, value in config.items(): | |
| st.sidebar.write(f"{key}: {value}") | |
| # 初始化模型和检索器 | |
| model = get_model(config['model_path']) | |
| retriever = create_retriever( | |
| config['vector_size'], | |
| config['dataset_path'], | |
| config['embedding_field'], | |
| config['id_field'], | |
| _model=model | |
| ) | |
| # Streamlit 应用界面 | |
| st.title("JSONL Data Retrieval Visualization") | |
| st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。") | |
| # 查询输入 | |
| query = st.text_input("Enter a search query:") | |
| top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5) | |
| # 检索并展示结果 | |
| if st.button("Search") and query: | |
| # 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段 | |
| rec_ids, scores = retriever.search_return_id(query, top_k) | |
| st.write("### Results:") | |
| with st.expander("Retrieval Results (click to expand)"): | |
| for j, rec_id in enumerate(rec_ids): | |
| st.markdown( | |
| f""" | |
| <div style="border:1px solid #ccc; padding:10px; border-radius:5px; margin-bottom:10px; background-color:#f9f9f9;"> | |
| <p><b>Record {j+1} ID:</b> {rec_id}</p> | |
| <p><b>Score:</b> {scores[j]:.4f}</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |