File size: 3,559 Bytes
0888a3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from transformers import pipeline
import torch
# Initialize the FastAPI app
app = FastAPI()
# Determine device (use GPU if available, otherwise CPU)
device = 0 if torch.cuda.is_available() else -1
# Initialize the NER pipeline
ner_pipeline = pipeline(
"ner",
model="dbmdz/bert-large-cased-finetuned-conll03-english",
aggregation_strategy="simple",
device=device
)
# Initialize the QA pipeline
qa_pipeline = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
device=device
)
# Allowed domains for filtering
allowed_domains = [
"clothing", "fashion", "shopping", "accessories", "pants", "jeans", "shirts", "sustainable materials"
]
# Context for the QA pipeline
context_msg = """
We offer a wide variety of clothing options, including sustainable pants, jeans, chinos, and trousers.
Our products are made with eco-friendly materials and are available in styles such as casual wear, formal wear, and activewear.
"""
# Pydantic models for structured responses
class Entity(BaseModel):
word: str
entity_group: str
score: float
class NERResponse(BaseModel):
entities: List[Entity]
words: List[str] # List of extracted words (added)
class QAResponse(BaseModel):
question: str
answer: str
score: float
class CombinedRequest(BaseModel):
text: str # The input text prompt
class CombinedResponse(BaseModel):
ner: NERResponse # NER output
qa: QAResponse # QA output
# Function to check if the input text belongs to allowed domains
def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
for domain in domains:
if domain in text.lower():
return True
return False
# Combined endpoint for NER and QA with domain filtering
@app.post("/process/", response_model=CombinedResponse)
async def process_request(request: CombinedRequest):
input_text = request.text
# Check if the input text belongs to the allowed domains
if not is_text_in_allowed_domain(input_text, allowed_domains):
raise HTTPException(
status_code=400,
detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories."
)
# Perform Named Entity Recognition (NER)
ner_entities = ner_pipeline(input_text)
# Process NER results into a structured response
formatted_entities = [
{
"word": entity["word"],
"entity_group": entity["entity_group"],
"score": float(entity["score"]),
}
for entity in ner_entities
]
ner_words = [entity["word"] for entity in ner_entities] # Collect only the words
ner_response = {
"entities": formatted_entities,
"words": ner_words # Include the list of words
}
# Perform Question Answering (QA)
qa_result = qa_pipeline(question=input_text, context=context_msg)
qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
qa_response = {
"question": input_text,
"answer": qa_result["answer"],
"score": qa_result["score"]
}
# Return both NER and QA responses
return {"ner": ner_response, "qa": qa_response}
# Root endpoint
@app.get("/")
async def root():
return {"message": "Welcome to the NER and QA API!"}
|