Update app.py
Browse files
app.py
CHANGED
@@ -2,15 +2,17 @@ import gradio as gr
|
|
2 |
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
3 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
from langchain_core.documents import Document
|
5 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
-
from langchain_community.llms import HuggingFaceHub
|
8 |
from langchain.prompts import ChatPromptTemplate
|
9 |
from dotenv import load_dotenv
|
10 |
import os
|
11 |
from datetime import datetime
|
12 |
from skyfield.api import load
|
13 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
14 |
|
15 |
# Load environment variables
|
16 |
load_dotenv()
|
@@ -121,12 +123,13 @@ def process_query(query_text: str, vectorstore):
|
|
121 |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
|
122 |
prompt = prompt_template.format(context=context_text, question=query_text)
|
123 |
|
124 |
-
model =
|
125 |
-
|
126 |
task="text2text-generation",
|
|
|
127 |
model_kwargs={"temperature": 0.5, "max_length": 512}
|
128 |
)
|
129 |
-
response_text = model.
|
130 |
|
131 |
sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
|
132 |
return response_text, sources
|
@@ -297,6 +300,12 @@ def chat_interface(query_text):
|
|
297 |
|
298 |
# Generate plot
|
299 |
fig = plot_pladder(PLadder)
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
# Compile response text
|
302 |
text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
|
|
|
2 |
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
3 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
from langchain_core.documents import Document
|
5 |
+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
|
6 |
from langchain_community.vectorstores import FAISS
|
|
|
7 |
from langchain.prompts import ChatPromptTemplate
|
8 |
from dotenv import load_dotenv
|
9 |
import os
|
10 |
from datetime import datetime
|
11 |
from skyfield.api import load
|
12 |
import matplotlib.pyplot as plt
|
13 |
+
from io import BytesIO
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
|
|
123 |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
|
124 |
prompt = prompt_template.format(context=context_text, question=query_text)
|
125 |
|
126 |
+
model = HuggingFaceEndpoint(
|
127 |
+
endpoint_url="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
|
128 |
task="text2text-generation",
|
129 |
+
# huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), # Include if token is required
|
130 |
model_kwargs={"temperature": 0.5, "max_length": 512}
|
131 |
)
|
132 |
+
response_text = model.invoke(prompt)
|
133 |
|
134 |
sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
|
135 |
return response_text, sources
|
|
|
300 |
|
301 |
# Generate plot
|
302 |
fig = plot_pladder(PLadder)
|
303 |
+
buf = BytesIO()
|
304 |
+
fig.savefig(buf, format='png') # Save figure to buffer as PNG
|
305 |
+
buf.seek(0)
|
306 |
+
img = Image.open(buf) # Convert to PIL image
|
307 |
+
plt.close(fig) # Close the figure to free memory
|
308 |
+
return text, img
|
309 |
|
310 |
# Compile response text
|
311 |
text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
|