Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import spaces | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader | |
| # --- 1. Model and Tokenizer Loading --- | |
| print("Initializing model and tokenizer...") | |
| model_name = "gitglubber/Ntfy" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print(f"Model and tokenizer loaded successfully on {'GPU' if torch.cuda.is_available() else 'CPU'}.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # --- 2. RAG Knowledge Base Setup (from Directory) --- | |
| print("Setting up RAG knowledge base from directory...") | |
| try: | |
| # Look for .md files in the current directory | |
| current_dir = './' | |
| # Load all .md files from the current directory only (not subdirectories) | |
| loader = DirectoryLoader(current_dir, glob="*.md", loader_cls=TextLoader) | |
| documents = loader.load() | |
| # If no .md files found, create a default one | |
| if not documents: | |
| print("No .md files found in current directory. Creating default knowledge base.") | |
| with open('./default.md', 'w') as f: | |
| f.write("# Default Knowledge Base\nNo additional documentation found.") | |
| # Load the newly created file | |
| loader = DirectoryLoader(current_dir, glob="*.md", loader_cls=TextLoader) | |
| documents = loader.load() | |
| if not documents: | |
| print("No documents found in knowledge base directory.") | |
| # Create a minimal document to prevent errors | |
| from langchain.schema import Document | |
| documents = [Document(page_content="No additional documentation available.", metadata={"source": "default"})] | |
| # Split the documents into chunks | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
| docs = text_splitter.split_documents(documents) | |
| # Create embeddings - force CPU to avoid ZeroGPU device conflicts | |
| embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| model_kwargs = {'device': 'cpu'} # Force CPU for embeddings to avoid device conflicts | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model_name, | |
| model_kwargs=model_kwargs | |
| ) | |
| # Create a FAISS vector store from the documents | |
| vector_store = FAISS.from_documents(docs, embeddings) | |
| retriever = vector_store.as_retriever() | |
| print(f"RAG knowledge base created successfully from {len(documents)} document(s).") | |
| except Exception as e: | |
| print(f"Error setting up knowledge base: {e}") | |
| raise | |
| # --- 3. System Message --- | |
| system_message = """You are an expert technical assistant for ntfy, a publish-subscribe notification service. | |
| Your purpose is to provide users with accurate, clear, and helpful information about the ntfy project. | |
| You will be provided with relevant context from the ntfy documentation to help you answer the user's question. Prioritize this information in your response. | |
| Always structure your answers for clarity using Markdown (lists, bold text, code blocks). | |
| If the provided context does not contain the answer, state that the information is not available in the documentation. | |
| Stick strictly to the topic of ntfy. | |
| **Fun Facts** to keep in mind: | |
| wunter8 is a mod on the discord - he is the most active member. | |
| binwiederhier is the owner/ developer of the project. | |
| The official github link is - https://github.com/binwiederhier/ntfy. | |
| **End Fun facts** | |
| **Common Question:** | |
| Question: Why aren't my IOS push notifications working? | |
| Answer: These are the things you need to do to get iOS push notifications to work: | |
| open a browser to the web app of your ntfy instance and copy the URL (including "http://" or "https://", your domain or IP address, and any ports, and excluding any trailing slashes) | |
| put the URL you copied in the ntfy base-url config in server.yml or NTFY_BASE_URL in env variables | |
| put the URL you copied in the default server URL setting in the iOS ntfy app | |
| set upstream-base-url in server.yml or NTFY_UPSTREAM_BASE_URL in env variables to "https://ntfy.sh/" (without a trailing slash) | |
| server.yml configuration - is solved by reading the config knowledge_base doc | |
| """ | |
| # --- 4. Gradio Interface with RAG and Streaming --- | |
| with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# NTFY Expert Chat Bot (with RAG)") | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| avatar_images=(None, "https://docs.ntfy.sh/static/img/ntfy.png"), | |
| scale=1 | |
| ) | |
| msg = gr.Textbox(label="Input", scale=0, placeholder="Ask me a question about ntfy...") | |
| clear = gr.Button("Clear") | |
| def respond(message, chat_history): | |
| """ | |
| Gradio response function that uses RAG and streams model output. | |
| """ | |
| if not message.strip(): | |
| yield "", chat_history | |
| return | |
| chat_history.append((message, "")) | |
| yield "", chat_history | |
| # --- RAG: Retrieve relevant context --- | |
| retrieved_docs = retriever.get_relevant_documents(message) | |
| context = "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
| # --- Prepare model input with context --- | |
| rag_prompt = f"""Use the following context to answer the user's question. | |
| **Context:** | |
| --- | |
| {context} | |
| --- | |
| **User Question:** {message} | |
| """ | |
| messages = [{"role": "system", "content": system_message}] | |
| # Add previous user/assistant messages for conversation history | |
| for user_msg, assistant_msg in chat_history[:-1]: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg is not None: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add the new RAG-enhanced prompt | |
| messages.append({"role": "user", "content": rag_prompt}) | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # --- Setup the streamer and generation thread --- | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=8192, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| temperature=0.7, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # --- Yield tokens as they become available --- | |
| bot_response = "" | |
| for new_text in streamer: | |
| bot_response += new_text | |
| chat_history[-1] = (message, bot_response) | |
| yield "", chat_history | |
| # Wire up the Gradio components | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| clear.click(lambda: [], None, chatbot, queue=False) | |
| # Launch the app | |
| demo.queue().launch(debug=True) | |