Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -235,7 +235,7 @@ import time
|
|
235 |
from fastapi import FastAPI, Request
|
236 |
from fastapi.responses import HTMLResponse
|
237 |
from fastapi.staticfiles import StaticFiles
|
238 |
-
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader,
|
239 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
240 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
241 |
from pydantic import BaseModel
|
@@ -244,7 +244,7 @@ import datetime
|
|
244 |
from fastapi.middleware.cors import CORSMiddleware
|
245 |
from fastapi.templating import Jinja2Templates
|
246 |
from simple_salesforce import Salesforce, SalesforceLogin
|
247 |
-
from transformers import AutoModelForSeq2SeqLM
|
248 |
|
249 |
# Define Pydantic model for incoming request body
|
250 |
class MessageRequest(BaseModel):
|
@@ -288,6 +288,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
288 |
templates = Jinja2Templates(directory="static")
|
289 |
|
290 |
# Configure Llama index settings
|
|
|
291 |
Settings.llm = HuggingFaceLLM(
|
292 |
model_name="google/flan-t5-small",
|
293 |
tokenizer_name="google/flan-t5-small",
|
@@ -295,6 +296,7 @@ Settings.llm = HuggingFaceLLM(
|
|
295 |
max_new_tokens=256,
|
296 |
generate_kwargs={"temperature": 0.1, "do_sample": True},
|
297 |
model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
|
|
|
298 |
device_map="auto" # Automatically use GPU if available, else CPU
|
299 |
)
|
300 |
Settings.embed_model = HuggingFaceEmbedding(
|
@@ -341,18 +343,15 @@ def split_name(full_name):
|
|
341 |
initialize() # Run initialization tasks
|
342 |
|
343 |
def handle_query(query):
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
)
|
354 |
-
]
|
355 |
-
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
|
356 |
|
357 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
358 |
index = load_index_from_storage(storage_context)
|
@@ -361,7 +360,7 @@ def handle_query(query):
|
|
361 |
if past_query.strip():
|
362 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
363 |
|
364 |
-
query_engine = index.as_query_engine(text_qa_template=text_qa_template
|
365 |
answer = query_engine.query(query)
|
366 |
|
367 |
if hasattr(answer, "response"):
|
|
|
235 |
from fastapi import FastAPI, Request
|
236 |
from fastapi.responses import HTMLResponse
|
237 |
from fastapi.staticfiles import StaticFiles
|
238 |
+
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, Settings
|
239 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
240 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
241 |
from pydantic import BaseModel
|
|
|
244 |
from fastapi.middleware.cors import CORSMiddleware
|
245 |
from fastapi.templating import Jinja2Templates
|
246 |
from simple_salesforce import Salesforce, SalesforceLogin
|
247 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
248 |
|
249 |
# Define Pydantic model for incoming request body
|
250 |
class MessageRequest(BaseModel):
|
|
|
288 |
templates = Jinja2Templates(directory="static")
|
289 |
|
290 |
# Configure Llama index settings
|
291 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
292 |
Settings.llm = HuggingFaceLLM(
|
293 |
model_name="google/flan-t5-small",
|
294 |
tokenizer_name="google/flan-t5-small",
|
|
|
296 |
max_new_tokens=256,
|
297 |
generate_kwargs={"temperature": 0.1, "do_sample": True},
|
298 |
model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
|
299 |
+
tokenizer=tokenizer,
|
300 |
device_map="auto" # Automatically use GPU if available, else CPU
|
301 |
)
|
302 |
Settings.embed_model = HuggingFaceEmbedding(
|
|
|
343 |
initialize() # Run initialization tasks
|
344 |
|
345 |
def handle_query(query):
|
346 |
+
# Custom prompt template for flan-t5-small (no chat template)
|
347 |
+
text_qa_template = PromptTemplate(
|
348 |
+
"""
|
349 |
+
You are Clara, a Redfernstech chatbot. Provide accurate, concise answers (10-15 words) based on company data.
|
350 |
+
Context: {context_str}
|
351 |
+
Question: {query_str}
|
352 |
+
Answer:
|
353 |
+
"""
|
354 |
+
)
|
|
|
|
|
|
|
355 |
|
356 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
357 |
index = load_index_from_storage(storage_context)
|
|
|
360 |
if past_query.strip():
|
361 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
362 |
|
363 |
+
query_engine = index.as_query_engine(text_qa_template=text_qa_template)
|
364 |
answer = query_engine.query(query)
|
365 |
|
366 |
if hasattr(answer, "response"):
|