Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from operator import itemgetter
|
3 |
from pathlib import Path
|
4 |
from typing import List, Optional, Dict, Any
|
5 |
import logging
|
@@ -8,17 +6,13 @@ from enum import Enum
|
|
8 |
import gradio as gr
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
from langchain_community.vectorstores import Chroma
|
11 |
-
from langchain.schema import BaseRetriever
|
12 |
from langchain.embeddings.base import Embeddings
|
13 |
-
from langchain.llms.base import BaseLanguageModel
|
14 |
import PyPDF2
|
15 |
from huggingface_hub import InferenceClient
|
16 |
-
# Install required packages
|
17 |
-
|
18 |
-
|
19 |
-
# Initialize models
|
20 |
import torch
|
21 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
|
|
|
|
22 |
embed_model = HuggingFaceBgeEmbeddings(
|
23 |
model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5",
|
24 |
model_kwargs={'device': 'cpu'},
|
@@ -39,6 +33,7 @@ class DocumentFormat(Enum):
|
|
39 |
PDF = ".pdf"
|
40 |
# Can be extended for other document types
|
41 |
|
|
|
42 |
@dataclass
|
43 |
class RAGConfig:
|
44 |
"""Configuration for RAG system parameters"""
|
@@ -47,15 +42,14 @@ class RAGConfig:
|
|
47 |
retriever_k: int = 3
|
48 |
persist_directory: str = "./chroma_db"
|
49 |
|
|
|
50 |
class AdvancedRAGSystem:
|
51 |
"""Advanced RAG System with improved error handling and type safety"""
|
52 |
-
|
53 |
-
|
54 |
def __init__(
|
55 |
self,
|
56 |
-
embed_model
|
57 |
-
llm
|
58 |
-
config
|
59 |
):
|
60 |
"""Initialize the RAG system with required models and optional configuration"""
|
61 |
self.embed_model = embed_model
|
@@ -166,19 +160,12 @@ Context:
|
|
166 |
}
|
167 |
]
|
168 |
|
169 |
-
response_text = ""
|
170 |
return self.llm.chat.completions.create(
|
171 |
model=model_name,
|
172 |
messages=messages,
|
173 |
max_tokens=500,
|
174 |
# stream=True
|
175 |
).choices[0].message.content
|
176 |
-
# return stream.choices[0].message.content
|
177 |
-
# if hasattr(chunk.choices[0].delta, 'content'):
|
178 |
-
# content = chunk.choices[0].delta.content
|
179 |
-
# if content is not None:
|
180 |
-
# response_text += content
|
181 |
-
# yield response_text
|
182 |
|
183 |
except Exception as e:
|
184 |
error_msg = f"Error during query processing: {str(e)}"
|
@@ -186,7 +173,9 @@ Context:
|
|
186 |
return error_msg
|
187 |
|
188 |
|
189 |
-
|
|
|
|
|
190 |
def create_gradio_interface(rag_system: AdvancedRAGSystem) :
|
191 |
"""Create an improved Gradio interface for the RAG system"""
|
192 |
|
@@ -207,8 +196,6 @@ def create_gradio_interface(rag_system: AdvancedRAGSystem) :
|
|
207 |
def query_streaming(question: str) :
|
208 |
try:
|
209 |
return rag_system.query(question)
|
210 |
-
# for response in rag_system.query(question):
|
211 |
-
# yield response
|
212 |
except Exception as e:
|
213 |
return f"Error: {str(e)}"
|
214 |
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
from typing import List, Optional, Dict, Any
|
3 |
import logging
|
|
|
6 |
import gradio as gr
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain_community.vectorstores import Chroma
|
|
|
9 |
from langchain.embeddings.base import Embeddings
|
|
|
10 |
import PyPDF2
|
11 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
12 |
import torch
|
13 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
14 |
+
# Install required packages
|
15 |
+
|
16 |
embed_model = HuggingFaceBgeEmbeddings(
|
17 |
model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5",
|
18 |
model_kwargs={'device': 'cpu'},
|
|
|
33 |
PDF = ".pdf"
|
34 |
# Can be extended for other document types
|
35 |
|
36 |
+
|
37 |
@dataclass
|
38 |
class RAGConfig:
|
39 |
"""Configuration for RAG system parameters"""
|
|
|
42 |
retriever_k: int = 3
|
43 |
persist_directory: str = "./chroma_db"
|
44 |
|
45 |
+
|
46 |
class AdvancedRAGSystem:
|
47 |
"""Advanced RAG System with improved error handling and type safety"""
|
|
|
|
|
48 |
def __init__(
|
49 |
self,
|
50 |
+
embed_model,
|
51 |
+
llm,
|
52 |
+
config = None
|
53 |
):
|
54 |
"""Initialize the RAG system with required models and optional configuration"""
|
55 |
self.embed_model = embed_model
|
|
|
160 |
}
|
161 |
]
|
162 |
|
|
|
163 |
return self.llm.chat.completions.create(
|
164 |
model=model_name,
|
165 |
messages=messages,
|
166 |
max_tokens=500,
|
167 |
# stream=True
|
168 |
).choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
except Exception as e:
|
171 |
error_msg = f"Error during query processing: {str(e)}"
|
|
|
173 |
return error_msg
|
174 |
|
175 |
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
def create_gradio_interface(rag_system: AdvancedRAGSystem) :
|
180 |
"""Create an improved Gradio interface for the RAG system"""
|
181 |
|
|
|
196 |
def query_streaming(question: str) :
|
197 |
try:
|
198 |
return rag_system.query(question)
|
|
|
|
|
199 |
except Exception as e:
|
200 |
return f"Error: {str(e)}"
|
201 |
|