from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List from transformers import pipeline import torch # Initialize the FastAPI app app = FastAPI() # Determine device (use GPU if available, otherwise CPU) device = 0 if torch.cuda.is_available() else -1 # Initialize the NER pipeline ner_pipeline = pipeline( "ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple", device=device ) # Initialize the QA pipeline qa_pipeline = pipeline( "question-answering", model="deepset/roberta-base-squad2", device=device ) # Allowed domains for filtering allowed_domains = [ "clothing", "fashion", "shopping", "accessories", "pants", "jeans", "shirts", "sustainable materials" ] # Context for the QA pipeline context_msg = """ We offer a wide variety of clothing options, including sustainable pants, jeans, chinos, and trousers. Our products are made with eco-friendly materials and are available in styles such as casual wear, formal wear, and activewear. """ # Pydantic models for structured responses class Entity(BaseModel): word: str entity_group: str score: float class NERResponse(BaseModel): entities: List[Entity] words: List[str] # List of extracted words (added) class QAResponse(BaseModel): question: str answer: str score: float class CombinedRequest(BaseModel): text: str # The input text prompt class CombinedResponse(BaseModel): ner: NERResponse # NER output qa: QAResponse # QA output # Function to check if the input text belongs to allowed domains def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool: for domain in domains: if domain in text.lower(): return True return False # Combined endpoint for NER and QA with domain filtering @app.post("/process/", response_model=CombinedResponse) async def process_request(request: CombinedRequest): input_text = request.text # Check if the input text belongs to the allowed domains if not is_text_in_allowed_domain(input_text, allowed_domains): raise HTTPException( status_code=400, detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories." ) # Perform Named Entity Recognition (NER) ner_entities = ner_pipeline(input_text) # Process NER results into a structured response formatted_entities = [ { "word": entity["word"], "entity_group": entity["entity_group"], "score": float(entity["score"]), } for entity in ner_entities ] ner_words = [entity["word"] for entity in ner_entities] # Collect only the words ner_response = { "entities": formatted_entities, "words": ner_words # Include the list of words } # Perform Question Answering (QA) qa_result = qa_pipeline(question=input_text, context=context_msg) qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float qa_response = { "question": input_text, "answer": qa_result["answer"], "score": qa_result["score"] } # Return both NER and QA responses return {"ner": ner_response, "qa": qa_response} # Root endpoint @app.get("/") async def root(): return {"message": "Welcome to the NER and QA API!"}