import os os.environ['CUDA_VISIBLE_DEVICES'] = '3' import os import sys import faiss import numpy as np import streamlit as st from text2vec import SentenceModel # 请确保 JSONLIndexer 在 src 目录下或者已正确安装 from src.jsonl_Indexer import JSONLIndexer # 命令行参数处理函数 def get_cli_args(): args = {} # 跳过第一个参数(脚本名)和第二个参数(streamlit run) argv = sys.argv[2:] if len(sys.argv) > 2 else [] for arg in argv: if '=' in arg: key, value = arg.split('=', 1) args[key.strip()] = value.strip() return args # 获取命令行参数 cli_args = get_cli_args() # 设置默认值(适用于 JSONL 文件) DEFAULT_CONFIG = { 'model_path': 'BAAI/bge-base-en-v1.5', 'dataset_path': 'src/tool-embedding.jsonl', # JSONL 文件路径 'vector_size': 768, 'embedding_field': 'embedding', # JSON中存储embedding的字段名 'id_field': 'id' # JSON中作为待检索文本的字段 } # 合并默认配置和命令行参数 config = DEFAULT_CONFIG.copy() config.update(cli_args) # 将 vector_size 转换为整数 config['vector_size'] = int(config['vector_size']) @st.cache_resource def get_model(model_path: str = config['model_path']): model = SentenceModel(model_path) return model @st.cache_resource def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) return retriever # 在侧边栏显示当前配置 if st.sidebar.checkbox("Show Configuration"): st.sidebar.write("Current Configuration:") for key, value in config.items(): st.sidebar.write(f"{key}: {value}") # 初始化模型和检索器 model = get_model(config['model_path']) retriever = create_retriever( config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model ) # Streamlit 应用界面 st.title("JSONL Data Retrieval Visualization") st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。") # 查询输入 query = st.text_input("Enter a search query:") top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5) # 检索并展示结果 if st.button("Search") and query: # 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段 rec_ids, scores = retriever.search_return_id(query, top_k) st.write("### Results:") with st.expander("Retrieval Results (click to expand)"): for j, rec_id in enumerate(rec_ids): st.markdown( f"""

Record {j+1} ID: {rec_id}

Score: {scores[j]:.4f}

""", unsafe_allow_html=True )