army / app.py
M17idd's picture
Update app.py
b5be236 verified
raw
history blame
6.66 kB
import time
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import FAISS
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from typing import List
from together import Together
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel
import torch
import streamlit as st
from PIL import Image
st.set_page_config(page_title="چت‌ بات توانا", page_icon="🪖", layout="wide")
# استایل
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Vazirmatn:wght@400;700&display=swap');
html, body, [class*="css"] {
font-family: 'Vazirmatn', Tahoma, sans-serif;
direction: rtl;
text-align: right;
}
.stApp {
background: url("military_bg.jpeg") no-repeat center center fixed;
background-size: cover;
backdrop-filter: blur(2px);
}
.stChatMessage {
background-color: rgba(255,255,255,0.8);
border: 1px solid #4e8a3e;
border-radius: 12px;
padding: 16px;
margin-bottom: 15px;
box-shadow: 0 4px 10px rgba(0,0,0,0.2);
animation: fadeIn 0.4s ease-in-out;
}
.stTextInput > div > input, .stTextArea textarea {
background-color: rgba(255,255,255,0.9) !important;
border-radius: 8px !important;
direction: rtl;
text-align: right;
font-family: 'Vazirmatn', Tahoma;
}
.stButton>button {
background-color: #4e8a3e !important;
color: white !important;
font-weight: bold;
border-radius: 10px;
padding: 8px 20px;
transition: 0.3s;
}
.stButton>button:hover {
background-color: #3c6d30 !important;
}
.header-text {
text-align: center;
margin-top: 20px;
margin-bottom: 40px;
background-color: rgba(255, 255, 255, 0.75);
padding: 20px;
border-radius: 20px;
box-shadow: 0 4px 12px rgba(0,0,0,0.2);
}
.header-text h1 {
font-size: 42px;
color: #2c3e50;
margin: 0;
font-weight: bold;
}
.subtitle {
font-size: 18px;
color: #34495e;
margin-top: 8px;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
</style>
""", unsafe_allow_html=True)
# لوگو در وسط با columns
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
try:
image = Image.open("army.png")
st.image(image, width=240)
except FileNotFoundError:
st.error("📁 فایل 'army.png' پیدا نشد. مطمئن شو کنار فایل اصلی Streamlit هست.")
# تیتر
st.markdown("""
<div class="header-text">
<h1>چت‌ بات توانا</h1>
<div class="subtitle">دستیار هوشمند برای تصمیم‌گیری در میدان نبرد</div>
</div>
""", unsafe_allow_html=True)
from transformers import AutoTokenizer, AutoModel
class HuggingFaceEmbeddings(Embeddings):
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in texts:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
return embeddings
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
@st.cache_resource
def get_pdf_index():
with st.spinner('📄 در حال پردازش فایل PDF...'):
# Load the PDF file
loader = PyPDFLoader('test1.pdf')
pages = loader.load()
# Extract text from each page
full_text = "\n".join([page.page_content for page in pages])
# Split the text into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, # Chunk size
chunk_overlap=128 # Overlap between chunks
)
texts = text_splitter.split_text(full_text)
# Create embeddings
embeddings = HuggingFaceEmbeddings(model_name="FacebookAI/xlm-roberta-large")
# Create FAISS vector store
vector_store = FAISS.from_texts(texts, embeddings)
return vector_store
index = get_pdf_index()
llm = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=index.vectorstore.as_retriever(),
input_key='question'
)
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'pending_prompt' not in st.session_state:
st.session_state.pending_prompt = None
for msg in st.session_state.messages:
with st.chat_message(msg['role']):
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
prompt = st.chat_input("چطور می‌تونم کمک کنم؟")
if prompt:
st.session_state.messages.append({'role': 'user', 'content': prompt})
st.session_state.pending_prompt = prompt
st.rerun()
if st.session_state.pending_prompt:
with st.chat_message('ai'):
thinking = st.empty()
thinking.markdown("🤖 در حال فکر کردن...")
response = chain.run(f'لطفاً فقط به زبان فارسی پاسخ بده: {st.session_state.pending_prompt}')
answer = response.split("Helpful Answer:")[-1].strip()
if not answer:
answer = "متأسفم، اطلاعات دقیقی در این مورد ندارم."
thinking.empty()
full_response = ""
placeholder = st.empty()
for word in answer.split():
full_response += word + " "
placeholder.markdown(full_response + "▌")
time.sleep(0.03)
placeholder.markdown(full_response)
st.session_state.messages.append({'role': 'ai', 'content': full_response})
st.session_state.pending_prompt = None