Spaces:
Runtime error
Runtime error
Commit
·
8f36cc4
1
Parent(s):
6da5bec
Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
# # Library for Sentence Similarity
|
4 |
-
# import pandas as pd
|
5 |
-
# from sentence_transformers import SentenceTransformer
|
6 |
-
# from sklearn.metrics.pairwise import cosine_similarity
|
7 |
|
8 |
# Library for Entailment
|
9 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
10 |
import torch
|
11 |
|
12 |
-
|
13 |
-
# # Library for keyword extraction
|
14 |
-
# import yake
|
15 |
-
|
16 |
-
|
17 |
-
# Load models and tokenisers for both sentence transformers and text classification
|
18 |
-
|
19 |
-
# sentence_transformer_model = SentenceTransformer('all-MiniLM-L6-v2')
|
20 |
|
21 |
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
|
22 |
|
@@ -26,157 +15,57 @@ text_classification_model = AutoModelForSequenceClassification.from_pretrained("
|
|
26 |
|
27 |
### Streamlit interface ###
|
28 |
|
29 |
-
st.title("
|
30 |
|
31 |
-
|
32 |
-
"What would you like to work with?",
|
33 |
-
("Compare two sentences", "Bulk upload and mark")
|
34 |
-
)
|
35 |
|
36 |
-
|
37 |
|
38 |
-
|
39 |
|
40 |
-
st.
|
41 |
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
sentence_2 = st.text_input("Sentence 2 input")
|
47 |
-
|
48 |
-
submit_button_compare = st.form_submit_button("Compare Sentences")
|
49 |
-
|
50 |
-
# If submit_button_compare clicked
|
51 |
-
if submit_button_compare:
|
52 |
-
|
53 |
-
print("Comparing sentences...")
|
54 |
-
|
55 |
-
# ### Compare Sentence Similarity ###
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
# #Initialise sentences
|
60 |
-
# sentences = []
|
61 |
-
|
62 |
-
# # Append input sentences to 'sentences' list
|
63 |
-
# sentences.append(sentence_1)
|
64 |
-
# sentences.append(sentence_2)
|
65 |
-
|
66 |
-
# # Create embeddings for both sentences
|
67 |
-
# sentence_embeddings = sentence_transformer_model.encode(sentences)
|
68 |
-
|
69 |
-
# cos_sim = cosine_similarity(sentence_embeddings[0].reshape(1, -1), sentence_embeddings[1].reshape(1, -1))[0][0]
|
70 |
-
# cos_sim = round(cos_sim * 100) # Convert to percentage and round-off
|
71 |
-
|
72 |
-
|
73 |
-
# # st.write('Similarity between "{}" and "{}" is {}%'.format(sentence_1,
|
74 |
-
# # sentence_2, cos_sim))
|
75 |
-
|
76 |
-
# st.subheader("Similarity")
|
77 |
-
# st.write(f"Similarity between the two sentences is {cos_sim}%.")
|
78 |
-
|
79 |
-
|
80 |
-
### Text classification - entailment, neutral or contradiction ###
|
81 |
-
|
82 |
-
raw_inputs = [f"{sentence_1}</s></s>{sentence_2}"]
|
83 |
|
84 |
-
|
85 |
|
86 |
-
|
87 |
|
88 |
-
|
89 |
|
90 |
-
|
91 |
-
# print(outputs)
|
92 |
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
print(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
|
97 |
-
print(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
|
98 |
|
99 |
-
|
|
|
100 |
|
101 |
-
|
102 |
-
st.write(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
|
103 |
-
st.write(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
|
104 |
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
|
108 |
-
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
# keywords = kw_extractor.extract_keywords(sentence_2)
|
112 |
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
# for kw, v in keywords:
|
116 |
-
# print("Keyphrase: ", kw, ": score", v)
|
117 |
-
# keywords_array.append(kw)
|
118 |
|
119 |
-
# st.write(kw)
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
if sidebar_selectbox == "Bulk upload and mark":
|
126 |
-
|
127 |
-
st.subheader("Bulk compare similarity of sentences")
|
128 |
-
|
129 |
-
sentence_reference = st.text_input("Reference sentence input")
|
130 |
-
|
131 |
-
# Only allow user to upload CSV files
|
132 |
-
data_file = st.file_uploader("Upload CSV",type=["csv"])
|
133 |
-
|
134 |
-
if data_file is not None:
|
135 |
-
with st.spinner('Wait for it...'):
|
136 |
-
file_details = {"filename":data_file.name, "filetype":data_file.type, "filesize":data_file.size}
|
137 |
-
# st.write(file_details)
|
138 |
-
df = pd.read_csv(data_file)
|
139 |
-
|
140 |
-
# Get length of df.shape (might not need this)
|
141 |
-
#total_rows = df.shape[0]
|
142 |
-
|
143 |
-
similarity_scores = []
|
144 |
-
|
145 |
-
for idx, row in df.iterrows():
|
146 |
-
# st.write(idx, row['Sentences'])
|
147 |
-
|
148 |
-
# Create an empty sentence list
|
149 |
-
sentences = []
|
150 |
-
|
151 |
-
# Compare the setences two by two
|
152 |
-
sentence_comparison = row['Sentences']
|
153 |
-
sentences.append(sentence_reference)
|
154 |
-
sentences.append(sentence_comparison)
|
155 |
-
|
156 |
-
sentence_embeddings = sentence_transformer_model.encode(sentences)
|
157 |
-
|
158 |
-
cos_sim = cosine_similarity(sentence_embeddings[0].reshape(1, -1), sentence_embeddings[1].reshape(1, -1))[0][0]
|
159 |
-
cos_sim = round(cos_sim * 100)
|
160 |
-
|
161 |
-
similarity_scores.append(cos_sim)
|
162 |
-
|
163 |
-
# Append new column to dataframe
|
164 |
-
|
165 |
-
df['Similarity (%)'] = similarity_scores
|
166 |
-
|
167 |
-
st.dataframe(df)
|
168 |
-
st.success('Done!')
|
169 |
-
|
170 |
-
@st.cache
|
171 |
-
def convert_df(df):
|
172 |
-
return df.to_csv().encode('utf-8')
|
173 |
-
|
174 |
-
csv = convert_df(df)
|
175 |
-
|
176 |
-
st.download_button(
|
177 |
-
"Press to Download",
|
178 |
-
csv,
|
179 |
-
"marked assignment.csv",
|
180 |
-
"text/csv",
|
181 |
-
key='download-csv'
|
182 |
-
)
|
|
|
1 |
import streamlit as st
|
2 |
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Library for Entailment
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
import torch
|
7 |
|
8 |
+
# Load model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
|
11 |
|
|
|
15 |
|
16 |
### Streamlit interface ###
|
17 |
|
18 |
+
st.title("Text Classification")
|
19 |
|
20 |
+
st.subheader("Entailment, neutral, or contradiction?")
|
|
|
|
|
|
|
21 |
|
22 |
+
with st.form("submission_form", clear_on_submit=False):
|
23 |
|
24 |
+
threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, step=0.1, value=0.7)
|
25 |
|
26 |
+
sentence_1 = st.text_input("Sentence 1 input")
|
27 |
|
28 |
+
sentence_2 = st.text_input("Sentence 2 input")
|
29 |
|
30 |
+
submit_button_compare = st.form_submit_button("Compare Sentences")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
# If submit_button_compare clicked
|
33 |
+
if submit_button_compare:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
print("Comparing sentences...")
|
36 |
|
37 |
+
### Text classification - entailment, neutral or contradiction ###
|
38 |
|
39 |
+
raw_inputs = [f"{sentence_1}</s></s>{sentence_2}"]
|
40 |
|
41 |
+
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
|
|
|
42 |
|
43 |
+
# print(inputs)
|
44 |
|
45 |
+
outputs = text_classification_model(**inputs)
|
|
|
|
|
46 |
|
47 |
+
outputs = torch.nn.functional.softmax(outputs.logits, dim = -1)
|
48 |
+
# print(outputs)
|
49 |
|
50 |
+
# argmax_index = torch.argmax(outputs).item()
|
|
|
|
|
51 |
|
52 |
+
print(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
|
53 |
+
print(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
|
54 |
+
print(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
|
55 |
|
56 |
+
st.subheader("Text classification for both sentences:")
|
57 |
|
58 |
+
st.write(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
|
59 |
+
st.write(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
|
60 |
+
st.write(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
|
61 |
|
62 |
+
entailment_score = round(outputs[0][2].item()*100,2)
|
|
|
63 |
|
64 |
+
if entailment_score >= threshold:
|
65 |
+
st.subheader("The statements are very similar!")
|
66 |
+
st.balloons()
|
67 |
+
else:
|
68 |
+
st.subheader("The statements are not close enough")
|
69 |
|
|
|
|
|
|
|
70 |
|
|
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|