|
import os |
|
import time |
|
import streamlit as st |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.schema import Document |
|
from langchain.chains import RetrievalQA |
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.prompts import PromptTemplate |
|
from typing import List |
|
from pydantic import Field |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
|
|
|
|
st.set_page_config(page_title="چت بات توانا", page_icon="🪖", layout="wide") |
|
|
|
st.markdown(""" |
|
<style> |
|
@import url('https://fonts.googleapis.com/css2?family=Vazirmatn:wght@400;700&display=swap'); |
|
html, body, [class*="css"] { |
|
font-family: 'Vazirmatn', Tahoma, sans-serif; |
|
direction: rtl; |
|
text-align: right; |
|
} |
|
.stApp { |
|
background: url("./military_bg.jpeg") no-repeat center center fixed; |
|
background-size: cover; |
|
backdrop-filter: blur(2px); |
|
} |
|
.stChatMessage { |
|
background-color: rgba(255,255,255,0.8); |
|
border: 1px solid #4e8a3e; |
|
border-radius: 12px; |
|
padding: 16px; |
|
margin-bottom: 15px; |
|
box-shadow: 0 4px 10px rgba(0,0,0,0.2); |
|
animation: fadeIn 0.4s ease-in-out; |
|
} |
|
.stTextInput > div > input, .stTextArea textarea { |
|
background-color: rgba(255,255,255,0.9) !important; |
|
border-radius: 8px !important; |
|
direction: rtl; |
|
text-align: right; |
|
font-family: 'Vazirmatn', Tahoma; |
|
} |
|
.stButton>button { |
|
background-color: #4e8a3e !important; |
|
color: white !important; |
|
font-weight: bold; |
|
border-radius: 10px; |
|
padding: 8px 20px; |
|
transition: 0.3s; |
|
} |
|
.stButton>button:hover { |
|
background-color: #3c6d30 !important; |
|
} |
|
.header-text { |
|
text-align: center; |
|
margin-top: 20px; |
|
margin-bottom: 40px; |
|
background-color: rgba(255, 255, 255, 0.75); |
|
padding: 20px; |
|
border-radius: 20px; |
|
box-shadow: 0 4px 12px rgba(0,0,0,0.2); |
|
} |
|
.header-text h1 { |
|
font-size: 42px; |
|
color: #2c3e50; |
|
margin: 0; |
|
font-weight: bold; |
|
} |
|
.subtitle { |
|
font-size: 18px; |
|
color: #34495e; |
|
margin-top: 8px; |
|
} |
|
@keyframes fadeIn { |
|
from { opacity: 0; transform: translateY(10px); } |
|
to { opacity: 1; transform: translateY(0); } |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
col1, col2, col3 = st.columns([1, 0.1, 1]) |
|
with col2: |
|
st.image("army.png", width=240) |
|
|
|
st.markdown(""" |
|
<div class="header-text"> |
|
<h1>چت بات توانا</h1> |
|
<div class="subtitle">دستیار هوشمند برای تصمیمگیری در میدان نبرد</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_pdf_index(): |
|
with st.spinner('📄 در حال پردازش فایل PDF...'): |
|
loader = PyPDFLoader('test1.pdf') |
|
documents = loader.load() |
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=128) |
|
texts = [] |
|
for doc in documents: |
|
texts.extend(splitter.split_text(doc.page_content)) |
|
|
|
|
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
embeddings = model.encode(texts, convert_to_numpy=True) |
|
|
|
|
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(embeddings) |
|
|
|
docs = [{"text": text} for text in texts] |
|
|
|
return docs, embeddings, index, model |
|
|
|
|
|
llm = ChatOpenAI( |
|
base_url="https://api.together.xyz/v1", |
|
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979', |
|
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free" |
|
) |
|
|
|
|
|
class SimpleRetriever(BaseRetriever): |
|
documents: List[dict] = Field(...) |
|
embeddings: np.ndarray = Field(...) |
|
index: faiss.Index |
|
model: SentenceTransformer |
|
|
|
def _get_relevant_documents(self, query: str) -> List[Document]: |
|
query_embedding = self.model.encode([query], convert_to_numpy=True) |
|
_, indices = self.index.search(query_embedding, 5) |
|
results = [] |
|
for i in indices[0]: |
|
results.append(Document(page_content=self.documents[i]['text'])) |
|
return results |
|
|
|
|
|
documents, embeddings, index, model = get_pdf_index() |
|
|
|
retriever = SimpleRetriever( |
|
documents=documents, |
|
embeddings=embeddings, |
|
index=index, |
|
model=model |
|
) |
|
custom_prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=""" |
|
شما فقط مجاز هستید از اطلاعات زیر پاسخ دهید. اگر اطلاعات لازم برای پاسخ دقیق وجود ندارد، لطفاً تلاش کنید تا نزدیکترین و مفیدترین پاسخ را از اطلاعات موجود پیدا کنید. در صورتی که اطلاعات مورد نظر وجود ندارد، بهجای دادن جواب نادرست، بگویید که اطلاعات مورد نیاز را ندارید یا نمیتوانید پاسخ دقیقی بدهید. |
|
|
|
اطلاعات: |
|
{context} |
|
|
|
سوال: |
|
{question} |
|
|
|
پاسخ دقیق و کامل بده: |
|
""" |
|
) |
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
chain_type_kwargs={"prompt": custom_prompt} |
|
) |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if 'pending_prompt' not in st.session_state: |
|
st.session_state.pending_prompt = None |
|
|
|
|
|
for msg in st.session_state.messages: |
|
with st.chat_message(msg['role']): |
|
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True) |
|
|
|
|
|
prompt = st.chat_input("سوالی در مورد فایل بپرس...") |
|
|
|
if prompt: |
|
st.session_state.messages.append({'role': 'user', 'content': prompt}) |
|
st.session_state.pending_prompt = prompt |
|
st.rerun() |
|
|
|
|
|
if st.session_state.pending_prompt: |
|
with st.chat_message('ai'): |
|
thinking = st.empty() |
|
thinking.markdown("🤖 در حال فکر کردن...") |
|
|
|
try: |
|
response = qa_chain.run(st.session_state.pending_prompt) |
|
answer = response.strip() |
|
except Exception as e: |
|
answer = f"خطا در پاسخدهی: {str(e)}" |
|
|
|
thinking.empty() |
|
|
|
full_response = "" |
|
placeholder = st.empty() |
|
for word in answer.split(): |
|
full_response += word + " " |
|
placeholder.markdown(full_response + "▌") |
|
time.sleep(0.03) |
|
|
|
placeholder.markdown(full_response) |
|
st.session_state.messages.append({'role': 'ai', 'content': full_response}) |
|
st.session_state.pending_prompt = None |
|
|