Nigel Thomas commited on
Commit
e38a4a7
Β·
1 Parent(s): 1cda6e3

Updated code

Browse files
Files changed (1) hide show
  1. app.py +208 -0
app.py CHANGED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import re
5
+ import tempfile
6
+ import torch
7
+ from datetime import datetime
8
+ from langchain_community.document_loaders import PDFPlumberLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_community.llms import Ollama
13
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
14
+ from sentence_transformers import CrossEncoder
15
+ from transformers import pipeline
16
+ from langchain_core.prompts import PromptTemplate
17
+ from langchain.chains import LLMChain
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
19
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
20
+ from huggingface_hub import login
21
+
22
+
23
+ # Load the model and tokenizer
24
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
25
+
26
+ # Initialize classifier once for input guardrail
27
+ classifier = pipeline("zero-shot-classification",
28
+ model="typeform/distilbert-base-uncased-mnli")
29
+
30
+ # Streamlit UI Configuration
31
+ st.set_page_config(page_title="Multi-File Financial Analyzer", layout="wide")
32
+ st.title("πŸ“Š Financial Analysis System")
33
+
34
+ # Sidebar Controls
35
+ with st.sidebar:
36
+ st.header("Configuration Panel")
37
+ model_choice = st.selectbox("LLM Model",
38
+ [model_name],
39
+ help="Choose the core analysis engine")
40
+ chunk_size = st.slider("Document Chunk Size", 500, 2000, 1000)
41
+ rerank_threshold = st.slider("Re-ranking Threshold", 0.0, 1.0, 0.1)
42
+
43
+ # File Upload Handling for multiple files
44
+ uploaded_files = st.file_uploader("Upload Financial PDFs",
45
+ type="pdf",
46
+ accept_multiple_files=True)
47
+
48
+ if uploaded_files:
49
+ all_docs = []
50
+ with st.spinner("Processing Multiple Financial Documents..."):
51
+ for uploaded_file in uploaded_files:
52
+ # Create temporary file for each PDF
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
54
+ tmp.write(uploaded_file.getvalue())
55
+ tmp_path = tmp.name
56
+
57
+ # Load and process each document
58
+ loader = PDFPlumberLoader(tmp_path)
59
+ docs = loader.load()
60
+ all_docs.extend(docs)
61
+
62
+ # Combined Document Processing
63
+ text_splitter = RecursiveCharacterTextSplitter(
64
+ chunk_size=chunk_size,
65
+ chunk_overlap=200,
66
+ separators=["\n\n", "\n", "\. ", "! ", "? ", " ", ""]
67
+ )
68
+ documents = text_splitter.split_documents(all_docs)
69
+
70
+ # Hybrid Retrieval Setup for combined documents
71
+ embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
72
+ vector_store = FAISS.from_documents(documents, embedder)
73
+ bm25_retriever = BM25Retriever.from_documents(documents)
74
+ bm25_retriever.k = 5
75
+ faiss_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
76
+ ensemble_retriever = EnsembleRetriever(
77
+ retrievers=[bm25_retriever, faiss_retriever],
78
+ weights=[0.4, 0.6]
79
+ )
80
+
81
+ # Re-ranking Model
82
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
83
+
84
+ # Financial Analysis LLM Configuration
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ model_name,
87
+ trust_remote_code=True,
88
+ padding_side="left" # Important for some models
89
+ )
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ model_name,
92
+ trust_remote_code=True,
93
+ )
94
+
95
+ # Create pipeline with generation parameters
96
+ pipeline_llm = pipeline(
97
+ "text-generation",
98
+ model=model,
99
+ tokenizer=tokenizer,
100
+ max_new_tokens=1024,
101
+ temperature=0.3,
102
+ top_p=0.95,
103
+ repetition_penalty=1.15,
104
+ return_full_text=False # Important for response formatting
105
+ )
106
+
107
+ llm = HuggingFacePipeline(pipeline=pipeline_llm)
108
+
109
+ # Update prompt template
110
+ PROMPT_TEMPLATE = """
111
+ <|system|>
112
+ You are a senior financial analyst. Analyze these financial reports:
113
+ 1. Compare key metrics between documents
114
+ 2. Identify trends across reporting periods
115
+ 3. Highlight differences/similarities
116
+ 4. Provide risk assessment
117
+ 5. Offer recommendations
118
+
119
+ Format response with clear sections and bullet points. Keep under 300 words.
120
+
121
+ Context: {context}
122
+ Question: {question}
123
+ <|assistant|>
124
+ """
125
+ # chat prompt template
126
+ qa_prompt = PromptTemplate(
127
+ template=PROMPT_TEMPLATE,
128
+ input_variables=["context", "question"]
129
+ )
130
+ llm_chain = LLMChain(llm=llm, prompt=qa_prompt)
131
+
132
+ # Interactive Q&A Interface
133
+ st.header("πŸ” Cross-Document Financial Inquiry")
134
+
135
+ # Suggested Comparative Questions
136
+ comparative_questions = [
137
+ "Analyze changes in debt structure across both reports",
138
+ "Show expense ratio differences between the two years",
139
+ "What are the main liquidity changes across both periods?",
140
+ ]
141
+ user_query = st.selectbox("Sample Financial Questions",
142
+ [""] + comparative_questions)
143
+ user_input = st.text_input("Or enter custom financial query:",
144
+ value=user_query)
145
+
146
+ if user_input:
147
+ # Input Validation Guardrail
148
+ classification = classifier(user_input,
149
+ ["financial", "other"],
150
+ multi_label=False)
151
+ print(f"-- Guard rail check is completed for query with prob:{classification['scores'][0]}")
152
+
153
+ if classification['scores'][0] < 0.7:
154
+ st.error("Query not related to financial. Ask about financial related queries")
155
+ st.stop()
156
+
157
+ with st.spinner("Performing Cross-Document Analysis..."):
158
+ # Hybrid Document Retrieval
159
+ initial_docs = ensemble_retriever.get_relevant_documents(user_input)
160
+
161
+ # Context Re-ranking
162
+ doc_pairs = [(user_input, doc.page_content) for doc in initial_docs]
163
+ rerank_scores = cross_encoder.predict(doc_pairs)
164
+ sorted_indices = np.argsort(rerank_scores)[::-1]
165
+ ranked_docs = [initial_docs[i] for i in sorted_indices]
166
+ filtered_docs = [d for d, s in zip(ranked_docs, rerank_scores)
167
+ if s > rerank_threshold][:7]
168
+ print(f"-- Retrieved chunks:{filtered_docs}")
169
+
170
+ # Confidence Calculation
171
+ confidence_score = np.mean(rerank_scores[sorted_indices][:3]) * 100
172
+ confidence_score = min(100, max(0, round(confidence_score, 1)))
173
+
174
+ # Response Generation
175
+ context = "\n".join([doc.page_content for doc in filtered_docs])
176
+ print(f"-- Retrieved context:{context}")
177
+
178
+ analysis = llm_chain.run(
179
+ context=context,
180
+ question=user_input
181
+ )
182
+ print(f"Analysis result:{analysis}")
183
+
184
+ # Response Cleaning
185
+ clean_analysis = re.sub(r"<think>|</think>|\n{3,}", "", analysis)
186
+ clean_analysis = re.sub(r'(\d)([A-Za-z])', r'\1 \2', clean_analysis)
187
+ clean_analysis = re.sub(r'(\d{1,3})(\d{3})', r'\1,\2', clean_analysis)
188
+
189
+ # Input Display
190
+ st.subheader("User Query+Context to the LLM")
191
+ st.markdown(f"```\n{qa_prompt.format(context=context, question=user_input)}\n```")
192
+
193
+ # Results Display
194
+ st.subheader("Integrated Financial Analysis")
195
+ st.markdown(f"```\n{clean_analysis}\n```")
196
+ st.progress(int(confidence_score)/100)
197
+ st.caption(f"Analysis Confidence: {confidence_score}%")
198
+
199
+ # Export Functionality
200
+ if st.button("Generate Financial Analysis Report"):
201
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
202
+ export_content = f"COMPARATIVE QUERY: {user_input}\n\nANALYSIS:\n{clean_analysis}"
203
+ st.download_button("Download Full Report", export_content,
204
+ file_name=f"Comparative_Analysis_{timestamp}.txt",
205
+ mime="text/plain")
206
+
207
+ else:
208
+ st.info("Please upload PDF financial reports to begin financial analysis")