File size: 4,343 Bytes
cbcd78b
1ea874c
be195b7
 
4530b74
be195b7
 
 
 
cbcd78b
be195b7
 
 
1ea874c
be195b7
 
 
 
 
 
1ea874c
be195b7
 
80cfb3b
3fac692
be195b7
cbcd78b
0a5100e
80cfb3b
be195b7
9d03f28
3fac692
be195b7
3fac692
 
be195b7
3fac692
cbcd78b
 
 
be195b7
 
 
 
9d03f28
511d4e8
 
 
 
be195b7
 
 
 
 
 
511d4e8
b9c37e3
 
511d4e8
 
80cfb3b
be195b7
80cfb3b
511d4e8
80cfb3b
511d4e8
9d32e7a
 
 
511d4e8
 
 
 
 
 
 
 
 
 
80cfb3b
be195b7
 
 
1ea874c
 
be195b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511d4e8
be195b7
 
511d4e8
be195b7
 
 
cbcd78b
511d4e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import pipeline
import pandas as pd
import spaces

# Load dataset
from datasets import load_dataset
ds = load_dataset('ZennyKenny/demo_customer_nps')
df = pd.DataFrame(ds['train'])

# Initialize model pipeline
from huggingface_hub import login
import os

# Login using the API key stored as an environment variable
hf_api_key = os.getenv("API_KEY")
login(token=hf_api_key)

classifier = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
generator = pipeline("text2text-generation", model="google/flan-t5-base")

# Function to classify customer comments
@spaces.GPU
def classify_comments(category_boxes):
    sentiments = []
    categories = []
    for comment in df['customer_comment']:
        sentiment = classifier(comment)[0]['label']
        category_list = [box for box in category_boxes if box.strip() != '']
        category_str = ', '.join([cat.strip() for cat in category_list])
        prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}."
        category = generator(prompt, max_length=30)[0]['generated_text']
        categories.append(category)
        sentiments.append(sentiment)
    df['comment_sentiment'] = sentiments
    df['comment_category'] = categories
    return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False)

# Gradio Interface
with gr.Blocks() as nps:
    def add_category(category_list, new_category):
        if new_category.strip() != "":
            category_list.append(new_category.strip())  # Add new category
        return category_list

    def remove_category(category, category_list):
        category_list.remove(category)  # Remove selected category
        return category_list

    def display_categories(categories):
        category_components = []
        for i, cat in enumerate(categories):
            with gr.Row():
                gr.Markdown(f"- {cat}")
                remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
                remove_btn.click(
                    fn=remove_category,
                    inputs=[gr.State(cat), category_boxes],
                    outputs=category_boxes
                )
            category_components.append(gr.Row())
        return category_components

    category_boxes = gr.State([])  # Store category input boxes as state
    category_column = gr.Column()

    with gr.Row():
        category_input = gr.Textbox(label="New Category", placeholder="Enter category name")
        add_category_btn = gr.Button("Add Category")
        add_category_btn.click(
            fn=add_category,
            inputs=[category_boxes, category_input],
            outputs=category_boxes
        )
        category_boxes.change(
            fn=display_categories,
            inputs=category_boxes,
            outputs=category_column
        )

    uploaded_file = gr.File(label="Upload CSV", type="filepath")
    template_btn = gr.Button("Use Template")
    gr.Markdown("# NPS Comment Categorization")
    classify_btn = gr.Button("Classify Comments")
    output = gr.HTML()

    def load_data(file):
        if file is not None:
            file.seek(0)  # Reset file pointer
            import io
            if file.name.endswith('.csv'):
                custom_df = pd.read_csv(file, encoding='utf-8')
            else:
                return "Error: Uploaded file is not a CSV."
            if 'customer_comment' not in custom_df.columns:
                return "Error: Uploaded CSV must contain a column named 'customer_comment'"
            global df
            df = custom_df
            return "Custom CSV loaded successfully!"
        else:
            return "No file uploaded."

    uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=output)
    template_btn.click(fn=lambda: "Using Template Dataset", outputs=output)

    def use_template():
        return ["Product Experience", "Customer Support", "Price of Service", "Other"]

    template_btn.click(fn=use_template, outputs=category_boxes)
    category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
    classify_btn.click(fn=classify_comments, inputs=category_boxes, outputs=output)

nps.launch(share=True)