M17idd commited on
Commit
5fc8461
·
verified ·
1 Parent(s): 32a149b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -15
app.py CHANGED
@@ -2,13 +2,10 @@ import os
2
  import time
3
  import streamlit as st
4
  from langchain.chat_models import ChatOpenAI
5
-
6
- from transformers import AutoTokenizer, AutoModel
7
  from langchain.document_loaders import PyPDFLoader
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.schema import Document as LangchainDocument
10
  from langchain.chains import RetrievalQA
11
- from langchain.llms import OpenAI
12
  import torch
13
  from langchain_core.retrievers import BaseRetriever
14
  from langchain_core.documents import Document
@@ -158,24 +155,21 @@ llm = ChatOpenAI(
158
  # ----------------- تعریف SimpleRetriever -----------------
159
  class SimpleRetriever(BaseRetriever):
160
  documents: List[Document] = Field(...)
161
- embeddings: List = Field(...)
162
 
163
  def _get_relevant_documents(self, query: str) -> List[Document]:
164
- # فقط از sentence_model استفاده می‌کنیم
165
  sentence_model = SentenceTransformer("aubmindlab/bert-base-arabert")
166
  query_embedding = sentence_model.encode(query, convert_to_numpy=True)
167
 
168
- similarities = []
169
- for doc_embedding in self.embeddings:
170
- similarity = (query_embedding * doc_embedding).sum()
171
- similarities.append(similarity)
172
 
173
- ranked_docs = sorted(
174
- zip(similarities, self.documents),
175
- key=lambda x: x[0],
176
- reverse=True
177
- )
178
- return [doc for _, doc in ranked_docs[:5]]
179
 
180
  # ----------------- ساخت Index -----------------
181
  documents, embeddings = build_pdf_index()
 
2
  import time
3
  import streamlit as st
4
  from langchain.chat_models import ChatOpenAI
 
 
5
  from langchain.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.schema import Document as LangchainDocument
8
  from langchain.chains import RetrievalQA
 
9
  import torch
10
  from langchain_core.retrievers import BaseRetriever
11
  from langchain_core.documents import Document
 
155
  # ----------------- تعریف SimpleRetriever -----------------
156
  class SimpleRetriever(BaseRetriever):
157
  documents: List[Document] = Field(...)
158
+ embeddings: List[np.ndarray] = Field(...)
159
 
160
  def _get_relevant_documents(self, query: str) -> List[Document]:
161
+ # استفاده از sentence_model برای تبدیل query به بردار
162
  sentence_model = SentenceTransformer("aubmindlab/bert-base-arabert")
163
  query_embedding = sentence_model.encode(query, convert_to_numpy=True)
164
 
165
+ # محاسبه شباهت‌های برداری برای تمام اسناد
166
+ similarities = np.dot(self.embeddings, query_embedding)
 
 
167
 
168
+ # ترتیب‌دهی اسناد بر اساس شباهت‌ها
169
+ ranked_docs = np.argsort(similarities)[::-1]
170
+
171
+ # برگشتن به ۵ سند برتر
172
+ return [self.documents[i] for i in ranked_docs[:5]]
 
173
 
174
  # ----------------- ساخت Index -----------------
175
  documents, embeddings = build_pdf_index()