Update app.py
Browse files
app.py
CHANGED
@@ -4,21 +4,15 @@ import streamlit as st
|
|
4 |
from langchain.chat_models import ChatOpenAI
|
5 |
from langchain.document_loaders import PyPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
-
from langchain.schema import Document
|
8 |
from langchain.chains import RetrievalQA
|
9 |
-
import torch
|
10 |
from langchain_core.retrievers import BaseRetriever
|
11 |
-
from langchain_core.documents import Document
|
12 |
from typing import List
|
13 |
from pydantic import Field
|
14 |
-
from sentence_transformers import SentenceTransformer
|
15 |
import numpy as np
|
16 |
-
from langchain.indexes.vectorstore import VectorstoreIndexCreator
|
17 |
-
|
18 |
from sentence_transformers import SentenceTransformer
|
19 |
import faiss
|
20 |
|
21 |
-
|
22 |
# ----------------- تنظیمات صفحه -----------------
|
23 |
st.set_page_config(page_title="چت بات توانا", page_icon="🪖", layout="wide")
|
24 |
|
@@ -105,79 +99,68 @@ st.markdown("""
|
|
105 |
@st.cache_resource
|
106 |
def get_pdf_index():
|
107 |
with st.spinner('📄 در حال پردازش فایل PDF...'):
|
108 |
-
# بارگذاری PDF
|
109 |
loader = PyPDFLoader('test1.pdf')
|
110 |
-
documents = loader.
|
111 |
-
|
112 |
-
# استفاده از مدل embeddings
|
113 |
-
model = TogetherEmbeddings(
|
114 |
-
model_name="togethercomputer/m2-bert-80M-8k-retrieval",
|
115 |
-
api_key="0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979"
|
116 |
-
)
|
117 |
|
118 |
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
|
119 |
texts = []
|
120 |
for doc in documents:
|
121 |
-
texts.extend(splitter.split_text(doc.page_content))
|
|
|
|
|
|
|
122 |
|
123 |
# تولید امبدینگها
|
124 |
-
embeddings =
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
embeddings = np.array(embeddings)
|
136 |
-
|
137 |
-
# ایجاد ایندکس با استفاده از VectorstoreIndexCreator
|
138 |
-
return VectorstoreIndexCreator(
|
139 |
-
embedding_function=model.encode, # استفاده از تابع انکودینگ برای تولید امبدینگها
|
140 |
-
text_splitter=splitter
|
141 |
-
).from_texts(texts, embeddings=embeddings)
|
142 |
-
|
143 |
-
# ----------------- تعریف LLM از Groq -----------------
|
144 |
llm = ChatOpenAI(
|
145 |
base_url="https://api.together.xyz/v1",
|
146 |
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
|
147 |
-
model="meta-llama/Llama-3
|
148 |
)
|
149 |
|
150 |
-
|
151 |
# ----------------- تعریف SimpleRetriever -----------------
|
152 |
class SimpleRetriever(BaseRetriever):
|
153 |
-
documents: List[dict] = Field(...)
|
154 |
-
embeddings:
|
155 |
index: faiss.Index
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
query_embedding =
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
return
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
retriever = SimpleRetriever(
|
170 |
-
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# ----------------- ساخت Chain -----------------
|
173 |
-
|
174 |
llm=llm,
|
175 |
retriever=retriever,
|
176 |
-
chain_type="stuff"
|
177 |
-
input_key="question"
|
178 |
)
|
179 |
|
180 |
-
# ----------------- استیت
|
181 |
if 'messages' not in st.session_state:
|
182 |
st.session_state.messages = []
|
183 |
|
@@ -189,7 +172,7 @@ for msg in st.session_state.messages:
|
|
189 |
with st.chat_message(msg['role']):
|
190 |
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
|
191 |
|
192 |
-
# ----------------- ورودی
|
193 |
prompt = st.chat_input("سوالی در مورد فایل بپرس...")
|
194 |
|
195 |
if prompt:
|
@@ -197,14 +180,14 @@ if prompt:
|
|
197 |
st.session_state.pending_prompt = prompt
|
198 |
st.rerun()
|
199 |
|
200 |
-
# -----------------
|
201 |
if st.session_state.pending_prompt:
|
202 |
with st.chat_message('ai'):
|
203 |
thinking = st.empty()
|
204 |
-
thinking.markdown("🤖 در حال فکر
|
205 |
|
206 |
try:
|
207 |
-
response =
|
208 |
answer = response.strip()
|
209 |
except Exception as e:
|
210 |
answer = f"خطا در پاسخدهی: {str(e)}"
|
|
|
4 |
from langchain.chat_models import ChatOpenAI
|
5 |
from langchain.document_loaders import PyPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from langchain.schema import Document
|
8 |
from langchain.chains import RetrievalQA
|
|
|
9 |
from langchain_core.retrievers import BaseRetriever
|
|
|
10 |
from typing import List
|
11 |
from pydantic import Field
|
|
|
12 |
import numpy as np
|
|
|
|
|
13 |
from sentence_transformers import SentenceTransformer
|
14 |
import faiss
|
15 |
|
|
|
16 |
# ----------------- تنظیمات صفحه -----------------
|
17 |
st.set_page_config(page_title="چت بات توانا", page_icon="🪖", layout="wide")
|
18 |
|
|
|
99 |
@st.cache_resource
|
100 |
def get_pdf_index():
|
101 |
with st.spinner('📄 در حال پردازش فایل PDF...'):
|
|
|
102 |
loader = PyPDFLoader('test1.pdf')
|
103 |
+
documents = loader.load()
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
|
106 |
texts = []
|
107 |
for doc in documents:
|
108 |
+
texts.extend(splitter.split_text(doc.page_content))
|
109 |
+
|
110 |
+
# مدل امبدینگ
|
111 |
+
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
112 |
|
113 |
# تولید امبدینگها
|
114 |
+
embeddings = model.encode(texts, convert_to_numpy=True)
|
115 |
+
|
116 |
+
# ساخت ایندکس Faiss
|
117 |
+
index = faiss.IndexFlatL2(embeddings.shape[1])
|
118 |
+
index.add(embeddings)
|
119 |
+
|
120 |
+
docs = [{"text": text} for text in texts]
|
121 |
+
|
122 |
+
return docs, embeddings, index, model
|
123 |
+
|
124 |
+
# ----------------- تعریف LLM -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
llm = ChatOpenAI(
|
126 |
base_url="https://api.together.xyz/v1",
|
127 |
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
|
128 |
+
model="meta-llama/Llama-3-70B-Instruct"
|
129 |
)
|
130 |
|
|
|
131 |
# ----------------- تعریف SimpleRetriever -----------------
|
132 |
class SimpleRetriever(BaseRetriever):
|
133 |
+
documents: List[dict] = Field(...)
|
134 |
+
embeddings: np.ndarray = Field(...)
|
135 |
index: faiss.Index
|
136 |
+
model: SentenceTransformer
|
137 |
+
|
138 |
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
139 |
+
query_embedding = self.model.encode([query], convert_to_numpy=True)
|
140 |
+
_, indices = self.index.search(query_embedding, 5)
|
141 |
+
results = []
|
142 |
+
for i in indices[0]:
|
143 |
+
results.append(Document(page_content=self.documents[i]['text']))
|
144 |
+
return results
|
145 |
+
|
146 |
+
# ----------------- بارگذاری دیتا -----------------
|
147 |
+
documents, embeddings, index, model = get_pdf_index()
|
148 |
+
|
149 |
+
retriever = SimpleRetriever(
|
150 |
+
documents=documents,
|
151 |
+
embeddings=embeddings,
|
152 |
+
index=index,
|
153 |
+
model=model
|
154 |
+
)
|
155 |
|
156 |
# ----------------- ساخت Chain -----------------
|
157 |
+
qa_chain = RetrievalQA.from_chain_type(
|
158 |
llm=llm,
|
159 |
retriever=retriever,
|
160 |
+
chain_type="stuff"
|
|
|
161 |
)
|
162 |
|
163 |
+
# ----------------- چت استیت -----------------
|
164 |
if 'messages' not in st.session_state:
|
165 |
st.session_state.messages = []
|
166 |
|
|
|
172 |
with st.chat_message(msg['role']):
|
173 |
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
|
174 |
|
175 |
+
# ----------------- ورودی کاربر -----------------
|
176 |
prompt = st.chat_input("سوالی در مورد فایل بپرس...")
|
177 |
|
178 |
if prompt:
|
|
|
180 |
st.session_state.pending_prompt = prompt
|
181 |
st.rerun()
|
182 |
|
183 |
+
# ----------------- پاسخدهی مدل -----------------
|
184 |
if st.session_state.pending_prompt:
|
185 |
with st.chat_message('ai'):
|
186 |
thinking = st.empty()
|
187 |
+
thinking.markdown("🤖 در حال فکر کردن...")
|
188 |
|
189 |
try:
|
190 |
+
response = qa_chain.run(st.session_state.pending_prompt)
|
191 |
answer = response.strip()
|
192 |
except Exception as e:
|
193 |
answer = f"خطا در پاسخدهی: {str(e)}"
|