Spaces:
Running
Running
| import gradio as gr | |
| from FlagEmbedding import BGEM3FlagModel | |
| import numpy as np | |
| import json | |
| import os | |
| import re | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from scipy.special import softmax | |
| import asyncio | |
| # --- Configuration and Global Data Loading --- | |
| # Determine the directory of the script to load files relative to it | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # Original issue-level artifacts (kept for sparse/loose and strict) | |
| issue_embeddings_paths = { | |
| # We will still attempt to load original dense (semantic) if present, | |
| # but semantic search will use component-level embeddings. This is optional. | |
| 'semantic': os.path.join(script_dir, 'ns_issues_semantic_bge-m3.npy'), | |
| 'loose': os.path.join(script_dir, 'ns_issues_loose_bge-m3.npy'), | |
| } | |
| issue_titles_path = os.path.join(script_dir, 'issue_titles.json') | |
| # Component-level artifacts (used for semantic only) | |
| issue_components_paths = { | |
| 'semantic': os.path.join(script_dir, 'ns_issue_components_semantic_bge-m3.npy'), | |
| # There is intentionally no component-level 'loose' per your instruction. | |
| } | |
| issue_components_meta_path = os.path.join(script_dir, 'ns_issue_components_meta.json') | |
| issue_titles_components_path = os.path.join(script_dir, 'issue_titles_components.json') | |
| # GA resolution artifacts (unchanged) | |
| ga_embeddings_paths = { | |
| 'semantic': os.path.join(script_dir, 'ns_ga_resolutions_semantic_bge-m3.npy'), | |
| 'loose': os.path.join(script_dir, 'ns_ga_resolutions_loose_bge-m3.npy'), | |
| } | |
| ga_resolutions_path = os.path.join(script_dir, 'parsed_ga_resolutions.json') | |
| print("Loading BGE-M3 model...") | |
| try: | |
| model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Please ensure you have an internet connection or the model is cached locally.") | |
| model = None # Indicate model loading failed | |
| # Issue data storage (issue-level and component-level) | |
| issue_all_embeddings = { | |
| 'semantic': None, # optional legacy dense; not used for semantic queries in this app | |
| 'loose': None, # issue-level sparse, used for loose search | |
| } | |
| issue_titles = {} | |
| all_issue_raw_texts = [] # For strict search (issue-level) | |
| issue_components_embeddings = { | |
| 'semantic': None, # dense component-level embedding matrix | |
| } | |
| issue_components_meta = [] # list of dicts aligned to component rows | |
| issue_titles_components = {} | |
| print("Loading issue data...") | |
| try: | |
| # Load issue-level embeddings (kept for sparse/loose and optional legacy dense) | |
| for embed_type, path in issue_embeddings_paths.items(): | |
| if os.path.exists(path): | |
| if embed_type == 'loose': | |
| issue_all_embeddings[embed_type] = np.load(path, allow_pickle=True).tolist() | |
| else: | |
| issue_all_embeddings[embed_type] = np.load(path) | |
| shape_or_len = issue_all_embeddings[embed_type].shape if hasattr(issue_all_embeddings[embed_type], 'shape') else len(issue_all_embeddings[embed_type]) | |
| print(f" Loaded {embed_type} issue embeddings from {path} (Shape/Len: {shape_or_len})") | |
| else: | |
| print(f" Warning: {embed_type} issue embeddings not found at {path}.") | |
| issue_all_embeddings[embed_type] = None | |
| # Load titles (issue-level) | |
| if os.path.exists(issue_titles_path): | |
| with open(issue_titles_path, encoding='utf-8') as file: | |
| issue_titles = json.load(file) | |
| print(f"Issue titles loaded: {len(issue_titles)} issues.") | |
| else: | |
| print(f" Warning: issue_titles.json not found at {issue_titles_path}") | |
| # Load raw issue texts for strict search | |
| issues_input_dir = os.path.join(script_dir, 'small_scripts', 'make_embedding', | |
| 'NationStates-Issue-Megathread', '002 - Issue Megalist (MAIN)') | |
| issue_files_for_raw_load = [] | |
| file_pattern = re.compile(r'(\d+) TO (\d+)\.txt') | |
| if os.path.isdir(issues_input_dir): | |
| for filename in os.listdir(issues_input_dir): | |
| if filename.endswith('.txt'): | |
| match = file_pattern.match(filename) | |
| if match: | |
| start_num = int(match.group(1)) | |
| issue_files_for_raw_load.append((start_num, filename)) | |
| issue_files_for_raw_load.sort(key=lambda x: x[0]) | |
| issue_files_for_raw_load = [os.path.join(issues_input_dir, filename) for _, filename in issue_files_for_raw_load] | |
| for filepath in issue_files_for_raw_load: | |
| with open(filepath, 'r', encoding='utf-8') as file: | |
| issues_text_in_file = file.read() | |
| issues_list_in_file = [ | |
| issue.strip() for issue in issues_text_in_file.split("[hr][/hr]") if issue.strip() | |
| ] | |
| all_issue_raw_texts.extend(issues_list_in_file) | |
| print(f" Loaded {len(all_issue_raw_texts)} raw issue texts for strict search.") | |
| else: | |
| print(f" Warning: Issue text directory '{issues_input_dir}' not found. Strict issue search will not work.") | |
| # Load component-level artifacts (semantic only) | |
| for embed_type, path in issue_components_paths.items(): | |
| if os.path.exists(path): | |
| issue_components_embeddings[embed_type] = np.load(path) | |
| print(f" Loaded component {embed_type} embeddings from {path} (Shape: {issue_components_embeddings[embed_type].shape})") | |
| else: | |
| print(f" Warning: component {embed_type} embeddings not found at {path}.") | |
| if os.path.exists(issue_components_meta_path): | |
| with open(issue_components_meta_path, encoding='utf-8') as f: | |
| issue_components_meta = json.load(f) | |
| print(f" Loaded component meta: {len(issue_components_meta)} items.") | |
| else: | |
| print(f" Warning: component meta not found at {issue_components_meta_path}.") | |
| if os.path.exists(issue_titles_components_path): | |
| with open(issue_titles_components_path, encoding='utf-8') as f: | |
| issue_titles_components = json.load(f) | |
| print(f" Loaded component issue titles: {len(issue_titles_components)}") | |
| else: | |
| # Fallback to issue-level titles if component titles not present | |
| issue_titles_components = issue_titles | |
| except FileNotFoundError as e: | |
| print(f"Error loading issue data: {e}") | |
| print(f"Please ensure embedding files and '{os.path.basename(issue_titles_path)}' are in the same directory as app.py") | |
| except Exception as e: | |
| print(f"Error loading issue data: {e}") | |
| # GA resolution data storage (unchanged) | |
| ga_all_embeddings = { | |
| 'semantic': None, | |
| 'loose': None, | |
| } | |
| ga_resolutions_data = [] | |
| print("Loading GA resolution data...") | |
| try: | |
| if model: # Only attempt to load embeddings if model is available | |
| for embed_type, path in ga_embeddings_paths.items(): | |
| if os.path.exists(path): | |
| if embed_type == 'loose': | |
| ga_all_embeddings[embed_type] = np.load(path, allow_pickle=True).tolist() | |
| else: | |
| ga_all_embeddings[embed_type] = np.load(path) | |
| shape_or_len = ga_all_embeddings[embed_type].shape if hasattr(ga_all_embeddings[embed_type], 'shape') else len(ga_all_embeddings[embed_type]) | |
| print(f" Loaded {embed_type} GA embeddings from {path} (Shape/Len: {shape_or_len})") | |
| else: | |
| print(f" Warning: {embed_type} GA embeddings not found at {path}.") | |
| ga_all_embeddings[embed_type] = None | |
| if os.path.exists(ga_resolutions_path): | |
| with open(ga_resolutions_path, encoding='utf-8') as file: | |
| ga_resolutions_data = json.load(file) | |
| print(f"GA resolution data loaded: {len(ga_resolutions_data)} resolutions.") | |
| else: | |
| print(f" Warning: GA data file not found at {ga_resolutions_path}") | |
| except FileNotFoundError as e: | |
| print(f"Error loading GA resolution data: {e}") | |
| print(f"Please ensure GA embedding files and '{os.path.basename(ga_resolutions_path)}' are in the same directory as app.py") | |
| except Exception as e: | |
| print(f"Error loading GA resolution data: {e}") | |
| # --- Search Utilities --- | |
| def _extract_context(text: str, query: str): | |
| """Extracts the first line containing the query and highlights all mentions of it (case-insensitive).""" | |
| text_lines = text.split('\n') | |
| query_lower = query.lower() | |
| for line in text_lines: | |
| if query_lower in line.lower(): | |
| highlighted_line = re.sub(re.escape(query), lambda m: f"**{m.group(0)}**", line, flags=re.IGNORECASE) | |
| return f'> {highlighted_line}' | |
| return "" | |
| def embedding_compare(query: str, corpus: dict[str, str]) -> list[tuple[str, float]]: | |
| query_embeddings = model.encode([query], | |
| return_dense=True, | |
| return_sparse=False, | |
| return_colbert_vecs=False) | |
| corpus_embeddings = model.encode(list(corpus.values()), | |
| return_dense=True, | |
| return_sparse=False, | |
| return_colbert_vecs=False) | |
| q = query_embeddings['dense_vecs'] # shape (1, d) | |
| c = corpus_embeddings['dense_vecs'] | |
| scores = (q @ c.T)[0] # shape (N_components,) | |
| scores_list = list(scores) | |
| results = sorted(zip(corpus.keys(), scores_list), key=lambda x: x[1], reverse=True) | |
| return results | |
| # --- Issue Search (Component-level semantic, Issue-level loose/strict) --- | |
| def search_issues(query: str, search_type: str = 'semantic', scope: str = 'both'): | |
| """ | |
| Issue search dispatcher: | |
| - semantic: component-level dense with scope (descriptions | options | both). | |
| - loose: issue-level sparse (scope is ignored). | |
| - strict: issue-level exact/substring match over raw texts (scope is ignored). | |
| """ | |
| try: | |
| if not model: | |
| return "Model failed to load. Cannot perform search." | |
| if not query: | |
| return "Please enter a search term." | |
| # --- Semantic (component-level) --- | |
| if search_type == 'semantic': | |
| corpus = issue_components_embeddings.get('semantic') | |
| if corpus is None or not len(issue_components_meta): | |
| return "Component-level semantic embeddings or metadata not loaded. Cannot run semantic search." | |
| query_embeddings = model.encode([query], | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False) | |
| q = query_embeddings['dense_vecs'] # shape (1, d) | |
| scores = (q @ corpus.T)[0] # shape (N_components,) | |
| indexed = list(enumerate(scores)) | |
| # Scope filter | |
| def allow(meta): | |
| t = meta.get('component_type') | |
| if scope == 'descriptions': | |
| return t == 'desc' | |
| elif scope == 'options': | |
| return t == 'option' | |
| return True | |
| filtered = [(i, s) for i, s in indexed if allow(issue_components_meta[i])] | |
| filtered.sort(key=lambda x: x[1], reverse=True) | |
| out = [f"# Top 20 Issue Results (Semantic, scope={scope})"] | |
| if not filtered: | |
| out.append("No matches found.") | |
| return "\n".join(out) | |
| topk = filtered[:20] | |
| for rank, (idx, score) in enumerate(topk, start=1): | |
| meta = issue_components_meta[idx] | |
| issue_idx = meta['issue_index'] | |
| ctype = meta['component_type'] | |
| opt_idx = meta['option_index'] | |
| title = issue_titles_components.get(str(issue_idx), f"Issue {issue_idx}") | |
| if ctype == 'desc': | |
| label = f"{title} — Description" | |
| else: | |
| label = f"{title} — Option {opt_idx}" | |
| out.append(f"{rank}. {label}, Similarity: {score:.4f}") | |
| return "\n".join(out) | |
| # --- Loose (issue-level sparse) --- | |
| elif search_type == 'loose': | |
| corpus_sparse = issue_all_embeddings.get('loose') | |
| if corpus_sparse is None: | |
| return "Issue-level sparse embeddings not loaded. Cannot run loose search." | |
| query_embeddings = model.encode([query], | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False) | |
| if 'lexical_weights' not in query_embeddings or not query_embeddings['lexical_weights']: | |
| return "Sparse query failed (no lexical weights)." | |
| q_sparse = query_embeddings['lexical_weights'][0] | |
| scores = [model.compute_lexical_matching_score(q_sparse, d) for d in corpus_sparse] | |
| indexed = list(enumerate(scores)) | |
| indexed.sort(key=lambda x: x[1], reverse=True) | |
| out = [f"# Top 20 Issue Results (Loose)"] | |
| if not indexed: | |
| out.append("No matches found.") | |
| return "\n".join(out) | |
| for rank, (idx, score) in enumerate(indexed[:20], start=1): | |
| title = issue_titles.get(str(idx), f"Unknown Issue (Index {idx})") | |
| out.append(f"{rank}. {title}, Similarity: {score:.4f}") | |
| return "\n".join(out) | |
| # --- Strict (issue-level exact/substring) --- | |
| elif search_type == 'strict': | |
| if not all_issue_raw_texts: | |
| return "Raw issue texts not loaded. Strict search is unavailable." | |
| strict_matches = [] | |
| ql = query.lower() | |
| for i, issue_text in enumerate(all_issue_raw_texts): | |
| if ql in issue_text.lower(): | |
| strict_matches.append(i) | |
| out = [f"# Top 20 Issue Search Results (Strict)"] | |
| if not strict_matches: | |
| out.append("No exact matches found.") | |
| return "\n".join(out) | |
| for rank, index in enumerate(strict_matches[:20], start=1): | |
| issue_title = issue_titles.get(str(index), f"Unknown Issue (Index {index})") | |
| context = _extract_context(all_issue_raw_texts[index], query) | |
| out.append(f"{rank}. {issue_title}\n{context}\n") | |
| return "\n".join(out) | |
| else: | |
| return f"Unsupported search type: {search_type}" | |
| except Exception as e: | |
| return f"An error occurred during issue search: {e}" | |
| # --- GA Resolution Search (unchanged logic) --- | |
| def _perform_search_ga(search_term: str, corpus_embeddings_dict: dict, search_type: str): | |
| if not model: | |
| raise ValueError("Model failed to load. Cannot perform search.") | |
| if not search_term: | |
| raise ValueError("Please enter a search term.") | |
| corpus_embeddings = corpus_embeddings_dict.get(search_type) | |
| if corpus_embeddings is None: | |
| raise ValueError(f"Corpus data for search type '{search_type}' not loaded. Cannot perform search.") | |
| query_embeddings = model.encode([search_term], | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False) | |
| if search_type == 'semantic': | |
| query_vec = query_embeddings['dense_vecs'] # Shape: (1, embedding_dim) | |
| similarity_scores = (query_vec @ corpus_embeddings.T)[0] | |
| elif search_type == 'loose': | |
| if 'lexical_weights' not in query_embeddings or not query_embeddings['lexical_weights']: | |
| raise ValueError("Lexical weights (sparse) not returned for query. Model or configuration issue.") | |
| query_sparse_dict = query_embeddings['lexical_weights'][0] | |
| similarity_scores = np.array([ | |
| model.compute_lexical_matching_score(query_sparse_dict, doc_sparse_dict) | |
| for doc_sparse_dict in corpus_embeddings | |
| ]) | |
| else: | |
| raise ValueError(f"Unsupported embedding search type: {search_type}") | |
| indexed_similarities = [(i, score) for i, score in enumerate(similarity_scores)] | |
| sorted_similarities = sorted(indexed_similarities, key=lambda item: item[1], reverse=True) | |
| return sorted_similarities | |
| def search_ga_resolutions(search_term: str, hide_repealed: bool, hide_repeal_category: bool, | |
| search_type: str = 'semantic'): | |
| try: | |
| if not search_term: | |
| return "Please enter a search term." | |
| if search_type == 'strict': | |
| if not ga_resolutions_data: | |
| return "GA resolution data not loaded. Strict search is unavailable." | |
| strict_matches = [] | |
| ql = search_term.lower() | |
| for i, resolution in enumerate(ga_resolutions_data): | |
| body = resolution.get('body', '') | |
| if ql in body.lower(): | |
| status = resolution.get('status') | |
| category = resolution.get('category') | |
| if hide_repealed and status == "Repealed": | |
| continue | |
| if hide_repeal_category and category == "Repeal": | |
| continue | |
| strict_matches.append(i) | |
| out = [f"# Top 20 GA Resolution Search Results (Strict)"] | |
| if not strict_matches: | |
| status_msgs = [] | |
| if hide_repealed: status_msgs.append("Repealed") | |
| if hide_repeal_category: status_msgs.append("Repeal Category") | |
| filter_msg = " (Filtered out " + " and ".join(status_msgs) + ")" if status_msgs else "" | |
| return "\n".join(out + [f"No exact matches found{filter_msg}."]) | |
| for rank, index in enumerate(strict_matches[:20], start=1): | |
| resolution = ga_resolutions_data[index] | |
| title = resolution.get('title', 'Untitled Resolution') | |
| res_id = resolution.get('id', 'N/A') | |
| council = resolution.get('council', 1) | |
| status = resolution.get('status') | |
| status_marker = "[REPEALED] " if status == "Repealed" else "" | |
| url = f"https://www.nationstates.net/page=WA_past_resolution/id={res_id}/council={council}" | |
| context = _extract_context(resolution.get('body', ''), search_term) | |
| out.append(f"{rank}. {status_marker}[#{res_id} {title}]({url}), Match: 1.0000\n{context}\n") | |
| return "\n".join(out) | |
| # Embedding-based GA search | |
| raw_sorted = _perform_search_ga(search_term, ga_all_embeddings, search_type) | |
| # Filter by status/category | |
| filtered = [] | |
| for index, score in raw_sorted: | |
| if index >= len(ga_resolutions_data): | |
| continue | |
| resolution = ga_resolutions_data[index] | |
| status = resolution.get('status') | |
| category = resolution.get('category') | |
| if hide_repealed and status == "Repealed": | |
| continue | |
| if hide_repeal_category and category == "Repeal": | |
| continue | |
| filtered.append((index, score)) | |
| out = [f"# Top 20 GA Resolution Search Results ({search_type.capitalize()})"] | |
| if not filtered: | |
| status_msgs = [] | |
| if hide_repealed: status_msgs.append("Repealed") | |
| if hide_repeal_category: status_msgs.append("Repeal Category") | |
| filter_msg = " (Filtered out " + " and ".join(status_msgs) + ")" if status_msgs else "" | |
| return "\n".join(out + [f"No matching resolutions found{filter_msg}."]) | |
| for rank, (index, score) in enumerate(filtered[:20], start=1): | |
| resolution = ga_resolutions_data[index] | |
| title = resolution.get('title', 'Untitled Resolution') | |
| res_id = resolution.get('id', 'N/A') | |
| council = resolution.get('council', 1) | |
| status = resolution.get('status') | |
| status_marker = "[REPEALED] " if status == "Repealed" else "" | |
| url = f"https://www.nationstates.net/page=WA_past_resolution/id={res_id}/council={council}" | |
| out.append(f"{rank}. {status_marker}[#{res_id} {title}]({url}), Similarity: {score:.4f}") | |
| return "\n".join(out) | |
| except Exception as e: | |
| return f"An error occurred during GA resolution search: {e}" | |
| # --- Sentiment Analysis Model and Functions --- | |
| print("Loading sentiment analysis model...") | |
| try: | |
| SENTIMENT_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
| sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_ID) | |
| sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_ID) | |
| print("Sentiment analysis model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading sentiment analysis model: {e}") | |
| sentiment_model = None | |
| def sentiment_analysis_func(text: str) -> dict: | |
| if not sentiment_model: | |
| return "Sentiment model not loaded." | |
| try: | |
| encoded_input = sentiment_tokenizer(text, return_tensors='pt') | |
| output = sentiment_model(**encoded_input) | |
| scores = output[0][0].detach().numpy() | |
| scores = softmax(scores) | |
| labels = sentiment_model.config.id2label | |
| results = {labels[i]: round(float(scores[i]), 4) for i in range(len(scores))} | |
| return results | |
| except Exception as e: | |
| return f"An error occurred during sentiment analysis: {e}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # NationStates Semantic Search | |
| Search NationStates issues and GA resolutions. Choose semantic for conceptual similarity, loose for keyword matching, and strict for exact substring queries. | |
| For semantic search, you can decide whether to search for only descriptions, only options, or both. For finding duplicate topics, I recommend using description-only. | |
| Please check the text of issue search results when determining whether your idea is a duplicate or not. | |
| """) | |
| with gr.Tabs() as tabs: | |
| # Issue Search Tab | |
| with gr.TabItem("Issue Search"): | |
| gr.Markdown(""" | |
| ### Search NationStates Issues | |
| """) | |
| issue_search_interface = gr.Interface( | |
| fn=search_issues, | |
| inputs=[ | |
| gr.Textbox(label="Search term", placeholder="What issue or option are you looking for?"), | |
| gr.Radio(["semantic", "loose", "strict"], label="Search Type", value="semantic", | |
| info="semantic: meaning-based; loose: keyword; strict: exact substring"), | |
| gr.Radio(["both", "descriptions", "options"], label="Scope (semantic only)", value="both", | |
| info="Only applies to semantic search; ignored for loose and strict.") | |
| ], | |
| outputs=gr.Markdown(), | |
| examples=[ | |
| ["coffee", "semantic", "both"], | |
| ["land value tax", "semantic", "descriptions"], | |
| ["chainsaw maniacs", "semantic", "options"], | |
| ["Elon Musk", "loose", "both"], | |
| ["environmental protection", "strict", "both"] | |
| ], | |
| title=None, | |
| description=None, | |
| submit_btn="Search Issues", | |
| article="Made by [Jiangbei](www.nationstates.net/nation=jiangbei). Issue data from Valentine Z. Powered by BAAI/bge-m3." | |
| ) | |
| # GA Resolution Search Tab | |
| with gr.TabItem("GA Resolution Search"): | |
| gr.Markdown(""" | |
| ### Search NationStates General Assembly Resolutions | |
| Use semantic for concepts, loose for keyword matching, or strict for exact substring. | |
| """) | |
| ga_search_term_input = gr.Textbox(label="Search term", placeholder="What are you looking for?") | |
| ga_hide_repealed_checkbox = gr.Checkbox(value=True, label="Hide repealed resolutions") | |
| ga_hide_repeal_category_checkbox = gr.Checkbox(value=True, label="Hide repeals") | |
| ga_search_type_radio = gr.Radio(["semantic", "loose", "strict"], label="Search Type", value="semantic", | |
| info="semantic: conceptual similarity; loose: keyword matching; strict: exact substring") | |
| ga_search_interface = gr.Interface( | |
| fn=search_ga_resolutions, | |
| inputs=[ | |
| ga_search_term_input, | |
| ga_hide_repealed_checkbox, | |
| ga_hide_repeal_category_checkbox, | |
| ga_search_type_radio | |
| ], | |
| outputs=gr.Markdown(), | |
| examples=[ | |
| ["condemn genocide", True, True, "semantic"], | |
| ["rights of animals", True, True, "loose"], | |
| ["regulating space mining", True, True, "semantic"], | |
| ["founding of the World Assembly", True, True, "semantic"], | |
| ["environmental protection", True, True, "semantic"], | |
| ["human rights", True, True, "strict"], | |
| ["World Assembly", True, True, "strict"] | |
| ], | |
| title=None, | |
| description=None, | |
| submit_btn="Search Resolutions", | |
| article="Made by [Jiangbei](www.nationstates.net/nation=jiangbei). GA data parsed from NationStates. Powered by BAAI/bge-m3." | |
| ) | |
| gr.api(sentiment_analysis_func, api_name="sentiment") | |
| from nationstates_ai import ns_ai_bot | |
| import threading | |
| USER_AGENT = os.environ.get("USER_AGENT") | |
| print(os.environ["AI_NATIONS"]) | |
| print(os.environ["AI_NATIONSTATES_PASSWORD"]) | |
| print(os.environ["AI_PROMPTS"]) | |
| AI_NATIONS = json.loads(os.environ["AI_NATIONS"]) | |
| AI_NATIONSTATES_PASSWORD = os.environ["AI_NATIONSTATES_PASSWORD"] | |
| AI_PROMPTS = json.loads(os.environ["AI_PROMPTS"]) | |
| def get_ai_coroutines( | |
| user_agent, compare_func, ns_password, nations, prompts | |
| ): | |
| ns_ai_coroutines = [] | |
| counter = 0 | |
| for index in range(len(nations)): | |
| ns_ai_coroutines.append( | |
| ns_ai_bot( | |
| nations[index], | |
| ns_password, | |
| compare_func, | |
| prompts[index], | |
| user_agent, | |
| counter * 5, | |
| )) | |
| counter += 1 | |
| return ns_ai_coroutines | |
| def run_ai_coroutines(): | |
| print("Starting NationStates AI...") | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| ai_coroutines = get_ai_coroutines(USER_AGENT, embedding_compare, AI_NATIONSTATES_PASSWORD, AI_NATIONS, AI_PROMPTS) | |
| results = loop.run_until_complete(asyncio.gather(*ai_coroutines)) | |
| print(f"NationStates AI finished (This should NOT happen, something went wrong if you see this)") | |
| loop.close() | |
| # --- Launch App --- | |
| if __name__ == "__main__": | |
| thread = threading.Thread(target=run_ai_coroutines) | |
| thread.start() | |
| demo.launch() | |