Update app.py
Browse files
app.py
CHANGED
@@ -64,31 +64,33 @@ class GroqLLM(OpenAI):
|
|
64 |
llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
|
65 |
|
66 |
# ----------------- ساخت SimpleRetriever -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
class SimpleRetriever(BaseRetriever):
|
|
|
|
|
|
|
68 |
def __init__(self):
|
|
|
69 |
self.documents, self.embeddings = build_pdf_index()
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
def _get_relevant_documents(self, query):
|
73 |
-
query_embedding = nlp(query).vector # تبدیل سوال به امبدینگ با استفاده از spaCy
|
74 |
similarities = []
|
75 |
for doc_embedding in self.embeddings:
|
76 |
-
similarity = query_embedding.
|
77 |
similarities.append(similarity)
|
78 |
-
|
79 |
-
# یافتن مستندات مشابه بر اساس بیشترین شباهت
|
80 |
-
ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
|
81 |
-
return [doc for _, doc in ranked_docs[:5]] # بازگرداندن 5 مستند مشابه
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
chain = RetrievalQA.from_chain_type(
|
87 |
-
llm=llm,
|
88 |
-
retriever=retriever, # ارسال نمونه از retriever
|
89 |
-
chain_type="stuff",
|
90 |
-
input_key="question"
|
91 |
-
)
|
92 |
# ----------------- استیت برای چت -----------------
|
93 |
if 'messages' not in st.session_state:
|
94 |
st.session_state.messages = []
|
|
|
64 |
llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
|
65 |
|
66 |
# ----------------- ساخت SimpleRetriever -----------------
|
67 |
+
from langchain_core.retrievers import BaseRetriever
|
68 |
+
from langchain_core.documents import Document
|
69 |
+
from typing import List
|
70 |
+
from dataclasses import dataclass, field
|
71 |
+
|
72 |
+
@dataclass
|
73 |
class SimpleRetriever(BaseRetriever):
|
74 |
+
documents: List[Document] = field(default_factory=list)
|
75 |
+
embeddings: List = field(default_factory=list)
|
76 |
+
|
77 |
def __init__(self):
|
78 |
+
super().__init__()
|
79 |
self.documents, self.embeddings = build_pdf_index()
|
80 |
|
81 |
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
82 |
+
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
|
83 |
+
with torch.no_grad():
|
84 |
+
outputs = model(**inputs)
|
85 |
+
query_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
|
86 |
|
|
|
|
|
87 |
similarities = []
|
88 |
for doc_embedding in self.embeddings:
|
89 |
+
similarity = (query_embedding * doc_embedding).sum()
|
90 |
similarities.append(similarity)
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
|
93 |
+
return [doc for _, doc in ranked_docs[:5]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
# ----------------- استیت برای چت -----------------
|
95 |
if 'messages' not in st.session_state:
|
96 |
st.session_state.messages = []
|