Spaces:
Sleeping
Sleeping
Oscar Wang
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,13 +2,26 @@ import gradio as gr
|
|
2 |
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Define the prediction function
|
11 |
-
def classify_text(text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Remove periods and new lines from the input text
|
13 |
cleaned_text = text.replace('.', '').replace('\n', ' ')
|
14 |
|
@@ -30,12 +43,15 @@ def classify_text(text):
|
|
30 |
# Create the Gradio interface
|
31 |
iface = gr.Interface(
|
32 |
fn=classify_text,
|
33 |
-
inputs=
|
|
|
|
|
|
|
34 |
outputs="json",
|
35 |
-
title="GoalZero Ada
|
36 |
-
description="Enter
|
37 |
)
|
38 |
|
39 |
# Launch the app
|
40 |
if __name__ == "__main__":
|
41 |
-
iface.launch(share=True)
|
|
|
2 |
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
3 |
import torch
|
4 |
|
5 |
+
# Define available models
|
6 |
+
model_options = {
|
7 |
+
"GoalZero/aidetection-ada-v0.2": "GoalZero/aidetection-ada-v0.2",
|
8 |
+
"GoalZero/aidetection-ada-v0.1": "GoalZero/aidetection-ada-v0.1"
|
9 |
+
}
|
10 |
+
|
11 |
+
# Initialize tokenizer and model with the default model
|
12 |
+
default_model = model_options["GoalZero/aidetection-ada-v0.2"]
|
13 |
+
tokenizer = RobertaTokenizer.from_pretrained(default_model)
|
14 |
+
model = RobertaForSequenceClassification.from_pretrained(default_model)
|
15 |
|
16 |
# Define the prediction function
|
17 |
+
def classify_text(text, model_choice):
|
18 |
+
global model, tokenizer # Access the global model and tokenizer variables
|
19 |
+
|
20 |
+
# Check if the model needs to be changed
|
21 |
+
if model_choice != model.name_or_path:
|
22 |
+
model = RobertaForSequenceClassification.from_pretrained(model_choice)
|
23 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_choice)
|
24 |
+
|
25 |
# Remove periods and new lines from the input text
|
26 |
cleaned_text = text.replace('.', '').replace('\n', ' ')
|
27 |
|
|
|
43 |
# Create the Gradio interface
|
44 |
iface = gr.Interface(
|
45 |
fn=classify_text,
|
46 |
+
inputs=[
|
47 |
+
gr.Textbox(lines=2, placeholder="Enter text here..."),
|
48 |
+
gr.Dropdown(choices=list(model_options.keys()), value=default_model, label="Select Model")
|
49 |
+
],
|
50 |
outputs="json",
|
51 |
+
title="GoalZero Ada Model Selector",
|
52 |
+
description="Enter text to get the probability of it being AI-written. Select a model version to use.",
|
53 |
)
|
54 |
|
55 |
# Launch the app
|
56 |
if __name__ == "__main__":
|
57 |
+
iface.launch(share=True)
|