Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import time
|
3 |
import streamlit as st
|
4 |
from langchain.chat_models import ChatOpenAI
|
|
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
from langchain.document_loaders import PyPDFLoader
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -13,26 +14,18 @@ from langchain_core.retrievers import BaseRetriever
|
|
13 |
from langchain_core.documents import Document
|
14 |
from typing import List
|
15 |
from pydantic import Field
|
16 |
-
from
|
|
|
17 |
|
18 |
# ----------------- تنظیمات صفحه -----------------
|
19 |
st.set_page_config(page_title="چتبات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
|
20 |
|
21 |
# ----------------- بارگذاری مدل FarsiBERT -----------------
|
22 |
-
model_name = "HooshvareLab/bert-fa-zwnj-base"
|
23 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
-
model = AutoModel.from_pretrained(model_name)
|
25 |
|
26 |
# ----------------- لود PDF و ساخت ایندکس -----------------
|
27 |
-
import os
|
28 |
-
import streamlit as st
|
29 |
-
import torch
|
30 |
-
from transformers import AutoTokenizer, AutoModel
|
31 |
-
from langchain.document_loaders import PyPDFLoader
|
32 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
33 |
-
from langchain.schema import Document as LangchainDocument
|
34 |
-
from sentence_transformers import SentenceTransformer
|
35 |
-
import numpy as np
|
36 |
|
37 |
@st.cache_resource
|
38 |
def build_pdf_index():
|
@@ -54,52 +47,30 @@ def build_pdf_index():
|
|
54 |
documents = [LangchainDocument(page_content=t) for t in texts]
|
55 |
|
56 |
# مدلهای Embedding
|
57 |
-
tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-zwnj-base")
|
58 |
-
bert_model = AutoModel.from_pretrained("HooshvareLab/bert-fa-zwnj-base")
|
59 |
-
|
60 |
sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
61 |
|
62 |
embeddings = []
|
63 |
-
batch_size = 16
|
64 |
|
65 |
-
#
|
66 |
progress_bar = st.progress(0)
|
67 |
-
|
68 |
-
|
69 |
-
for i in range(0, len(documents), batch_size):
|
70 |
-
batch_docs = documents[i:i+batch_size]
|
71 |
-
batch_texts = [doc.page_content for doc in batch_docs]
|
72 |
-
|
73 |
-
# اول تلاش با مدل SentenceTransformer (خیلی سریعتره)
|
74 |
-
try:
|
75 |
-
batch_embeddings = sentence_model.encode(batch_texts, batch_size=batch_size, convert_to_numpy=True)
|
76 |
-
except Exception as e:
|
77 |
-
st.error(f"❌ خطا در SentenceTransformer: {e}")
|
78 |
-
batch_embeddings = []
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
outputs = bert_model(**inputs)
|
85 |
-
batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
86 |
-
|
87 |
-
embeddings.extend(batch_embeddings)
|
88 |
|
89 |
# بهروزرسانی پروگرس بار
|
90 |
-
progress_bar.progress((i +
|
91 |
|
92 |
-
# اطمینان
|
93 |
embeddings = np.array(embeddings)
|
94 |
|
95 |
return documents, embeddings
|
96 |
|
97 |
|
98 |
# ----------------- تعریف LLM از Groq -----------------
|
99 |
-
groq_api_key = "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"
|
100 |
-
|
101 |
-
from langchain.llms import HuggingFaceEndpoint
|
102 |
-
groq_api_key = os.environ.get("GROQ_API_KEY")
|
103 |
|
104 |
# به جای OpenAI اینو بذار:
|
105 |
llm = ChatOpenAI(
|
@@ -114,10 +85,9 @@ class SimpleRetriever(BaseRetriever):
|
|
114 |
embeddings: List = Field(...)
|
115 |
|
116 |
def _get_relevant_documents(self, query: str) -> List[Document]:
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
query_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
|
121 |
|
122 |
similarities = []
|
123 |
for doc_embedding in self.embeddings:
|
@@ -162,21 +132,24 @@ if prompt:
|
|
162 |
# ----------------- پاسخ مدل -----------------
|
163 |
if st.session_state.pending_prompt:
|
164 |
with st.chat_message('ai'):
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
try:
|
169 |
response = chain.run(f"سوال: {st.session_state.pending_prompt}")
|
170 |
answer = response.strip()
|
171 |
-
|
172 |
-
# شبیه سازی پردازش برای به روز کردن پروگرس بار
|
173 |
-
for i in range(0, 101, 20):
|
174 |
-
progress_bar.progress(i)
|
175 |
-
time.sleep(0.1) # شبیه سازی سرعت پردازش
|
176 |
-
|
177 |
except Exception as e:
|
178 |
answer = f"خطا در پاسخدهی: {str(e)}"
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import time
|
3 |
import streamlit as st
|
4 |
from langchain.chat_models import ChatOpenAI
|
5 |
+
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
from langchain.document_loaders import PyPDFLoader
|
8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
14 |
from langchain_core.documents import Document
|
15 |
from typing import List
|
16 |
from pydantic import Field
|
17 |
+
from sentence_transformers import SentenceTransformer
|
18 |
+
import numpy as np
|
19 |
|
20 |
# ----------------- تنظیمات صفحه -----------------
|
21 |
st.set_page_config(page_title="چتبات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
|
22 |
|
23 |
# ----------------- بارگذاری مدل FarsiBERT -----------------
|
24 |
+
# model_name = "HooshvareLab/bert-fa-zwnj-base"
|
25 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
+
# model = AutoModel.from_pretrained(model_name)
|
27 |
|
28 |
# ----------------- لود PDF و ساخت ایندکس -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
@st.cache_resource
|
31 |
def build_pdf_index():
|
|
|
47 |
documents = [LangchainDocument(page_content=t) for t in texts]
|
48 |
|
49 |
# مدلهای Embedding
|
|
|
|
|
|
|
50 |
sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
51 |
|
52 |
embeddings = []
|
|
|
53 |
|
54 |
+
# پروگرس بار اضافه میکنیم
|
55 |
progress_bar = st.progress(0)
|
56 |
+
total_docs = len(documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# فقط از SentenceTransformer استفاده میکنیم
|
59 |
+
for i, doc in enumerate(documents):
|
60 |
+
batch_embedding = sentence_model.encode(doc.page_content, convert_to_numpy=True)
|
61 |
+
embeddings.append(batch_embedding)
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# بهروزرسانی پروگرس بار
|
64 |
+
progress_bar.progress((i + 1) / total_docs)
|
65 |
|
66 |
+
# اطمینان از اینکه خروجی NumpyArray است
|
67 |
embeddings = np.array(embeddings)
|
68 |
|
69 |
return documents, embeddings
|
70 |
|
71 |
|
72 |
# ----------------- تعریف LLM از Groq -----------------
|
73 |
+
# groq_api_key = "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"
|
|
|
|
|
|
|
74 |
|
75 |
# به جای OpenAI اینو بذار:
|
76 |
llm = ChatOpenAI(
|
|
|
85 |
embeddings: List = Field(...)
|
86 |
|
87 |
def _get_relevant_documents(self, query: str) -> List[Document]:
|
88 |
+
# فقط از sentence_model استفاده میکنیم
|
89 |
+
sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
90 |
+
query_embedding = sentence_model.encode(query, convert_to_numpy=True)
|
|
|
91 |
|
92 |
similarities = []
|
93 |
for doc_embedding in self.embeddings:
|
|
|
132 |
# ----------------- پاسخ مدل -----------------
|
133 |
if st.session_state.pending_prompt:
|
134 |
with st.chat_message('ai'):
|
135 |
+
thinking = st.empty()
|
136 |
+
thinking.markdown("🤖 در حال فکر کردن از روی PDF...")
|
137 |
+
|
138 |
try:
|
139 |
response = chain.run(f"سوال: {st.session_state.pending_prompt}")
|
140 |
answer = response.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
except Exception as e:
|
142 |
answer = f"خطا در پاسخدهی: {str(e)}"
|
143 |
|
144 |
+
thinking.empty()
|
145 |
+
|
146 |
+
full_response = ""
|
147 |
+
placeholder = st.empty()
|
148 |
+
for word in answer.split():
|
149 |
+
full_response += word + " "
|
150 |
+
placeholder.markdown(full_response + "▌")
|
151 |
+
time.sleep(0.03)
|
152 |
+
|
153 |
+
placeholder.markdown(full_response)
|
154 |
+
st.session_state.messages.append({'role': 'ai', 'content': full_response})
|
155 |
+
st.session_state.pending_prompt = None
|