M17idd commited on
Commit
2c4dca2
·
verified ·
1 Parent(s): 070bb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -63
app.py CHANGED
@@ -4,21 +4,15 @@ import streamlit as st
4
  from langchain.chat_models import ChatOpenAI
5
  from langchain.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.schema import Document as LangchainDocument
8
  from langchain.chains import RetrievalQA
9
- import torch
10
  from langchain_core.retrievers import BaseRetriever
11
- from langchain_core.documents import Document
12
  from typing import List
13
  from pydantic import Field
14
- from sentence_transformers import SentenceTransformer
15
  import numpy as np
16
- from langchain.indexes.vectorstore import VectorstoreIndexCreator
17
-
18
  from sentence_transformers import SentenceTransformer
19
  import faiss
20
 
21
-
22
  # ----------------- تنظیمات صفحه -----------------
23
  st.set_page_config(page_title="چت‌ بات توانا", page_icon="🪖", layout="wide")
24
 
@@ -105,79 +99,68 @@ st.markdown("""
105
  @st.cache_resource
106
  def get_pdf_index():
107
  with st.spinner('📄 در حال پردازش فایل PDF...'):
108
- # بارگذاری PDF
109
  loader = PyPDFLoader('test1.pdf')
110
- documents = loader.load_and_split() # اینجا متن PDF را استخراج می‌کنیم
111
-
112
- # استفاده از مدل embeddings
113
- model = TogetherEmbeddings(
114
- model_name="togethercomputer/m2-bert-80M-8k-retrieval",
115
- api_key="0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979"
116
- )
117
 
118
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
119
  texts = []
120
  for doc in documents:
121
- texts.extend(splitter.split_text(doc.page_content)) # دسترسی به متن از ویژگی page_content
 
 
 
122
 
123
  # تولید امبدینگ‌ها
124
- embeddings = []
125
- batch_size = 512
126
- for i in range(0, len(texts), batch_size):
127
- batch_texts = texts[i:i + batch_size]
128
- batch_embeddings = model.encode(batch_texts, convert_to_numpy=True)
129
- embeddings.extend(batch_embeddings)
130
-
131
- # زمان دادن به progress bar برای خالی شدن
132
- time.sleep(1)
133
-
134
- # تبدیل لیست embeddings به آرایه numpy
135
- embeddings = np.array(embeddings)
136
-
137
- # ایجاد ایندکس با استفاده از VectorstoreIndexCreator
138
- return VectorstoreIndexCreator(
139
- embedding_function=model.encode, # استفاده از تابع انکودینگ برای تولید امبدینگ‌ها
140
- text_splitter=splitter
141
- ).from_texts(texts, embeddings=embeddings)
142
-
143
- # ----------------- تعریف LLM از Groq -----------------
144
  llm = ChatOpenAI(
145
  base_url="https://api.together.xyz/v1",
146
  api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
147
- model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
148
  )
149
 
150
-
151
  # ----------------- تعریف SimpleRetriever -----------------
152
  class SimpleRetriever(BaseRetriever):
153
- documents: List[dict] = Field(...) # تغییر نوع مستند به dict
154
- embeddings: List[np.ndarray] = Field(...)
155
  index: faiss.Index
156
-
157
- def _get_relevant_documents(self, query: str) -> List[dict]:
158
- sentence_model = SentenceTransformer("togethercomputer/m2-bert-80M-8k-retrieval", trust_remote_code=True)
159
- query_embedding = sentence_model.encode(query, convert_to_numpy=True)
160
-
161
- # جستجوی اسناد مشابه
162
- _, indices = self.index.search(np.expand_dims(query_embedding, axis=0), 5) # پیدا کردن 5 سند مشابه
163
-
164
- return [self.documents[i] for i in indices[0]]
165
-
166
-
167
- # ----------------- ساخت Index -----------------
168
- documents, embeddings, index = get_pdf_index()
169
- retriever = SimpleRetriever(documents=documents, embeddings=embeddings, index=index)
170
-
 
 
 
 
171
 
172
  # ----------------- ساخت Chain -----------------
173
- chain = RetrievalQA.from_chain_type(
174
  llm=llm,
175
  retriever=retriever,
176
- chain_type="stuff",
177
- input_key="question"
178
  )
179
 
180
- # ----------------- استیت برای چت -----------------
181
  if 'messages' not in st.session_state:
182
  st.session_state.messages = []
183
 
@@ -189,7 +172,7 @@ for msg in st.session_state.messages:
189
  with st.chat_message(msg['role']):
190
  st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
191
 
192
- # ----------------- ورودی چت -----------------
193
  prompt = st.chat_input("سوالی در مورد فایل بپرس...")
194
 
195
  if prompt:
@@ -197,14 +180,14 @@ if prompt:
197
  st.session_state.pending_prompt = prompt
198
  st.rerun()
199
 
200
- # ----------------- پاسخ مدل -----------------
201
  if st.session_state.pending_prompt:
202
  with st.chat_message('ai'):
203
  thinking = st.empty()
204
- thinking.markdown("🤖 در حال فکر کردن ...")
205
 
206
  try:
207
- response = chain.run(f"سوال: {st.session_state.pending_prompt}")
208
  answer = response.strip()
209
  except Exception as e:
210
  answer = f"خطا در پاسخ‌دهی: {str(e)}"
 
4
  from langchain.chat_models import ChatOpenAI
5
  from langchain.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.schema import Document
8
  from langchain.chains import RetrievalQA
 
9
  from langchain_core.retrievers import BaseRetriever
 
10
  from typing import List
11
  from pydantic import Field
 
12
  import numpy as np
 
 
13
  from sentence_transformers import SentenceTransformer
14
  import faiss
15
 
 
16
  # ----------------- تنظیمات صفحه -----------------
17
  st.set_page_config(page_title="چت‌ بات توانا", page_icon="🪖", layout="wide")
18
 
 
99
  @st.cache_resource
100
  def get_pdf_index():
101
  with st.spinner('📄 در حال پردازش فایل PDF...'):
 
102
  loader = PyPDFLoader('test1.pdf')
103
+ documents = loader.load()
 
 
 
 
 
 
104
 
105
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
106
  texts = []
107
  for doc in documents:
108
+ texts.extend(splitter.split_text(doc.page_content))
109
+
110
+ # مدل امبدینگ
111
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
112
 
113
  # تولید امبدینگ‌ها
114
+ embeddings = model.encode(texts, convert_to_numpy=True)
115
+
116
+ # ساخت ایندکس Faiss
117
+ index = faiss.IndexFlatL2(embeddings.shape[1])
118
+ index.add(embeddings)
119
+
120
+ docs = [{"text": text} for text in texts]
121
+
122
+ return docs, embeddings, index, model
123
+
124
+ # ----------------- تعریف LLM -----------------
 
 
 
 
 
 
 
 
 
125
  llm = ChatOpenAI(
126
  base_url="https://api.together.xyz/v1",
127
  api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
128
+ model="meta-llama/Llama-3-70B-Instruct"
129
  )
130
 
 
131
  # ----------------- تعریف SimpleRetriever -----------------
132
  class SimpleRetriever(BaseRetriever):
133
+ documents: List[dict] = Field(...)
134
+ embeddings: np.ndarray = Field(...)
135
  index: faiss.Index
136
+ model: SentenceTransformer
137
+
138
+ def _get_relevant_documents(self, query: str) -> List[Document]:
139
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
140
+ _, indices = self.index.search(query_embedding, 5)
141
+ results = []
142
+ for i in indices[0]:
143
+ results.append(Document(page_content=self.documents[i]['text']))
144
+ return results
145
+
146
+ # ----------------- بارگذاری دیتا -----------------
147
+ documents, embeddings, index, model = get_pdf_index()
148
+
149
+ retriever = SimpleRetriever(
150
+ documents=documents,
151
+ embeddings=embeddings,
152
+ index=index,
153
+ model=model
154
+ )
155
 
156
  # ----------------- ساخت Chain -----------------
157
+ qa_chain = RetrievalQA.from_chain_type(
158
  llm=llm,
159
  retriever=retriever,
160
+ chain_type="stuff"
 
161
  )
162
 
163
+ # ----------------- چت استیت -----------------
164
  if 'messages' not in st.session_state:
165
  st.session_state.messages = []
166
 
 
172
  with st.chat_message(msg['role']):
173
  st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
174
 
175
+ # ----------------- ورودی کاربر -----------------
176
  prompt = st.chat_input("سوالی در مورد فایل بپرس...")
177
 
178
  if prompt:
 
180
  st.session_state.pending_prompt = prompt
181
  st.rerun()
182
 
183
+ # ----------------- پاسخ‌دهی مدل -----------------
184
  if st.session_state.pending_prompt:
185
  with st.chat_message('ai'):
186
  thinking = st.empty()
187
+ thinking.markdown("🤖 در حال فکر کردن...")
188
 
189
  try:
190
+ response = qa_chain.run(st.session_state.pending_prompt)
191
  answer = response.strip()
192
  except Exception as e:
193
  answer = f"خطا در پاسخ‌دهی: {str(e)}"