File size: 7,403 Bytes
d1021bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# -*- coding: utf-8 -*-
"""Now that we've built a powerful LLM-based classifier, let's showcase it to the world by creating an interactive demo. In this chapter, we'll learn how to:
- Create a user-friendly web interface using Gradio
- Package our demo for deployment
- Deploy it on Hugging Face Spaces for free
- Use the Hugging Face Inference API for model access
"""

import json
import time
import os
import sys
from retry import retry
from rich.progress import track
from huggingface_hub import InferenceClient
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
import gradio as gr

# Calling HF client
api_key = os.getenv("HF_TOKEN")
client = InferenceClient(token=api_key)

# Load sample data
sample_df = pd.read_csv("sample.csv")

def get_batch_list(li, n=10):
    """Split the provided list into batches of size `n`."""
    batch_list = []
    for i in range(0, len(li), n):
        batch_list.append(li[i : i + n])
    return batch_list

# Helper function to split data into batches
training_input, test_input, training_output, test_output = train_test_split(
    sample_df[['payee']],
    sample_df['category'],
    test_size=0.33,
    random_state=42
)

# Function to create few-shot examples
def get_fewshots(training_input, training_output, batch_size=10):
    """Convert the training input and output from sklearn's train_test_split into a few-shot prompt"""
    # Batch up the training input into groups of `batch_size`
    input_batches = get_batch_list(list(training_input.payee), n=batch_size)

    # Do the same for the output
    output_batches = get_batch_list(list(training_output), n=batch_size)

    # Create a list to hold the formatted few-shot examples
    fewshot_list = []

    # Ensure we only process complete pairs
    batch_count = min(len(input_batches), len(output_batches))

    # Loop through the batches
    for i in range(batch_count):
        fewshot_list.extend([
            # Create a "user" message for the LLM formatted the same as our prompt with newlines
            {
                "role": "user",
                "content": "\n".join(input_batches[i]),
            },
            # Create the expected "assistant" response as the JSON formatted output we expect
            {
                "role": "assistant",
                "content": json.dumps(output_batches[i])
            }
        ])

    # Return the list of few-shot examples, one for each batch
    return fewshot_list

fewshot_list = get_fewshots(training_input, training_output)

@retry(ValueError, tries=2, delay=2)
def classify_payees(name_list):
    prompt = """You are an AI model trained to categorize businesses based on their names.

You will be given a list of business names, each separated by a new line.

Your task is to analyze each name and classify it into one of the following categories: Restaurant, Bar, Hotel, or Other.

It is extremely critical that there is a corresponding category output for each business name provided as an input.

If a business does not clearly fall into Restaurant, Bar, or Hotel categories, you should classify it as "Other".

Even if the type of business is not immediately clear from the name, it is essential that you provide your best guess based on the information available to you. If you can't make a good guess, classify it as Other.

For example, if given the following input:

"Intercontinental Hotel\nPizza Hut\nCheers\nWelsh's Family Restaurant\nKTLA\nDirect Mailing"

Your output should be a JSON list in the following format:

["Hotel", "Restaurant", "Bar", "Restaurant", "Other", "Other"]

This means that you have classified "Intercontinental Hotel" as a Hotel, "Pizza Hut" as a Restaurant, "Cheers" as a Bar, "Welsh's Family Restaurant" as a Restaurant, and both "KTLA" and "Direct Mailing" as Other.

If a business name contains both the word "Restaurant" and the word "Bar", you should classify it as a Restaurant.

Ensure that the number of classifications in your output matches the number of business names in the input. It is very important that the length of the JSON list you return is exactly the same as the number of business names you receive.
"""
    try:
        response = client.chat.completions.create(
            messages=[
                # System role message that explains the classification task
                {
                    "role": "system",
                    "content": prompt,
                },
                *fewshot_list,
                {
                    "role": "user",
                    "content": "\n".join(name_list),
                }
            ],
            model="meta-llama/Llama-3.3-70B-Instruct",
            temperature=0,
        )

        answer_str = response.choices[0].message.content
        answer_list = json.loads(answer_str)

        acceptable_answers = [
            "Restaurant",
            "Bar",
            "Hotel",
            "Other",
        ]
        for answer in answer_list:
            if answer not in acceptable_answers:
                raise ValueError(f"{answer} not in list of acceptable answers")

        if len(name_list) != len(answer_list):
            raise ValueError(f"Number of inputs ({len(name_list)}) does not equal the number of outputs ({len(answer_list)})")

        return dict(zip(name_list, answer_list))
        
    except Exception as e:
        # Reraise as ValueError to trigger retry
        raise ValueError(f"Error during classification: {str(e)}")

def classify_batches(name_list, batch_size=10, wait=2):
    # Store the results
    all_results = {}

    # Batch up the list
    batch_list = get_batch_list(name_list, n=batch_size)

    # Loop through the list in batches
    for batch in track(batch_list):
        try:
            # Classify it
            batch_results = classify_payees(batch)

            # Add it to the results
            all_results.update(batch_results)

            # Tap the brakes
            time.sleep(wait)
        except Exception as e:
            print(f"Error processing batch: {e}", file=sys.stderr)
            # Continue with other batches

    # Return the results
    return pd.DataFrame(
        all_results.items(),
        columns=["payee", "category"]
    )

# Run classification on test data
llm_df = classify_batches(list(test_input.payee))

# -- Gradio interface function --
def classify_business_names(input_text):
    # Parse input text into list of names
    name_list = [line.strip() for line in input_text.splitlines() if line.strip()]
    
    if not name_list:
        return json.dumps({"error": "No business names provided. Please enter at least one business name."})
        
    try:
        result = classify_payees(name_list)
        return json.dumps(result, indent=2)
    except Exception as e:
        return json.dumps({"error": f"Classification failed: {str(e)}"})

# -- Launch the demo --
demo = gr.Interface(
    fn=classify_business_names,
    inputs=gr.Textbox(lines=10, placeholder="Enter business names, one per line"),
    outputs=gr.JSON(),
    title="Business Category Classifier",
    description="Enter business names and get a classification: Restaurant, Bar, Hotel, or Other.",
    examples=[
        ["Marriott Hotel\nTaco Bell\nThe Tipsy Cow\nStarbucks\nApple Store"]
    ]
)

if __name__ == "__main__":
    demo.launch(share=True)