data_project / app.py
kartik91's picture
Create app.py
5cefe7f verified
raw
history blame contribute delete
4.07 kB
from transformers import pipeline, AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
import faiss
class TextClassifier:
def __init__(self, model_name='distilbert-base-uncased'):
self.classifier = pipeline("text-classification", model=model_name)
def classify(self, text):
return self.classifier(text)[0]['label']
class SentimentAnalyzer:
def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
self.analyzer = pipeline("sentiment-analysis", model=model_name)
def analyze(self, text):
return self.analyzer(text)[0]
class ImageRecognizer:
def __init__(self, model_name='resnet50'):
self.model = models.resnet50(pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def recognize(self, image_path):
image = Image.open(image_path)
image = self.transform(image).unsqueeze(0)
with torch.no_grad():
outputs = self.model(image)
_, predicted = torch.max(outputs, 1)
return predicted.item()
class TextGenerator:
def __init__(self, model_name='gpt2'):
self.generator = pipeline("text-generation", model=model_name)
def generate(self, prompt):
response = self.generator(prompt, max_length=100, num_return_sequences=1)
return response[0]['generated_text']
class FAQRetriever:
def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings
def embed(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
return embeddings.cpu().numpy()
def add_faqs(self, faqs):
self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
faiss.normalize_L2(self.faq_embeddings)
self.index.add(self.faq_embeddings)
def retrieve(self, query):
query_embedding = self.embed(query)
faiss.normalize_L2(query_embedding)
D, I = self.index.search(query_embedding, 5)
return I[0] # Return top 5 FAQ indices
class CustomerSupportAssistant:
def __init__(self):
self.text_classifier = TextClassifier()
self.sentiment_analyzer = SentimentAnalyzer()
self.image_recognizer = ImageRecognizer()
self.text_generator = TextGenerator()
self.faq_retriever = FAQRetriever()
self.faqs = [
"How to reset my password?",
"What is the return policy?",
"How to track my order?",
"How to contact customer support?",
"What payment methods are accepted?"
]
self.faq_retriever.add_faqs(self.faqs)
def process_query(self, text, image_path=None):
topic = self.text_classifier.classify(text)
sentiment = self.sentiment_analyzer.analyze(text)
if image_path:
image_info = self.image_recognizer.recognize(image_path)
else:
image_info = "No image provided."
faqs = self.faq_retriever.retrieve(text)
faq_responses = [self.faqs[i] for i in faqs]
response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
response = self.text_generator.generate(response_prompt)
return response
# Example usage:
assistant = CustomerSupportAssistant()
input_text = "I'm having trouble with my recent order."
output = assistant.process_query(input_text)
print(output)