Commit
·
7b0a595
1
Parent(s):
6f69f46
Update space
Browse files
app.py
CHANGED
@@ -6,11 +6,36 @@ import streamlit as st
|
|
6 |
model = AutoModelForSequenceClassification.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
|
7 |
tokenizer = AutoTokenizer.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def predict(text):
|
10 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
11 |
outputs = model(**inputs)
|
12 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
13 |
-
|
|
|
14 |
|
15 |
# Streamlit app layout
|
16 |
st.title("Retail Product Classification")
|
|
|
6 |
model = AutoModelForSequenceClassification.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
|
7 |
tokenizer = AutoTokenizer.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
|
8 |
|
9 |
+
label_to_id = {'Electronics':0,
|
10 |
+
'Sports & Outdoors':1,
|
11 |
+
'Cell Phones & Accessories':2,
|
12 |
+
'Automotive':3,
|
13 |
+
'Toys & Games':4,
|
14 |
+
'Tools & Home Improvement':5,
|
15 |
+
'Health & Personal Care':6,
|
16 |
+
'Beauty':7,
|
17 |
+
'Grocery & Gourmet Food':8,
|
18 |
+
'Office Products':9,
|
19 |
+
'Arts, Crafts & Sewing':10,
|
20 |
+
'Pet Supplies':11,
|
21 |
+
'Patio, Lawn & Garden':12,
|
22 |
+
'Clothing, Shoes & Jewelry':13,
|
23 |
+
'Baby':14,
|
24 |
+
'Musical Instruments':15,
|
25 |
+
'Industrial & Scientific':16,
|
26 |
+
'Baby Products':17,
|
27 |
+
'Appliances':18,
|
28 |
+
'All Beauty':19,
|
29 |
+
'All Electronics':20}
|
30 |
+
|
31 |
+
id_to_label = {v: k for k, v in label_to_id.items()}
|
32 |
+
|
33 |
def predict(text):
|
34 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
35 |
outputs = model(**inputs)
|
36 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
37 |
+
class_name = id_to_label[predicted_class]
|
38 |
+
return class_name
|
39 |
|
40 |
# Streamlit app layout
|
41 |
st.title("Retail Product Classification")
|