Spaces:
Sleeping
Sleeping
| from pysolr import Solr | |
| import os | |
| import csv | |
| from sentence_transformers import SentenceTransformer, util | |
| import torch | |
| from get_keywords import get_keywords | |
| import os | |
| """ | |
| This function creates top 15 articles from Solr and saves them in a csv file | |
| Input: | |
| query: str | |
| num_articles: int | |
| keyword_type: str (openai, rake, or na) | |
| Output: path to csv file | |
| """ | |
| def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str: | |
| keywords = get_keywords(query, keyword_type) | |
| if keyword_type == "na": | |
| keywords = query | |
| return save_solr_articles(keywords, num_articles) | |
| """ | |
| Removes spaces and newlines from text | |
| Input: text: str | |
| Output: text: str | |
| """ | |
| def remove_spaces_newlines(text: str) -> str: | |
| text = text.replace('\n', ' ') | |
| text = text.replace(' ', ' ') | |
| return text | |
| # truncates long articles to 1500 words | |
| def truncate_article(text: str) -> str: | |
| split = text.split() | |
| if len(split) > 1500: | |
| split = split[:1500] | |
| text = ' '.join(split) | |
| return text | |
| """ | |
| Searches Solr for articles based on keywords and saves them in a csv file | |
| Input: | |
| keywords: str | |
| num_articles: int | |
| Output: path to csv file | |
| Minor details: | |
| Removes duplicate articles to start with. | |
| Articles with dead urls are removed since those articles are often wierd. | |
| Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes. | |
| If one of title, uuid, cleaned_content, url are missing the article is skipped. | |
| """ | |
| def save_solr_articles(keywords: str, num_articles=15) -> str: | |
| solr_key = os.getenv("SOLR_KEY") | |
| SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/" | |
| solr = Solr(SOLR_ARTICLES_URL, verify=False) | |
| # No duplicates | |
| fq = ['-dups:0'] | |
| query = f'text:({keywords})' + " AND " + "dead_url:(false)" | |
| # Get top 2*num_articles articles and then remove misformed or duplicate articles | |
| outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2) | |
| article_count = 0 | |
| save_path = os.path.join("data", "articles.csv") | |
| if not os.path.exists(os.path.dirname(save_path)): | |
| os.makedirs(os.path.dirname(save_path)) | |
| with open(save_path, 'w', newline='') as csvfile: | |
| fieldnames = ['title', 'uuid', 'content', 'url', 'domain'] | |
| writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC) | |
| writer.writeheader() | |
| title_five_words = set() | |
| for d in outputs.docs: | |
| print('dictionary of article',d) | |
| if article_count == num_articles: | |
| break | |
| # skip if title returns a keyerror | |
| if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d: | |
| continue | |
| title_cleaned = remove_spaces_newlines(d['title']) | |
| split = title_cleaned.split() | |
| # skip if title is a duplicate | |
| if not len(split) < 5: | |
| five_words = title_cleaned.split()[:5] | |
| five_words = ' '.join(five_words) | |
| if five_words in title_five_words: | |
| continue | |
| title_five_words.add(five_words) | |
| article_count += 1 | |
| cleaned_content = remove_spaces_newlines(d['cleaned_content']) | |
| cleaned_content = truncate_article(cleaned_content) | |
| domain = "" | |
| if 'domain' not in d: | |
| domain = "Not Specified" | |
| else: | |
| domain = d['domain'] | |
| print(domain) | |
| writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'], | |
| 'domain': domain}) | |
| return save_path | |
| def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15): | |
| bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
| query_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
| hits = util.semantic_search(query_embedding, article_embeddings, top_k=15) | |
| hits = hits[0] | |
| corpus_ids = [item['corpus_id'] for item in hits] | |
| r_contents = [contents[idx] for idx in corpus_ids] | |
| r_titles = [titles[idx] for idx in corpus_ids] | |
| r_uuids = [uuids[idx] for idx in corpus_ids] | |
| r_urls = [urls[idx] for idx in corpus_ids] | |
| save_path = os.path.join("data", "articles.csv") | |
| if not os.path.exists(os.path.dirname(save_path)): | |
| os.makedirs(os.path.dirname(save_path)) | |
| with open(save_path, 'w', newline='', encoding="utf-8") as csvfile: | |
| fieldNames = ['title', 'uuid', 'content', 'url'] | |
| writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC) | |
| writer.writeheader() | |
| for i in range(num_articles): | |
| writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]}) | |
| return save_path |