jer233 commited on
Commit
2cc64b2
·
verified ·
1 Parent(s): 6741d4b

Update demo/demo.py

Browse files
Files changed (1) hide show
  1. demo/demo.py +34 -19
demo/demo.py CHANGED
@@ -1,43 +1,56 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
- # from MMD_calculate import mmd_two_sample_baseline # Adjust path based on your structure
4
- # from utils_MMD import extract_features # Example helper from your utils
5
 
6
  MINIMUM_TOKENS = 64
 
7
 
8
  def count_tokens(text, tokenizer):
 
 
 
9
  return len(tokenizer(text).input_ids)
10
 
11
- def run_test_power(model_name, tokenizer_name, real_text, generated_text, N):
12
  """
13
  Runs the test power calculation for provided real and generated texts.
14
- """
15
 
16
- # load tokenizer and model
17
- tokenizer = AutoTokenizer.from_pretrained(model_name).cuda()
18
- model = AutoModel.from_pretrained(model)
 
 
19
 
 
 
 
 
 
 
 
 
 
20
  if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS:
21
- return "Too short length. Need minimum 64 tokens to calculated Test Power."
22
 
23
  # Extract features
24
- fea_real_ls = extract_features(model_name, tokenizer_name, [real_text])
25
- fea_generated_ls = extract_features(model_name, tokenizer_name, [generated_text])
26
 
27
- # Calculate test power list
28
- test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=10)
29
 
30
  # Compute the average test power value
31
  power_test_value = sum(test_power_ls) / len(test_power_ls)
32
 
33
  # Classify the text
34
- if power_test_value < threshold:
35
  return "Prediction: Human"
36
  else:
37
  return "Prediction: AI"
38
 
39
-
40
-
41
  css = """
42
  #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
43
  #output-text { font-weight: bold; font-size: 1.2em; }
@@ -78,9 +91,9 @@ with gr.Blocks(css=css) as app:
78
  clear_button = gr.Button("Clear", variant="secondary")
79
  with gr.Row():
80
  output = gr.Textbox(
81
- label = "Prediction",
82
- placeholder = "Prediction: Human or AI",
83
- elem_id = "output-text",
84
  )
85
  with gr.Accordion("Disclaimer", open=False):
86
  gr.Markdown(
@@ -102,7 +115,9 @@ with gr.Blocks(css=css) as app:
102
  ```
103
  """
104
  )
105
- submit_button.click(detect_text, inputs=[input_text, model_name], outputs=output)
 
 
106
  clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
107
 
108
  app.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
+ from utils_MMD import extract_features # Adjust the import path
4
+ from MMD_calculate import mmd_two_sample_baseline # Adjust the import path
5
 
6
  MINIMUM_TOKENS = 64
7
+ THRESHOLD = 0.5 # Threshold for classification
8
 
9
  def count_tokens(text, tokenizer):
10
+ """
11
+ Counts the number of tokens in the text using the provided tokenizer.
12
+ """
13
  return len(tokenizer(text).input_ids)
14
 
15
+ def run_test_power(model_name, real_text, generated_text, N=10):
16
  """
17
  Runs the test power calculation for provided real and generated texts.
 
18
 
19
+ Args:
20
+ model_name (str): Hugging Face model name.
21
+ real_text (str): Example real text for comparison.
22
+ generated_text (str): The input text to classify.
23
+ N (int): Number of repetitions for MMD calculation.
24
 
25
+ Returns:
26
+ str: "Prediction: Human" or "Prediction: AI".
27
+ """
28
+ # Load tokenizer and model
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModel.from_pretrained(model_name).cuda()
31
+ model.eval()
32
+
33
+ # Ensure minimum token length
34
  if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS:
35
+ return "Too short length. Need a minimum of 64 tokens to calculate Test Power."
36
 
37
  # Extract features
38
+ fea_real_ls = extract_features([real_text], tokenizer, model)
39
+ fea_generated_ls = extract_features([generated_text], tokenizer, model)
40
 
41
+ # Calculate test power list
42
+ test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=N)
43
 
44
  # Compute the average test power value
45
  power_test_value = sum(test_power_ls) / len(test_power_ls)
46
 
47
  # Classify the text
48
+ if power_test_value < THRESHOLD:
49
  return "Prediction: Human"
50
  else:
51
  return "Prediction: AI"
52
 
53
+ # CSS for custom styling
 
54
  css = """
55
  #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
56
  #output-text { font-weight: bold; font-size: 1.2em; }
 
91
  clear_button = gr.Button("Clear", variant="secondary")
92
  with gr.Row():
93
  output = gr.Textbox(
94
+ label="Prediction",
95
+ placeholder="Prediction: Human or AI",
96
+ elem_id="output-text",
97
  )
98
  with gr.Accordion("Disclaimer", open=False):
99
  gr.Markdown(
 
115
  ```
116
  """
117
  )
118
+ submit_button.click(
119
+ run_test_power, inputs=[model_name, "The cat sat on the mat.", input_text], outputs=output
120
+ )
121
  clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
122
 
123
  app.launch()