Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
import streamlit as st
|
4 |
-
import
|
5 |
-
from groq import Groq
|
6 |
from langchain.document_loaders import PyPDFLoader
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain.schema import Document as LangchainDocument
|
9 |
from langchain.chains import RetrievalQA
|
10 |
from langchain.llms import OpenAI
|
|
|
11 |
|
12 |
# ----------------- تنظیمات صفحه -----------------
|
13 |
st.set_page_config(page_title="چتبات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
# ----------------- بارگذاری مدل spaCy برای امبدینگ -----------------
|
21 |
-
nlp = spacy.load("MultiHashEmbed.v2") # یا 'en_core_web_lg' برای مدل بزرگتر
|
22 |
|
23 |
# ----------------- لود PDF و ساخت ایندکس -----------------
|
24 |
@st.cache_resource
|
@@ -40,11 +37,13 @@ def build_pdf_index():
|
|
40 |
# تبدیل به Document
|
41 |
documents = [LangchainDocument(page_content=t) for t in texts]
|
42 |
|
43 |
-
# استفاده از
|
44 |
embeddings = []
|
45 |
for doc in documents:
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
|
49 |
# به جای FAISS، فقط لیست امبدینگها را برمیگردانیم
|
50 |
return documents, embeddings
|
@@ -52,40 +51,45 @@ def build_pdf_index():
|
|
52 |
# ----------------- ساختن Index از PDF -----------------
|
53 |
documents, embeddings = build_pdf_index()
|
54 |
|
55 |
-
# ----------------- تعریف LLM Groq -----------------
|
|
|
56 |
client = Groq(api_key=groq_api_key)
|
57 |
|
58 |
class GroqLLM(OpenAI):
|
59 |
def __init__(self, api_key, model_name):
|
60 |
super().__init__(openai_api_key=api_key, model_name=model_name, base_url="https://api.groq.com/openai/v1")
|
61 |
|
|
|
62 |
llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
|
63 |
|
64 |
-
# -----------------
|
65 |
-
# به جای استفاده از FAISS، اکنون از لیست امبدینگها استفاده میکنیم
|
66 |
class SimpleRetriever:
|
67 |
def __init__(self, documents, embeddings):
|
68 |
self.documents = documents
|
69 |
self.embeddings = embeddings
|
70 |
|
71 |
def retrieve(self, query, top_k=1):
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
similarities = []
|
74 |
for doc_embedding in self.embeddings:
|
75 |
-
similarity = query_embedding.dot(doc_embedding) # محاسبه شباهت
|
76 |
similarities.append(similarity)
|
77 |
-
|
78 |
-
# یافتن مستندات مشابه
|
79 |
ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
|
80 |
return [doc for _, doc in ranked_docs[:top_k]]
|
81 |
|
82 |
-
# استفاده از retriever
|
83 |
retriever = SimpleRetriever(documents, embeddings)
|
84 |
|
85 |
# ----------------- Retrieval Chain -----------------
|
86 |
chain = RetrievalQA.from_chain_type(
|
87 |
llm=llm,
|
88 |
-
retriever=retriever.retrieve,
|
89 |
chain_type="stuff",
|
90 |
input_key="question"
|
91 |
)
|
|
|
1 |
import os
|
2 |
import time
|
3 |
import streamlit as st
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
|
|
5 |
from langchain.document_loaders import PyPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain.schema import Document as LangchainDocument
|
8 |
from langchain.chains import RetrievalQA
|
9 |
from langchain.llms import OpenAI
|
10 |
+
from groq import Groq
|
11 |
|
12 |
# ----------------- تنظیمات صفحه -----------------
|
13 |
st.set_page_config(page_title="چتبات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
|
14 |
|
15 |
+
# ----------------- بارگذاری مدل FarsiBERT -----------------
|
16 |
+
model_name = "HooshvareLab/bert-fa-zwnj-base" # مدل BERT فارسی
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
model = AutoModel.from_pretrained(model_name)
|
|
|
|
|
|
|
19 |
|
20 |
# ----------------- لود PDF و ساخت ایندکس -----------------
|
21 |
@st.cache_resource
|
|
|
37 |
# تبدیل به Document
|
38 |
documents = [LangchainDocument(page_content=t) for t in texts]
|
39 |
|
40 |
+
# استفاده از FarsiBERT برای تولید امبدینگ
|
41 |
embeddings = []
|
42 |
for doc in documents:
|
43 |
+
inputs = tokenizer(doc.page_content, return_tensors="pt", padding=True, truncation=True)
|
44 |
+
with torch.no_grad():
|
45 |
+
outputs = model(**inputs)
|
46 |
+
embeddings.append(outputs.last_hidden_state.mean(dim=1).numpy()) # میانگین امبدینگها
|
47 |
|
48 |
# به جای FAISS، فقط لیست امبدینگها را برمیگردانیم
|
49 |
return documents, embeddings
|
|
|
51 |
# ----------------- ساختن Index از PDF -----------------
|
52 |
documents, embeddings = build_pdf_index()
|
53 |
|
54 |
+
# ----------------- تعریف LLM از Groq -----------------
|
55 |
+
groq_api_key = "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"
|
56 |
client = Groq(api_key=groq_api_key)
|
57 |
|
58 |
class GroqLLM(OpenAI):
|
59 |
def __init__(self, api_key, model_name):
|
60 |
super().__init__(openai_api_key=api_key, model_name=model_name, base_url="https://api.groq.com/openai/v1")
|
61 |
|
62 |
+
# مدل Groq را با API خود بارگذاری کنید
|
63 |
llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
|
64 |
|
65 |
+
# ----------------- ساخت SimpleRetriever -----------------
|
|
|
66 |
class SimpleRetriever:
|
67 |
def __init__(self, documents, embeddings):
|
68 |
self.documents = documents
|
69 |
self.embeddings = embeddings
|
70 |
|
71 |
def retrieve(self, query, top_k=1):
|
72 |
+
# تبدیل سوال به امبدینگ با استفاده از FarsiBERT
|
73 |
+
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
|
74 |
+
with torch.no_grad():
|
75 |
+
query_embedding = model(**inputs).last_hidden_state.mean(dim=1).numpy()
|
76 |
+
|
77 |
similarities = []
|
78 |
for doc_embedding in self.embeddings:
|
79 |
+
similarity = query_embedding.dot(doc_embedding) # محاسبه شباهت
|
80 |
similarities.append(similarity)
|
81 |
+
|
82 |
+
# یافتن مستندات مشابه
|
83 |
ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
|
84 |
return [doc for _, doc in ranked_docs[:top_k]]
|
85 |
|
86 |
+
# استفاده از retriever برای جستجو
|
87 |
retriever = SimpleRetriever(documents, embeddings)
|
88 |
|
89 |
# ----------------- Retrieval Chain -----------------
|
90 |
chain = RetrievalQA.from_chain_type(
|
91 |
llm=llm,
|
92 |
+
retriever=retriever.retrieve,
|
93 |
chain_type="stuff",
|
94 |
input_key="question"
|
95 |
)
|