SergeyO7 commited on
Commit
c914465
·
verified ·
1 Parent(s): 70057d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -2,15 +2,17 @@ import gradio as gr
2
  from langchain_community.document_loaders import UnstructuredMarkdownLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_core.documents import Document
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
- from langchain_community.llms import HuggingFaceHub
8
  from langchain.prompts import ChatPromptTemplate
9
  from dotenv import load_dotenv
10
  import os
11
  from datetime import datetime
12
  from skyfield.api import load
13
  import matplotlib.pyplot as plt
 
 
 
14
 
15
  # Load environment variables
16
  load_dotenv()
@@ -121,12 +123,13 @@ def process_query(query_text: str, vectorstore):
121
  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
122
  prompt = prompt_template.format(context=context_text, question=query_text)
123
 
124
- model = HuggingFaceHub(
125
- repo_id="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
126
  task="text2text-generation",
 
127
  model_kwargs={"temperature": 0.5, "max_length": 512}
128
  )
129
- response_text = model.predict(prompt)
130
 
131
  sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
132
  return response_text, sources
@@ -297,6 +300,12 @@ def chat_interface(query_text):
297
 
298
  # Generate plot
299
  fig = plot_pladder(PLadder)
 
 
 
 
 
 
300
 
301
  # Compile response text
302
  text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
 
2
  from langchain_community.document_loaders import UnstructuredMarkdownLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_core.documents import Document
5
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
6
  from langchain_community.vectorstores import FAISS
 
7
  from langchain.prompts import ChatPromptTemplate
8
  from dotenv import load_dotenv
9
  import os
10
  from datetime import datetime
11
  from skyfield.api import load
12
  import matplotlib.pyplot as plt
13
+ from io import BytesIO
14
+ from PIL import Image
15
+
16
 
17
  # Load environment variables
18
  load_dotenv()
 
123
  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
124
  prompt = prompt_template.format(context=context_text, question=query_text)
125
 
126
+ model = HuggingFaceEndpoint(
127
+ endpoint_url="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
128
  task="text2text-generation",
129
+ # huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), # Include if token is required
130
  model_kwargs={"temperature": 0.5, "max_length": 512}
131
  )
132
+ response_text = model.invoke(prompt)
133
 
134
  sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
135
  return response_text, sources
 
300
 
301
  # Generate plot
302
  fig = plot_pladder(PLadder)
303
+ buf = BytesIO()
304
+ fig.savefig(buf, format='png') # Save figure to buffer as PNG
305
+ buf.seek(0)
306
+ img = Image.open(buf) # Convert to PIL image
307
+ plt.close(fig) # Close the figure to free memory
308
+ return text, img
309
 
310
  # Compile response text
311
  text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"