M17idd commited on
Commit
5985f75
·
verified ·
1 Parent(s): bc325e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -30
app.py CHANGED
@@ -7,15 +7,18 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.schema import Document as LangchainDocument
8
  from langchain.chains import RetrievalQA
9
  from langchain.llms import OpenAI
10
- from groq import Groq
11
  import torch
12
  from langchain_core.retrievers import BaseRetriever
 
 
 
 
13
 
14
  # ----------------- تنظیمات صفحه -----------------
15
  st.set_page_config(page_title="چت‌بات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
16
 
17
  # ----------------- بارگذاری مدل FarsiBERT -----------------
18
- model_name = "HooshvareLab/bert-fa-zwnj-base" # مدل BERT فارسی
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  model = AutoModel.from_pretrained(model_name)
21
 
@@ -26,7 +29,6 @@ def build_pdf_index():
26
  loader = PyPDFLoader("test1.pdf")
27
  pages = loader.load()
28
 
29
- # تکه‌تکه کردن متن PDF
30
  splitter = RecursiveCharacterTextSplitter(
31
  chunk_size=500,
32
  chunk_overlap=50
@@ -36,42 +38,27 @@ def build_pdf_index():
36
  for page in pages:
37
  texts.extend(splitter.split_text(page.page_content))
38
 
39
- # تبدیل به Document
40
  documents = [LangchainDocument(page_content=t) for t in texts]
41
 
42
- # استفاده از FarsiBERT برای تولید امبدینگ
43
  embeddings = []
44
  for doc in documents:
45
  inputs = tokenizer(doc.page_content, return_tensors="pt", padding=True, truncation=True)
46
  with torch.no_grad():
47
  outputs = model(**inputs)
48
- embeddings.append(outputs.last_hidden_state.mean(dim=1).numpy()) # میانگین امبدینگ‌ها
49
 
50
- # به جای FAISS، فقط لیست امبدینگ‌ها را برمی‌گردانیم
51
  return documents, embeddings
52
 
53
- # ----------------- ساختن Index از PDF -----------------
54
-
55
  # ----------------- تعریف LLM از Groq -----------------
56
  groq_api_key = "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"
57
- client = Groq(api_key=groq_api_key)
58
-
59
- class GroqLLM(OpenAI):
60
- def __init__(self, api_key, model_name):
61
- super().__init__(
62
- openai_api_key=api_key,
63
- model_name=model_name,
64
- base_url="https://api.groq.com" # فقط همین
65
- )
66
-
67
- # ساخت مدل
68
- llm = GroqLLM(api_key=groq_api_key, model_name="deepseek-r1-distill-llama-70b")
69
 
70
- from langchain_core.retrievers import BaseRetriever
71
- from langchain_core.documents import Document
72
- from typing import List
73
- from pydantic import Field
 
74
 
 
75
  class SimpleRetriever(BaseRetriever):
76
  documents: List[Document] = Field(...)
77
  embeddings: List = Field(...)
@@ -89,16 +76,19 @@ class SimpleRetriever(BaseRetriever):
89
 
90
  ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
91
  return [doc for _, doc in ranked_docs[:5]]
 
 
92
  documents, embeddings = build_pdf_index()
93
  retriever = SimpleRetriever(documents=documents, embeddings=embeddings)
94
 
95
- # بعد chain را بساز
96
  chain = RetrievalQA.from_chain_type(
97
  llm=llm,
98
  retriever=retriever,
99
  chain_type="stuff",
100
  input_key="question"
101
  )
 
102
  # ----------------- استیت برای چت -----------------
103
  if 'messages' not in st.session_state:
104
  st.session_state.messages = []
@@ -119,23 +109,20 @@ if prompt:
119
  st.session_state.pending_prompt = prompt
120
  st.rerun()
121
 
122
- # ----------------- پاسخ مدل فقط از روی PDF -----------------
123
  if st.session_state.pending_prompt:
124
  with st.chat_message('ai'):
125
  thinking = st.empty()
126
  thinking.markdown("🤖 در حال فکر کردن از روی PDF...")
127
 
128
  try:
129
- # گرفتن جواب فقط از PDF
130
  response = chain.run(f"سوال: {st.session_state.pending_prompt}")
131
  answer = response.strip()
132
-
133
  except Exception as e:
134
  answer = f"خطا در پاسخ‌دهی: {str(e)}"
135
 
136
  thinking.empty()
137
 
138
- # انیمیشن تایپ پاسخ
139
  full_response = ""
140
  placeholder = st.empty()
141
  for word in answer.split():
 
7
  from langchain.schema import Document as LangchainDocument
8
  from langchain.chains import RetrievalQA
9
  from langchain.llms import OpenAI
 
10
  import torch
11
  from langchain_core.retrievers import BaseRetriever
12
+ from langchain_core.documents import Document
13
+ from typing import List
14
+ from pydantic import Field
15
+ from groq import Groq
16
 
17
  # ----------------- تنظیمات صفحه -----------------
18
  st.set_page_config(page_title="چت‌بات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
19
 
20
  # ----------------- بارگذاری مدل FarsiBERT -----------------
21
+ model_name = "HooshvareLab/bert-fa-zwnj-base"
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  model = AutoModel.from_pretrained(model_name)
24
 
 
29
  loader = PyPDFLoader("test1.pdf")
30
  pages = loader.load()
31
 
 
32
  splitter = RecursiveCharacterTextSplitter(
33
  chunk_size=500,
34
  chunk_overlap=50
 
38
  for page in pages:
39
  texts.extend(splitter.split_text(page.page_content))
40
 
 
41
  documents = [LangchainDocument(page_content=t) for t in texts]
42
 
 
43
  embeddings = []
44
  for doc in documents:
45
  inputs = tokenizer(doc.page_content, return_tensors="pt", padding=True, truncation=True)
46
  with torch.no_grad():
47
  outputs = model(**inputs)
48
+ embeddings.append(outputs.last_hidden_state.mean(dim=1).numpy())
49
 
 
50
  return documents, embeddings
51
 
 
 
52
  # ----------------- تعریف LLM از Groq -----------------
53
  groq_api_key = "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"
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # مستقیماً از OpenAI بدون کلاس اضافه
56
+ llm = OpenAI(
57
+ openai_api_key=groq_api_key,
58
+ model_name="deepseek-r1-distill-llama-70b"
59
+ )
60
 
61
+ # ----------------- تعریف SimpleRetriever -----------------
62
  class SimpleRetriever(BaseRetriever):
63
  documents: List[Document] = Field(...)
64
  embeddings: List = Field(...)
 
76
 
77
  ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
78
  return [doc for _, doc in ranked_docs[:5]]
79
+
80
+ # ----------------- ساخت Index -----------------
81
  documents, embeddings = build_pdf_index()
82
  retriever = SimpleRetriever(documents=documents, embeddings=embeddings)
83
 
84
+ # ----------------- ساخت Chain -----------------
85
  chain = RetrievalQA.from_chain_type(
86
  llm=llm,
87
  retriever=retriever,
88
  chain_type="stuff",
89
  input_key="question"
90
  )
91
+
92
  # ----------------- استیت برای چت -----------------
93
  if 'messages' not in st.session_state:
94
  st.session_state.messages = []
 
109
  st.session_state.pending_prompt = prompt
110
  st.rerun()
111
 
112
+ # ----------------- پاسخ مدل -----------------
113
  if st.session_state.pending_prompt:
114
  with st.chat_message('ai'):
115
  thinking = st.empty()
116
  thinking.markdown("🤖 در حال فکر کردن از روی PDF...")
117
 
118
  try:
 
119
  response = chain.run(f"سوال: {st.session_state.pending_prompt}")
120
  answer = response.strip()
 
121
  except Exception as e:
122
  answer = f"خطا در پاسخ‌دهی: {str(e)}"
123
 
124
  thinking.empty()
125
 
 
126
  full_response = ""
127
  placeholder = st.empty()
128
  for word in answer.split():