redfernstech commited on
Commit
4cfe99e
·
verified ·
1 Parent(s): 60a19b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -235,7 +235,7 @@ import time
235
  from fastapi import FastAPI, Request
236
  from fastapi.responses import HTMLResponse
237
  from fastapi.staticfiles import StaticFiles
238
- from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
239
  from llama_index.llms.huggingface import HuggingFaceLLM
240
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
241
  from pydantic import BaseModel
@@ -244,7 +244,7 @@ import datetime
244
  from fastapi.middleware.cors import CORSMiddleware
245
  from fastapi.templating import Jinja2Templates
246
  from simple_salesforce import Salesforce, SalesforceLogin
247
- from transformers import AutoModelForSeq2SeqLM
248
 
249
  # Define Pydantic model for incoming request body
250
  class MessageRequest(BaseModel):
@@ -288,6 +288,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
288
  templates = Jinja2Templates(directory="static")
289
 
290
  # Configure Llama index settings
 
291
  Settings.llm = HuggingFaceLLM(
292
  model_name="google/flan-t5-small",
293
  tokenizer_name="google/flan-t5-small",
@@ -295,6 +296,7 @@ Settings.llm = HuggingFaceLLM(
295
  max_new_tokens=256,
296
  generate_kwargs={"temperature": 0.1, "do_sample": True},
297
  model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
 
298
  device_map="auto" # Automatically use GPU if available, else CPU
299
  )
300
  Settings.embed_model = HuggingFaceEmbedding(
@@ -341,18 +343,15 @@ def split_name(full_name):
341
  initialize() # Run initialization tasks
342
 
343
  def handle_query(query):
344
- chat_text_qa_msgs = [
345
- (
346
- "user",
347
- """
348
- You are Clara, a Redfernstech chatbot. Provide accurate, concise answers (10-15 words) based on company data.
349
- Context: {context_str}
350
- Question: {query_str}
351
- Answer:
352
- """
353
- )
354
- ]
355
- text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
356
 
357
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
358
  index = load_index_from_storage(storage_context)
@@ -361,7 +360,7 @@ def handle_query(query):
361
  if past_query.strip():
362
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
363
 
364
- query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
365
  answer = query_engine.query(query)
366
 
367
  if hasattr(answer, "response"):
 
235
  from fastapi import FastAPI, Request
236
  from fastapi.responses import HTMLResponse
237
  from fastapi.staticfiles import StaticFiles
238
+ from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, Settings
239
  from llama_index.llms.huggingface import HuggingFaceLLM
240
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
241
  from pydantic import BaseModel
 
244
  from fastapi.middleware.cors import CORSMiddleware
245
  from fastapi.templating import Jinja2Templates
246
  from simple_salesforce import Salesforce, SalesforceLogin
247
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
248
 
249
  # Define Pydantic model for incoming request body
250
  class MessageRequest(BaseModel):
 
288
  templates = Jinja2Templates(directory="static")
289
 
290
  # Configure Llama index settings
291
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
292
  Settings.llm = HuggingFaceLLM(
293
  model_name="google/flan-t5-small",
294
  tokenizer_name="google/flan-t5-small",
 
296
  max_new_tokens=256,
297
  generate_kwargs={"temperature": 0.1, "do_sample": True},
298
  model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
299
+ tokenizer=tokenizer,
300
  device_map="auto" # Automatically use GPU if available, else CPU
301
  )
302
  Settings.embed_model = HuggingFaceEmbedding(
 
343
  initialize() # Run initialization tasks
344
 
345
  def handle_query(query):
346
+ # Custom prompt template for flan-t5-small (no chat template)
347
+ text_qa_template = PromptTemplate(
348
+ """
349
+ You are Clara, a Redfernstech chatbot. Provide accurate, concise answers (10-15 words) based on company data.
350
+ Context: {context_str}
351
+ Question: {query_str}
352
+ Answer:
353
+ """
354
+ )
 
 
 
355
 
356
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
357
  index = load_index_from_storage(storage_context)
 
360
  if past_query.strip():
361
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
362
 
363
+ query_engine = index.as_query_engine(text_qa_template=text_qa_template)
364
  answer = query_engine.query(query)
365
 
366
  if hasattr(answer, "response"):