dejanseo's picture
Update app.py
d6bd636 verified
raw
history blame
5.17 kB
import os
import json
import ast
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
import math
import logging
st.set_page_config(
page_title="AI Article Detection by DEJAN",
page_icon="🧠",
layout="wide"
)
# --- Load heuristic weights from environment secrets, with JSON→Python fallback ---
@st.cache_resource
def load_heuristic_weights():
def _load(env_key):
raw = os.environ[env_key]
try:
return json.loads(raw)
except json.JSONDecodeError:
return ast.literal_eval(raw)
ai = _load("AI_WEIGHTS_JSON")
og = _load("OG_WEIGHTS_JSON")
return ai, og
AI_WEIGHTS, OG_WEIGHTS = load_heuristic_weights()
SIGMOID_K = 0.5
def tokenize(text):
return re.findall(r'\b[a-z]{2,}\b', text.lower())
def classify_text_likelihood(text: str) -> float:
tokens = tokenize(text)
if not tokens:
return 0.5
ai_score = og_score = matched = 0
for t in tokens:
aw = AI_WEIGHTS.get(t, 0)
ow = OG_WEIGHTS.get(t, 0)
if aw or ow:
matched += 1
ai_score += aw
og_score += ow
if matched == 0:
return 0.5
net = ai_score - og_score
return 1 / (1 + math.exp(-SIGMOID_K * net))
def highlight_heuristic_words(text: str) -> str:
parts = re.split(r'(\b[a-z]{2,}\b)', text)
out = []
for part in parts:
lower = part.lower()
if lower in AI_WEIGHTS:
out.append(
f"<span style='text-decoration: underline; "
f"text-decoration-color: darkred; text-decoration-thickness: 2px;'>"
f"{part}</span>"
)
elif lower in OG_WEIGHTS:
out.append(
f"<span style='text-decoration: underline; "
f"text-decoration-color: darkgreen; text-decoration-thickness: 2px;'>"
f"{part}</span>"
)
else:
out.append(part)
return ''.join(out)
# --- Logging & Streamlit setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
st.markdown("""
<link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet">
<style>
html, body, [class*="css"] {
font-family: 'Roboto', sans-serif;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (device.type=="cuda" and torch.cuda.is_bf16_supported()) else torch.float32
model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=dtype)
model.to(device).eval()
return tokenizer, model, device
MODEL_NAME = "dejanseo/ai-detection-small"
try:
tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
except Exception as e:
st.error(f"Error loading model: {e}")
logger.error(f"Failed to load model: {e}", exc_info=True)
st.stop()
def sent_tokenize(text):
return [s for s in re.split(r'(?<=[\.!?])\s+', text.strip()) if s]
st.title("AI Article Detection")
text = st.text_area("Enter text to classify", height=200, placeholder="Paste your text here…")
if st.button("Classify", type="primary"):
if not text.strip():
st.warning("Please enter some text.")
else:
with st.spinner("Analyzing…"):
sentences = sent_tokenize(text)
if not sentences:
st.warning("No sentences detected.")
st.stop()
inputs = tokenizer(
sentences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=model.config.max_position_embeddings
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1).cpu()
preds = torch.argmax(probs, dim=-1).cpu()
chunks = []
for i, s in enumerate(sentences):
inner = highlight_heuristic_words(s)
p = preds[i].item()
r, g = (255, 0) if p == 0 else (0, 255)
conf = probs[i, p].item()
alpha = conf
span = (
f"<span style='background-color: rgba({r},{g},0,{alpha:.2f}); "
f"padding:2px; margin:0 2px; border-radius:3px;'>{inner}</span>"
)
chunks.append(span)
st.markdown("".join(chunks), unsafe_allow_html=True)
avg = torch.mean(probs, dim=0)
model_ai = avg[0].item()
heuristic_ai = classify_text_likelihood(text)
combined = min(model_ai + heuristic_ai, 1.0)
st.subheader(f"🤖 Model AI Likelihood: {model_ai*100:.1f}%")
st.subheader(f"🛠️ Heuristic AI Likelihood: {heuristic_ai*100:.1f}%")
st.subheader(f"⚖️ Combined AI Likelihood: {combined*100:.1f}%")