pure_RAG / app.py
SergeyO7's picture
Update app.py
58c7f8b verified
raw
history blame contribute delete
12.9 kB
import gradio as gr
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
import os
from datetime import datetime
from skyfield.api import load
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
# Load environment variables
load_dotenv()
DATA_PATH = "" # Specify the path to your file
PROMPT_TEMPLATE = """
Ответь на вопрос, используя только следующий контекст:
{context}
---
Ответь на вопрос на основе приведенного контекста: {question}
"""
# Global variable for status
status_message = "Инициализация..."
# Translation dictionaries
classification_ru = {
'Swallowed': 'проглоченная',
'Tiny': 'сверхмалая',
'Small': 'малая',
'Normal': 'нормальная',
'Ideal': 'идеальная',
'Big': 'большая'
}
planet_ru = {
'Sun': 'Солнце',
'Moon': 'Луна',
'Mercury': 'Меркурий',
'Venus': 'Венера',
'Mars': 'Марс',
'Jupiter': 'Юпитер',
'Saturn': 'Сатурн'
}
planet_symbols = {
'Sun': '☉',
'Moon': '☾',
'Mercury': '☿',
'Venus': '♀',
'Mars': '♂',
'Jupiter': '♃',
'Saturn': '♄'
}
def initialize_vectorstore():
"""Initialize the FAISS vector store for document retrieval."""
global status_message
try:
status_message = "Загрузка и обработка документов..."
documents = load_documents()
chunks = split_text(documents)
status_message = "Создание векторной базы..."
vectorstore = save_to_faiss(chunks)
status_message = "База данных готова к использованию."
return vectorstore
except Exception as e:
status_message = f"Ошибка инициализации: {str(e)}"
raise
def load_documents():
"""Load documents from the specified file path."""
file_path = os.path.join(DATA_PATH, "pl250320252.md")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Файл {file_path} не найден")
loader = UnstructuredMarkdownLoader(file_path)
return loader.load()
def split_text(documents: list[Document]):
"""Split documents into chunks for vectorization."""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=900,
chunk_overlap=300,
length_function=len,
add_start_index=True,
)
return text_splitter.split_documents(documents)
def save_to_faiss(chunks: list[Document]):
"""Save document chunks to a FAISS vector store."""
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
return FAISS.from_documents(chunks, embeddings)
def process_query(query_text: str, vectorstore):
"""Process a query using the RAG system."""
if vectorstore is None:
return "База данных не инициализирована", []
try:
results = vectorstore.similarity_search_with_relevance_scores(query_text, k=3)
global status_message
status_message += f"\nНайдено {len(results)} результатов"
if not results:
return "Не найдено результатов.", []
context_text = "\n\n---\n\n".join([
f"Релевантность: {score:.2f}\n{doc.page_content}"
for doc, score in results
])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_text, question=query_text)
model = HuggingFaceEndpoint(
endpoint_url="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
task="text2text-generation",
# huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), # Include if token is required
model_kwargs={"temperature": 0.5, "max_length": 512}
)
response_text = model.invoke(prompt)
sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
return response_text, sources
except Exception as e:
return f"Ошибка обработки запроса: {str(e)}", []
def PLadder_ZSizes(date_time_iso: str):
"""
Calculate the planetary ladder and zone sizes for a given date and time.
Args:
date_time_iso (str): Date and time in ISO format (e.g., '2023-10-10T12:00:00')
Returns:
dict: Contains 'PLadder' (list of planets) and 'ZSizes' (list of zone sizes with classifications)
or an error message if unsuccessful
"""
try:
dt = datetime.fromisoformat(date_time_iso)
if dt.year < 1900 or dt.year > 2050:
return {"error": "Дата вне диапазона. Должна быть между 1900 и 2050 годами."}
# Load ephemeris
planets = load('de421.bsp')
earth = planets['earth']
# Define planet objects
planet_objects = {
'Sun': planets['sun'],
'Moon': planets['moon'],
'Mercury': planets['mercury'],
'Venus': planets['venus'],
'Mars': planets['mars'],
'Jupiter': planets['jupiter barycenter'],
'Saturn': planets['saturn barycenter']
}
# Create time object
ts = load.timescale()
t = ts.utc(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
# Compute ecliptic longitudes
longitudes = {}
for planet in planet_objects:
apparent = earth.at(t).observe(planet_objects[planet]).apparent()
_, lon, _ = apparent.ecliptic_latlon()
longitudes[planet] = lon.degrees
# Sort planets by longitude to form PLadder
sorted_planets = sorted(longitudes.items(), key=lambda x: x[1])
PLadder = [p for p, _ in sorted_planets]
sorted_lons = [lon for _, lon in sorted_planets]
# Calculate zone sizes
zone_sizes = [sorted_lons[0]] + [sorted_lons[i+1] - sorted_lons[i] for i in range(6)] + [360 - sorted_lons[6]]
# Determine bordering planets for classification
bordering = [[PLadder[0]]] + [[PLadder[i-1], PLadder[i]] for i in range(1, 7)] + [[PLadder[6]]]
# Classify each zone
ZSizes = []
for i, size in enumerate(zone_sizes):
bord = bordering[i]
if any(p in ['Sun', 'Moon'] for p in bord):
X = 7
elif any(p in ['Mercury', 'Venus', 'Mars'] for p in bord):
X = 6
else:
X = 5
if size <= 1:
classification = 'Swallowed'
elif size <= X:
classification = 'Tiny'
elif size <= 40:
classification = 'Small'
elif size < 60:
if 50 <= size <= 52:
classification = 'Ideal'
else:
classification = 'Normal'
else:
classification = 'Big'
# Convert size to degrees and minutes
d = int(size)
m = int((size - d) * 60)
size_str = f"{d}°{m}'"
ZSizes.append((size_str, classification))
return {'PLadder': PLadder, 'ZSizes': ZSizes}
except ValueError:
return {"error": "Неверный формат даты и времени. Используйте ISO формат, например, '2023-10-10T12:00:00'"}
except Exception as e:
return {"error": f"Ошибка при вычислении: {str(e)}"}
def plot_pladder(PLadder):
"""
Plot the planetary ladder as a right triangle with planet symbols.
Args:
PLadder (list): List of planet names in order
Returns:
matplotlib.figure.Figure: The generated plot
"""
fig, ax = plt.subplots()
# Draw triangle with vertices (0,0), (0,3), (3,0)
ax.plot([0, 0, 3, 0], [0, 3, 0, 0], 'k-')
# Draw horizontal lines dividing height into three equal parts
ax.plot([0, 3], [1, 1], 'k--')
ax.plot([0, 3], [2, 2], 'k--')
# Define positions for planets 1 to 7
positions = [(0, 0), (0, 1), (0, 2), (0, 3), (1, 2), (2, 1), (3, 0)]
for i, pos in enumerate(positions):
symbol = planet_symbols[PLadder[i]]
ax.text(pos[0], pos[1], symbol, ha='center', va='center', fontsize=12)
ax.set_xlim(-0.5, 3.5)
ax.set_ylim(-0.5, 3.5)
ax.set_aspect('equal')
ax.axis('off')
return fig
def chat_interface(query_text):
"""
Handle user queries, either for planetary ladder or general RAG questions.
Args:
query_text (str): User's input query
Returns:
tuple: (text response, plot figure or None)
"""
global status_message
try:
vectorstore = initialize_vectorstore()
if query_text.startswith("PLadder "):
# Extract date and time from query
date_time_iso = query_text.split(" ", 1)[1]
result = PLadder_ZSizes(date_time_iso)
if "error" in result:
return result["error"], None
PLadder = result["PLadder"]
ZSizes = result["ZSizes"]
# Translate to Russian
PLadder_ru = [planet_ru[p] for p in PLadder]
ZSizes_ru = [(size_str, classification_ru[classification]) for size_str, classification in ZSizes]
# Prepare queries and get responses
responses = []
for i in range(7):
planet = PLadder_ru[i]
size_str, class_ru = ZSizes_ru[i]
query = f"Что значит {planet} на {i+1}-й ступени и {size_str} {class_ru} {i+1}-я зона?"
response, _ = process_query(query, vectorstore)
responses.append(f"Интерпретация для {i+1}-й ступени и {i+1}-й зоны: {response}")
# Query for 8th zone
size_str, class_ru = ZSizes_ru[7]
query = f"Что значит {size_str} {class_ru} восьмая зона?"
response, _ = process_query(query, vectorstore)
responses.append(f"Интерпретация для 8-й зоны: {response}")
# Generate plot
fig = plot_pladder(PLadder)
buf = BytesIO()
fig.savefig(buf, format='png') # Save figure to buffer as PNG
buf.seek(0)
img = Image.open(buf) # Convert to PIL image
plt.close(fig) # Close the figure to free memory
return text, img
# Compile response text
text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
text += "Размеры зон:\n" + "\n".join([f"Зона {i+1}: {size_str} {class_ru}"
for i, (size_str, class_ru) in enumerate(ZSizes_ru)]) + "\n\n"
text += "\n".join(responses)
return text, fig
else:
# Handle regular RAG query
response, sources = process_query(query_text, vectorstore)
full_response = f"{status_message}\n\nОтвет: {response}\n\nИсточники: {', '.join(sources) if sources else 'Нет источников'}"
return full_response, None
except Exception as e:
return f"Критическая ошибка: {str(e)}", None
# Define Gradio Interface
interface = gr.Interface(
fn=chat_interface,
inputs=gr.Textbox(lines=2, placeholder="Введите ваш вопрос здесь..."),
outputs=[gr.Textbox(), gr.Image()],
title="Чат с документами",
description="Задайте вопрос, и я отвечу на основе загруженных документов. "
"Для запроса планетарной лестницы используйте формат: PLadder YYYY-MM-DDTHH:MM:SS"
)
if __name__ == "__main__":
interface.launch()