M17idd commited on
Commit
b7b439e
·
verified ·
1 Parent(s): fb38c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -64,36 +64,31 @@ class GroqLLM(OpenAI):
64
  llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
65
 
66
  # ----------------- ساخت SimpleRetriever -----------------
67
- class SimpleRetriever:
68
  def __init__(self, documents, embeddings):
69
  self.documents = documents
70
  self.embeddings = embeddings
71
 
72
- def retrieve(self, query, top_k=1):
73
- # تبدیل سوال به امبدینگ با استفاده از FarsiBERT
74
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
75
- with torch.no_grad():
76
- query_embedding = model(**inputs).last_hidden_state.mean(dim=1).numpy()
77
-
78
  similarities = []
79
  for doc_embedding in self.embeddings:
80
- similarity = query_embedding.dot(doc_embedding) # محاسبه شباهت
81
  similarities.append(similarity)
82
-
83
- # یافتن مستندات مشابه
84
  ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
85
- return [doc for _, doc in ranked_docs[:top_k]]
 
86
 
87
- # استفاده از retriever برای جستجو
88
  retriever = SimpleRetriever(documents, embeddings)
89
 
90
  chain = RetrievalQA.from_chain_type(
91
  llm=llm,
92
- retriever=retriever,
93
  chain_type="stuff",
94
  input_key="question"
95
  )
96
-
97
  # ----------------- استیت برای چت -----------------
98
  if 'messages' not in st.session_state:
99
  st.session_state.messages = []
 
64
  llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
65
 
66
  # ----------------- ساخت SimpleRetriever -----------------
67
+ class SimpleRetriever(BaseRetriever):
68
  def __init__(self, documents, embeddings):
69
  self.documents = documents
70
  self.embeddings = embeddings
71
 
72
+ def _get_relevant_documents(self, query):
73
+ query_embedding = nlp(query).vector # تبدیل سوال به امبدینگ با استفاده از spaCy
 
 
 
 
74
  similarities = []
75
  for doc_embedding in self.embeddings:
76
+ similarity = query_embedding.dot(doc_embedding) # محاسبه شباهت بین امبدینگ‌ها
77
  similarities.append(similarity)
78
+
79
+ # یافتن مستندات مشابه بر اساس بیشترین شباهت
80
  ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
81
+ return [doc for _, doc in ranked_docs[:5]] # بازگرداندن 5 مستند مشابه
82
+
83
 
 
84
  retriever = SimpleRetriever(documents, embeddings)
85
 
86
  chain = RetrievalQA.from_chain_type(
87
  llm=llm,
88
+ retriever=retriever, # ارسال نمونه از retriever
89
  chain_type="stuff",
90
  input_key="question"
91
  )
 
92
  # ----------------- استیت برای چت -----------------
93
  if 'messages' not in st.session_state:
94
  st.session_state.messages = []