Oscar Wang commited on
Commit
34531ec
·
verified ·
1 Parent(s): 844edf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -2,13 +2,26 @@ import gradio as gr
2
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
3
  import torch
4
 
5
- # Load the model and tokenizer from the specified directory
6
- model_path = 'GoalZero/aidetection-ada-v0.1'
7
- tokenizer = RobertaTokenizer.from_pretrained(model_path)
8
- model = RobertaForSequenceClassification.from_pretrained(model_path)
 
 
 
 
 
 
9
 
10
  # Define the prediction function
11
- def classify_text(text):
 
 
 
 
 
 
 
12
  # Remove periods and new lines from the input text
13
  cleaned_text = text.replace('.', '').replace('\n', ' ')
14
 
@@ -30,12 +43,15 @@ def classify_text(text):
30
  # Create the Gradio interface
31
  iface = gr.Interface(
32
  fn=classify_text,
33
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
 
 
 
34
  outputs="json",
35
- title="GoalZero Ada v0.1 Demo",
36
- description="Enter some text and get the probability of the text being written by AI. Full checkpoints of the model will be released soon.",
37
  )
38
 
39
  # Launch the app
40
  if __name__ == "__main__":
41
- iface.launch(share=True)
 
2
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
3
  import torch
4
 
5
+ # Define available models
6
+ model_options = {
7
+ "GoalZero/aidetection-ada-v0.2": "GoalZero/aidetection-ada-v0.2",
8
+ "GoalZero/aidetection-ada-v0.1": "GoalZero/aidetection-ada-v0.1"
9
+ }
10
+
11
+ # Initialize tokenizer and model with the default model
12
+ default_model = model_options["GoalZero/aidetection-ada-v0.2"]
13
+ tokenizer = RobertaTokenizer.from_pretrained(default_model)
14
+ model = RobertaForSequenceClassification.from_pretrained(default_model)
15
 
16
  # Define the prediction function
17
+ def classify_text(text, model_choice):
18
+ global model, tokenizer # Access the global model and tokenizer variables
19
+
20
+ # Check if the model needs to be changed
21
+ if model_choice != model.name_or_path:
22
+ model = RobertaForSequenceClassification.from_pretrained(model_choice)
23
+ tokenizer = RobertaTokenizer.from_pretrained(model_choice)
24
+
25
  # Remove periods and new lines from the input text
26
  cleaned_text = text.replace('.', '').replace('\n', ' ')
27
 
 
43
  # Create the Gradio interface
44
  iface = gr.Interface(
45
  fn=classify_text,
46
+ inputs=[
47
+ gr.Textbox(lines=2, placeholder="Enter text here..."),
48
+ gr.Dropdown(choices=list(model_options.keys()), value=default_model, label="Select Model")
49
+ ],
50
  outputs="json",
51
+ title="GoalZero Ada Model Selector",
52
+ description="Enter text to get the probability of it being AI-written. Select a model version to use.",
53
  )
54
 
55
  # Launch the app
56
  if __name__ == "__main__":
57
+ iface.launch(share=True)