|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from typing import List, Dict, Any |
|
from pymongo import MongoClient |
|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
import spacy |
|
import os |
|
import logging |
|
import re |
|
import torch |
|
import random |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
connection_string = os.getenv("MONGO_URI", "mongodb+srv://clician:[email protected]/?retryWrites=true&w=majority&appName=Hutterdev") |
|
client = MongoClient(connection_string) |
|
db = client["test"] |
|
products_collection = db["products"] |
|
|
|
|
|
model_repo = "SyedHutter/blenderbot_model" |
|
model_subfolder = "blenderbot_model" |
|
model_dir = "/home/user/app/blenderbot_model" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
if not os.path.exists(model_dir): |
|
logger.info(f"Downloading {model_repo}/{model_subfolder} to {model_dir}...") |
|
tokenizer = BlenderbotTokenizer.from_pretrained(model_repo, subfolder=model_subfolder) |
|
model = BlenderbotForConditionalGeneration.from_pretrained(model_repo, subfolder=model_subfolder) |
|
os.makedirs(model_dir, exist_ok=True) |
|
tokenizer.save_pretrained(model_dir) |
|
model.save_pretrained(model_dir) |
|
logger.info("Model download complete.") |
|
else: |
|
logger.info(f"Loading pre-existing model from {model_dir}.") |
|
tokenizer = BlenderbotTokenizer.from_pretrained(model_dir) |
|
model = BlenderbotForConditionalGeneration.from_pretrained(model_dir).to(device) |
|
model.eval() |
|
|
|
|
|
context_msg = "I am Hutter, your shopping guide for Hutter Products GmbH, here to help you find sustainable products." |
|
|
|
|
|
spacy_model_path = "/home/user/app/en_core_web_sm-3.8.0" |
|
nlp = spacy.load(spacy_model_path) |
|
|
|
|
|
class PromptRequest(BaseModel): |
|
input_text: str |
|
conversation_history: List[str] = [] |
|
|
|
class CombinedResponse(BaseModel): |
|
ner: Dict[str, Any] |
|
qa: Dict[str, Any] |
|
products_matched: List[Dict[str, Any]] |
|
|
|
|
|
def extract_keywords(text: str) -> List[str]: |
|
doc = nlp(text) |
|
keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]] |
|
return list(set(keywords)) |
|
|
|
def detect_intent(text: str) -> str: |
|
doc = nlp(text.lower()) |
|
text_lower = text.lower() |
|
if any(token.text in ["buy", "shop", "find", "recommend", "product", "products", "item", "store", "catalog"] for token in doc) or "what" in text_lower.split()[:2]: |
|
return "recommend_product" |
|
elif any(token.text in ["company", "who", "do"] for token in doc): |
|
return "company_info" |
|
elif "name" in text_lower or "yourself" in text_lower or ("you" in doc and "about" in doc): |
|
return "ask_name" |
|
elif re.search(r"\d+\s*[\+\-\*/]\s*\d+", text_lower): |
|
return "math_query" |
|
return "chat" |
|
|
|
def search_products_by_keywords(keywords: List[str]) -> List[Dict[str, Any]]: |
|
if not keywords: |
|
return [] |
|
query = {"$or": [{"name": {"$regex": f"\\b{keyword}\\b", "$options": "i"}} for keyword in keywords]} |
|
matched_products = [ |
|
{ |
|
"id": str(p["_id"]), |
|
"name": p.get("name", "Unknown"), |
|
"skuNumber": p.get("skuNumber", "N/A"), |
|
"description": p.get("description", "No description available") |
|
} |
|
for p in products_collection.find(query) |
|
] |
|
return matched_products |
|
|
|
def get_product_context(products: List[Dict]) -> str: |
|
if not products: |
|
return "" |
|
product_str = "Products: " + ", ".join([f"'{p['name']}' - {p['description']}" for p in products[:2]]) |
|
return product_str |
|
|
|
def format_response(response: str, products: List[Dict], intent: str, input_text: str, history: List[str]) -> str: |
|
|
|
base_response = response if response else "I’m here to help—what’s on your mind?" |
|
|
|
if intent == "recommend_product": |
|
if products: |
|
product = products[0] |
|
return f"{base_response} Speaking of sustainable products, check out our '{product['name']}'—it’s {product['description'].lower()}." |
|
prompts = [ |
|
f"{base_response} What sustainable items are you looking for today?", |
|
f"{base_response} Any specific eco-friendly products you’re curious about?", |
|
] |
|
return random.choice(prompts) |
|
|
|
elif intent == "company_info": |
|
return f"{base_response} I’m with Hutter Products GmbH—we focus on sustainable items like recycled textiles and ocean plastic goods." |
|
|
|
elif intent == "ask_name": |
|
return f"{base_response} I’m Hutter, your shopping guide for Hutter Products GmbH, here to assist with sustainable products." |
|
|
|
elif intent == "math_query": |
|
match = re.search(r"(\d+)\s*([\+\-\*/])\s*(\d+)", input_text.lower()) |
|
if match: |
|
num1, op, num2 = int(match.group(1)), match.group(2), int(match.group(3)) |
|
if op == "+": return f"{base_response} By the way, {num1} + {num2} = {num1 + num2}." |
|
elif op == "-": return f"{base_response} Also, {num1} - {num2} = {num1 - num2}." |
|
elif op == "*": return f"{base_response} Oh, and {num1} * {num2} = {num1 * num2}." |
|
elif op == "/": return f"{base_response} Plus, {num1} / {num2} = {num1 / num2}." if num2 != 0 else f"{base_response} Can’t divide by zero, though!" |
|
return f"{base_response} I can help with math—try something like '2 + 2'." |
|
|
|
elif intent == "chat": |
|
if "yes" in input_text.lower() and history and any(word in history[-1].lower() for word in ["hat", "product", "store"]): |
|
if products: |
|
product = products[0] |
|
return f"{base_response} Great! How about our '{product['name']}'? It’s {product['description'].lower()}." |
|
return f"{base_response} Want me to suggest some sustainable items?" |
|
return base_response |
|
|
|
return base_response |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Welcome to the NER and Chat API!"} |
|
|
|
@app.post("/process/", response_model=CombinedResponse) |
|
async def process_prompt(request: PromptRequest): |
|
try: |
|
logger.info(f"Processing request: {request.input_text}") |
|
input_text = request.input_text |
|
history = request.conversation_history[-1:] if request.conversation_history else [] |
|
|
|
intent = detect_intent(input_text) |
|
keywords = extract_keywords(input_text) |
|
logger.info(f"Intent: {intent}, Keywords: {keywords}") |
|
|
|
products = search_products_by_keywords(keywords) |
|
product_context = get_product_context(products) |
|
logger.info(f"Products matched: {len(products)}") |
|
|
|
history_str = " || ".join(history) |
|
full_input = f"{context_msg} || {product_context} || {input_text}" if product_context else f"{context_msg} || {input_text}" |
|
logger.info(f"Full input to model: {full_input}") |
|
|
|
logger.info("Tokenizing input...") |
|
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=64).to(device) |
|
logger.info("Input tokenized successfully.") |
|
|
|
logger.info("Generating model response...") |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=30, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.8, |
|
no_repeat_ngram_size=2 |
|
) |
|
logger.info("Model generation complete.") |
|
|
|
logger.info("Decoding model output...") |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
logger.info(f"Model response: {response}") |
|
|
|
enhanced_response = format_response(response, products, intent, input_text, request.conversation_history) |
|
qa_response = { |
|
"question": input_text, |
|
"answer": enhanced_response, |
|
"score": 1.0 |
|
} |
|
|
|
logger.info("Returning response...") |
|
return { |
|
"ner": {"extracted_keywords": keywords}, |
|
"qa": qa_response, |
|
"products_matched": products |
|
} |
|
except Exception as e: |
|
logger.error(f"Error processing request: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Oops, something went wrong: {str(e)}") |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
logger.info("API is running with BlenderBot-400M-distill, connected to MongoDB.") |
|
|
|
@app.on_event("shutdown") |
|
def shutdown_event(): |
|
client.close() |