|
import os |
|
import time |
|
import streamlit as st |
|
from langchain.chat_models import ChatOpenAI |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.schema import Document as LangchainDocument |
|
from langchain.chains import RetrievalQA |
|
from langchain.llms import OpenAI |
|
import torch |
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.documents import Document |
|
from typing import List |
|
from pydantic import Field |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
|
|
|
|
st.set_page_config(page_title="چت بات توانا", page_icon="🪖", layout="wide") |
|
|
|
st.markdown(""" |
|
<style> |
|
@import url('https://fonts.googleapis.com/css2?family=Vazirmatn:wght@400;700&display=swap'); |
|
html, body, [class*="css"] { |
|
font-family: 'Vazirmatn', Tahoma, sans-serif; |
|
direction: rtl; |
|
text-align: right; |
|
} |
|
.stApp { |
|
background: url("./military_bg.jpeg") no-repeat center center fixed; |
|
background-size: cover; |
|
backdrop-filter: blur(2px); |
|
} |
|
.stChatMessage { |
|
background-color: rgba(255,255,255,0.8); |
|
border: 1px solid #4e8a3e; |
|
border-radius: 12px; |
|
padding: 16px; |
|
margin-bottom: 15px; |
|
box-shadow: 0 4px 10px rgba(0,0,0,0.2); |
|
animation: fadeIn 0.4s ease-in-out; |
|
} |
|
.stTextInput > div > input, .stTextArea textarea { |
|
background-color: rgba(255,255,255,0.9) !important; |
|
border-radius: 8px !important; |
|
direction: rtl; |
|
text-align: right; |
|
font-family: 'Vazirmatn', Tahoma; |
|
} |
|
.stButton>button { |
|
background-color: #4e8a3e !important; |
|
color: white !important; |
|
font-weight: bold; |
|
border-radius: 10px; |
|
padding: 8px 20px; |
|
transition: 0.3s; |
|
} |
|
.stButton>button:hover { |
|
background-color: #3c6d30 !important; |
|
} |
|
.header-text { |
|
text-align: center; |
|
margin-top: 20px; |
|
margin-bottom: 40px; |
|
background-color: rgba(255, 255, 255, 0.75); |
|
padding: 20px; |
|
border-radius: 20px; |
|
box-shadow: 0 4px 12px rgba(0,0,0,0.2); |
|
} |
|
.header-text h1 { |
|
font-size: 42px; |
|
color: #2c3e50; |
|
margin: 0; |
|
font-weight: bold; |
|
} |
|
.subtitle { |
|
font-size: 18px; |
|
color: #34495e; |
|
margin-top: 8px; |
|
} |
|
@keyframes fadeIn { |
|
from { opacity: 0; transform: translateY(10px); } |
|
to { opacity: 1; transform: translateY(0); } |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
col1, col2, col3 = st.columns([1, 1, 1]) |
|
with col2: |
|
st.image("army.png", width=240) |
|
|
|
st.markdown(""" |
|
<div class="header-text"> |
|
<h1>چت بات توانا</h1> |
|
<div class="subtitle">دستیار هوشمند برای تصمیمگیری در میدان نبرد</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def build_pdf_index(): |
|
with st.spinner('📄 در حال پردازش فایل ...'): |
|
loader = PyPDFLoader("test1.pdf") |
|
pages = loader.load() |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50 |
|
) |
|
|
|
texts = [] |
|
for page in pages: |
|
texts.extend(splitter.split_text(page.page_content)) |
|
|
|
documents = [LangchainDocument(page_content=t) for t in texts] |
|
|
|
sentence_model = SentenceTransformer('HooshvareLab/bert-fa-zwnj-base') |
|
|
|
progress_bar = st.progress(0) |
|
total_docs = len(documents) |
|
|
|
texts_to_encode = [doc.page_content for doc in documents] |
|
|
|
batch_size = 128 |
|
embeddings = [] |
|
for i in range(0, total_docs, batch_size): |
|
batch_texts = texts_to_encode[i:i+batch_size] |
|
batch_embeddings = sentence_model.encode(batch_texts, convert_to_numpy=True) |
|
embeddings.extend(batch_embeddings) |
|
|
|
progress_bar.progress(min((i + batch_size) / total_docs, 1.0)) |
|
|
|
|
|
time.sleep(5) |
|
progress_bar.empty() |
|
embeddings = np.array(embeddings) |
|
|
|
return documents, embeddings |
|
|
|
|
|
|
|
|
|
|
|
llm = ChatOpenAI( |
|
base_url="https://api.together.xyz/v1", |
|
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979', |
|
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free" |
|
) |
|
|
|
|
|
class SimpleRetriever(BaseRetriever): |
|
documents: List[Document] = Field(...) |
|
embeddings: List = Field(...) |
|
|
|
def _get_relevant_documents(self, query: str) -> List[Document]: |
|
|
|
sentence_model = SentenceTransformer('HooshvareLab/bert-fa-zwnj-base') |
|
query_embedding = sentence_model.encode(query, convert_to_numpy=True) |
|
|
|
similarities = [] |
|
for doc_embedding in self.embeddings: |
|
similarity = (query_embedding * doc_embedding).sum() |
|
similarities.append(similarity) |
|
|
|
ranked_docs = sorted( |
|
zip(similarities, self.documents), |
|
key=lambda x: x[0], |
|
reverse=True |
|
) |
|
return [doc for _, doc in ranked_docs[:5]] |
|
|
|
|
|
documents, embeddings = build_pdf_index() |
|
retriever = SimpleRetriever(documents=documents, embeddings=embeddings) |
|
|
|
|
|
chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
input_key="question" |
|
) |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if 'pending_prompt' not in st.session_state: |
|
st.session_state.pending_prompt = None |
|
|
|
|
|
for msg in st.session_state.messages: |
|
with st.chat_message(msg['role']): |
|
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True) |
|
|
|
|
|
prompt = st.chat_input("سوالی در مورد فایل بپرس...") |
|
|
|
if prompt: |
|
st.session_state.messages.append({'role': 'user', 'content': prompt}) |
|
st.session_state.pending_prompt = prompt |
|
st.rerun() |
|
|
|
|
|
if st.session_state.pending_prompt: |
|
with st.chat_message('ai'): |
|
thinking = st.empty() |
|
thinking.markdown("🤖 در حال فکر کردن ...") |
|
|
|
try: |
|
response = chain.run(f"سوال: {st.session_state.pending_prompt}") |
|
answer = response.strip() |
|
except Exception as e: |
|
answer = f"خطا در پاسخدهی: {str(e)}" |
|
|
|
thinking.empty() |
|
|
|
full_response = "" |
|
placeholder = st.empty() |
|
for word in answer.split(): |
|
full_response += word + " " |
|
placeholder.markdown(full_response + "▌") |
|
time.sleep(0.03) |
|
|
|
placeholder.markdown(full_response) |
|
st.session_state.messages.append({'role': 'ai', 'content': full_response}) |
|
st.session_state.pending_prompt = None |
|
|