M17idd commited on
Commit
5054e30
·
verified ·
1 Parent(s): 3652b60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -64,31 +64,33 @@ class GroqLLM(OpenAI):
64
  llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
65
 
66
  # ----------------- ساخت SimpleRetriever -----------------
 
 
 
 
 
 
67
  class SimpleRetriever(BaseRetriever):
 
 
 
68
  def __init__(self):
 
69
  self.documents, self.embeddings = build_pdf_index()
70
 
 
 
 
 
 
71
 
72
- def _get_relevant_documents(self, query):
73
- query_embedding = nlp(query).vector # تبدیل سوال به امبدینگ با استفاده از spaCy
74
  similarities = []
75
  for doc_embedding in self.embeddings:
76
- similarity = query_embedding.dot(doc_embedding) # محاسبه شباهت بین امبدینگ‌ها
77
  similarities.append(similarity)
78
-
79
- # یافتن مستندات مشابه بر اساس بیشترین شباهت
80
- ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
81
- return [doc for _, doc in ranked_docs[:5]] # بازگرداندن 5 مستند مشابه
82
 
83
-
84
- retriever = SimpleRetriever()
85
-
86
- chain = RetrievalQA.from_chain_type(
87
- llm=llm,
88
- retriever=retriever, # ارسال نمونه از retriever
89
- chain_type="stuff",
90
- input_key="question"
91
- )
92
  # ----------------- استیت برای چت -----------------
93
  if 'messages' not in st.session_state:
94
  st.session_state.messages = []
 
64
  llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
65
 
66
  # ----------------- ساخت SimpleRetriever -----------------
67
+ from langchain_core.retrievers import BaseRetriever
68
+ from langchain_core.documents import Document
69
+ from typing import List
70
+ from dataclasses import dataclass, field
71
+
72
+ @dataclass
73
  class SimpleRetriever(BaseRetriever):
74
+ documents: List[Document] = field(default_factory=list)
75
+ embeddings: List = field(default_factory=list)
76
+
77
  def __init__(self):
78
+ super().__init__()
79
  self.documents, self.embeddings = build_pdf_index()
80
 
81
+ def _get_relevant_documents(self, query: str) -> List[Document]:
82
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
83
+ with torch.no_grad():
84
+ outputs = model(**inputs)
85
+ query_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
86
 
 
 
87
  similarities = []
88
  for doc_embedding in self.embeddings:
89
+ similarity = (query_embedding * doc_embedding).sum()
90
  similarities.append(similarity)
 
 
 
 
91
 
92
+ ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
93
+ return [doc for _, doc in ranked_docs[:5]]
 
 
 
 
 
 
 
94
  # ----------------- استیت برای چت -----------------
95
  if 'messages' not in st.session_state:
96
  st.session_state.messages = []