dejanseo commited on
Commit
4db9ce2
·
verified ·
1 Parent(s): 0598948

Upload 3 files

Browse files
Files changed (3) hide show
  1. .streamlit/config.toml +7 -0
  2. app.py +203 -0
  3. requirements.txt +5 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "light"
3
+ primaryColor = "#4a90e2"
4
+ backgroundColor = "#ffffff"
5
+ secondaryBackgroundColor = "#f0f2f6"
6
+ textColor = "#000000"
7
+ font = "roboto"
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import re
6
+ import logging # Optional: Add logging for better debugging
7
+
8
+ # Set up logging (optional but helpful)
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Set the page configuration
13
+ st.set_page_config(
14
+ page_title="AI Article Detection by DEJAN",
15
+ page_icon="🧠",
16
+ layout="wide"
17
+ )
18
+
19
+ # Logo as provided
20
+ st.logo(
21
+ image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
22
+ link="https://dejan.ai/",
23
+ # size="large" # 'size' is not a valid argument for st.logo as of Streamlit 1.34 - remove or adjust if needed
24
+ )
25
+
26
+ # Font styling
27
+ st.markdown("""
28
+ <link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet">
29
+ <style>
30
+ html, body, [class*="css"] {
31
+ font-family: 'Roboto', sans-serif;
32
+ }
33
+ </style>
34
+ """, unsafe_allow_html=True)
35
+
36
+ @st.cache_resource # Cache the model and tokenizer to avoid reloading on every interaction
37
+ def load_model_and_tokenizer(model_name):
38
+ """Loads the model and tokenizer."""
39
+ logger.info(f"Loading tokenizer: {model_name}")
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+
42
+ # Determine device
43
+ device_type = "cuda" if torch.cuda.is_available() else "cpu"
44
+ # Use bfloat16 if available on CUDA for potential speedup/memory saving, else float32
45
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
46
+ logger.info(f"Using device: {device_type} with dtype: {dtype}")
47
+
48
+ logger.info(f"Loading model: {model_name}")
49
+ # Load model onto CPU first, then move to target device
50
+ model = AutoModelForSequenceClassification.from_pretrained(
51
+ model_name,
52
+ torch_dtype=dtype # Use the determined dtype
53
+ # Removed device_map="auto"
54
+ )
55
+ logger.info("Moving model to target device...")
56
+ model.to(torch.device(device_type)) # Move the entire model to the target device
57
+ model.eval() # Set model to evaluation mode
58
+ logger.info("Model loaded successfully.")
59
+ return tokenizer, model, torch.device(device_type)
60
+
61
+ # Load model and tokenizer using the cached function
62
+ MODEL_NAME = "dejanseo/ai-detection-base"
63
+ try:
64
+ tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
65
+ except Exception as e:
66
+ st.error(f"Error loading model: {e}")
67
+ logger.error(f"Failed to load model or tokenizer: {e}", exc_info=True)
68
+ st.stop() # Stop execution if model loading fails
69
+
70
+
71
+ # Static settings
72
+ LABELS = ["AI Content", "Human Content"]
73
+ COLORS = ["#ffe5e5", "#e6ffe6"] # light red, light green
74
+
75
+ # Regex-based sentence splitter (improved slightly for robustness)
76
+ def sent_tokenize(text):
77
+ # Split by '.', '!', '?' followed by space(s) or end of string
78
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
79
+ # Filter out empty strings that might result from splitting
80
+ return [s for s in sentences if s]
81
+
82
+ def split_into_chunks(text, tokenizer, max_length=512):
83
+ sentences = sent_tokenize(text)
84
+ if not sentences:
85
+ return [] # Handle empty input after tokenization
86
+
87
+ chunks, current_chunk_sentences, current_len = [], [], 0
88
+ max_tokens = max_length - 2 # Account for [CLS] and [SEP] tokens
89
+
90
+ for sent in sentences:
91
+ # Use tokenizer.encode to get accurate token count (more reliable than tokenize)
92
+ token_ids = tokenizer.encode(sent, add_special_tokens=False)
93
+ token_len = len(token_ids)
94
+
95
+ if token_len > max_tokens:
96
+ # Sentence is too long even by itself, handle appropriately
97
+ # Option 1: Truncate the sentence (simplest)
98
+ logger.warning(f"Sentence truncated as it exceeds max_length: '{sent[:100]}...'")
99
+ truncated_sent = tokenizer.decode(token_ids[:max_tokens])
100
+ # If there was a previous chunk, add it first
101
+ if current_chunk_sentences:
102
+ chunks.append(" ".join(current_chunk_sentences))
103
+ chunks.append(truncated_sent) # Add the single truncated sentence as its own chunk
104
+ current_chunk_sentences, current_len = [], 0 # Reset chunk
105
+ continue # Move to the next sentence
106
+
107
+ if current_len + token_len <= max_tokens:
108
+ current_chunk_sentences.append(sent)
109
+ current_len += token_len
110
+ else:
111
+ # Current chunk is full, finalize it
112
+ if current_chunk_sentences:
113
+ chunks.append(" ".join(current_chunk_sentences))
114
+ # Start a new chunk with the current sentence
115
+ current_chunk_sentences = [sent]
116
+ current_len = token_len
117
+
118
+ # Add the last remaining chunk
119
+ if current_chunk_sentences:
120
+ chunks.append(" ".join(current_chunk_sentences))
121
+
122
+ return chunks
123
+
124
+ # --- UI ---
125
+ st.title("AI Article Detection")
126
+ text = st.text_area("Enter text to classify", height=150, placeholder="Paste your text here...")
127
+
128
+ if st.button("Classify", type="primary"):
129
+ if not text or not text.strip():
130
+ st.warning("Please enter some text.")
131
+ else:
132
+ with st.spinner("Analyzing... Please wait."):
133
+ try:
134
+ # Split text using the tokenizer reference
135
+ chunks = split_into_chunks(text, tokenizer, max_length=model.config.max_position_embeddings)
136
+ logger.info(f"Split text into {len(chunks)} chunks.")
137
+
138
+ if not chunks:
139
+ st.warning("Could not process the input text (perhaps it's too short or contains only delimiters?).")
140
+ st.stop()
141
+
142
+ # Tokenize chunks and move tensors to the correct device
143
+ inputs = tokenizer(
144
+ chunks,
145
+ return_tensors="pt",
146
+ padding=True, # Pad sequences to the longest in the batch
147
+ truncation=True, # Truncate sequences longer than max_length
148
+ max_length=model.config.max_position_embeddings # Use model's max length
149
+ ).to(device) # Move inputs to the same device as the model
150
+
151
+ # Perform inference
152
+ with torch.no_grad():
153
+ outputs = model(**inputs)
154
+ logits = outputs.logits
155
+ # Ensure probabilities are calculated on CPU if needed for aggregation later
156
+ probs = F.softmax(logits, dim=-1).cpu() # Move probs to CPU
157
+ preds = torch.argmax(probs, dim=-1) # Argmax on CPU probabilities
158
+
159
+ # Process results
160
+ chunk_results = []
161
+ for i, chunk in enumerate(chunks):
162
+ pred_index = preds[i].item() # Get prediction index for this chunk
163
+ chunk_results.append({
164
+ "text": chunk,
165
+ "label": LABELS[pred_index],
166
+ "color": COLORS[pred_index],
167
+ "conf": probs[i, pred_index].item() * 100, # Get confidence for the predicted class
168
+ })
169
+
170
+ # Calculate overall prediction based on average probability across chunks
171
+ if probs.numel() > 0: # Check if probs tensor is not empty
172
+ avg_probs = torch.mean(probs, dim=0) # Average probabilities across the batch dimension
173
+ final_class_index = torch.argmax(avg_probs).item()
174
+ final_label = LABELS[final_class_index]
175
+ final_conf = avg_probs[final_class_index].item() * 100
176
+
177
+ # Display final prediction
178
+ st.subheader("📊 Final Prediction")
179
+ st.markdown(
180
+ f"<div style='background-color:{COLORS[final_class_index]}; padding:1rem; border-radius:0.5rem; border: 1px solid #ccc;'>"
181
+ f"Based on the analysis, the text is most likely: <b>{final_label}</b> (Confidence: {final_conf:.1f}%)</div>",
182
+ unsafe_allow_html=True
183
+ )
184
+ else:
185
+ st.warning("Could not generate predictions for the provided text.")
186
+
187
+
188
+ # Display per-chunk predictions in an expander
189
+ with st.expander("See per-chunk predictions and confidence"):
190
+ if chunk_results:
191
+ for result in chunk_results:
192
+ st.markdown(
193
+ f"<div title='Confidence: {result['conf']:.1f}%' "
194
+ f"style='background-color:{result['color']}; padding:0.75rem; margin-bottom:0.5rem; border-radius:0.5rem; border: 1px solid #ddd;'>"
195
+ f"<i>({result['label']} - {result['conf']:.1f}%)</i><br>{result['text']}</div>",
196
+ unsafe_allow_html=True
197
+ )
198
+ else:
199
+ st.write("No chunk predictions were generated.")
200
+
201
+ except Exception as e:
202
+ st.error(f"An error occurred during analysis: {e}")
203
+ logger.error(f"Analysis failed: {e}", exc_info=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ nltk
5
+ accelerate