SergeyO7 commited on
Commit
bfe504a
·
verified ·
1 Parent(s): 6911e00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -11,6 +11,8 @@ import os
11
  from datetime import datetime
12
  from skyfield.api import load
13
  import matplotlib.pyplot as plt
 
 
14
 
15
  # Load environment variables
16
  load_dotenv()
@@ -121,12 +123,13 @@ def process_query(query_text: str, vectorstore):
121
  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
122
  prompt = prompt_template.format(context=context_text, question=query_text)
123
 
124
- model = HuggingFaceHub(
125
- repo_id="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
126
  task="text2text-generation",
 
127
  model_kwargs={"temperature": 0.5, "max_length": 512}
128
  )
129
- response_text = model.predict(prompt)
130
 
131
  sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
132
  return response_text, sources
@@ -297,13 +300,18 @@ def chat_interface(query_text):
297
 
298
  # Generate plot
299
  fig = plot_pladder(PLadder)
 
 
 
 
 
300
 
301
  # Compile response text
302
  text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
303
  text += "Размеры зон:\n" + "\n".join([f"Зона {i+1}: {size_str} {class_ru}"
304
  for i, (size_str, class_ru) in enumerate(ZSizes_ru)]) + "\n\n"
305
  text += "\n".join(responses)
306
- return text, fig
307
 
308
  else:
309
  # Handle regular RAG query
 
11
  from datetime import datetime
12
  from skyfield.api import load
13
  import matplotlib.pyplot as plt
14
+ from io import BytesIO
15
+ from PIL import Image
16
 
17
  # Load environment variables
18
  load_dotenv()
 
123
  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
124
  prompt = prompt_template.format(context=context_text, question=query_text)
125
 
126
+ model = HuggingFaceEndpoint(
127
+ endpoint_url="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud/",
128
  task="text2text-generation",
129
+ # huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
130
  model_kwargs={"temperature": 0.5, "max_length": 512}
131
  )
132
+ response_text = model.invoke(prompt)
133
 
134
  sources = list(set([doc.metadata.get("source", "") for doc, _ in results]))
135
  return response_text, sources
 
300
 
301
  # Generate plot
302
  fig = plot_pladder(PLadder)
303
+ buf = BytesIO()
304
+ fig.savefig(buf, format='png') # Save figure to buffer as PNG
305
+ buf.seek(0)
306
+ img = Image.open(buf) # Convert to PIL image
307
+ plt.close(fig) # Close the figure to free memory
308
 
309
  # Compile response text
310
  text = "Планетарная лестница: " + ", ".join(PLadder_ru) + "\n"
311
  text += "Размеры зон:\n" + "\n".join([f"Зона {i+1}: {size_str} {class_ru}"
312
  for i, (size_str, class_ru) in enumerate(ZSizes_ru)]) + "\n\n"
313
  text += "\n".join(responses)
314
+ return text, img
315
 
316
  else:
317
  # Handle regular RAG query