shashwatashish's picture
Update app.py
6555a28 verified
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
MODEL_NAME = "valhalla/distilbart-mnli-12-1"
device = 0 if torch.cuda.is_available() else -1
@st.cache_resource
def load_zero_shot_pipeline():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer, device=device)
zero_shot = load_zero_shot_pipeline()
st.title("ArticleInsight (Demo Offline Pipeline)")
st.markdown(\"\"\"
**Upload a CSV** with an 'Abstract' column. We'll run a simple 8-step analysis:
1. Empirical Study?
2. Identify Construct
3. Sample Details
4. Main Research Question
5. Key Findings
6. Variables (IV/DV)
7. Antecedents, Outcomes
8. Unit of Analysis
**Disclaimer**: This is a *very naive* demonstration using zero-shot classification and simple regex.
It won't be super accurate, but requires no coding from you!
\"\"\")
uploaded_file = st.file_uploader("Upload CSV with 'Abstract' column")
if uploaded_file:
df = pd.read_csv(uploaded_file)
if "Abstract" not in df.columns:
st.error("CSV must have an 'Abstract' column.")
st.stop()
st.success("File uploaded successfully!")
if st.button("Run Analysis"):
with st.spinner("Analyzing each abstract..."):
df["Empirical Study"] = ""
df["Construct"] = ""
df["Sample Details"] = ""
df["Research Question"] = ""
df["Key Findings"] = ""
df["Variables"] = ""
df["Antecedents"] = ""
df["Outcomes"] = ""
df["Unit of Analysis"] = ""
for i, row in df.iterrows():
abstract = str(row["Abstract"])
df.at[i, "Empirical Study"] = classify_empirical(abstract)
if df.at[i, "Empirical Study"] == "Yes":
df.at[i, "Construct"] = find_constructs(abstract)
df.at[i, "Sample Details"] = extract_sample_details(abstract)
df.at[i, "Research Question"] = guess_research_question(abstract)
df.at[i, "Key Findings"] = guess_key_findings(abstract)
var, ants, outs = identify_variables(abstract)
df.at[i, "Variables"] = var
df.at[i, "Antecedents"] = ants
df.at[i, "Outcomes"] = outs
df.at[i, "Unit of Analysis"] = identify_unit_of_analysis(abstract)
else:
for col in ["Construct", "Sample Details", "Research Question", "Key Findings", "Variables", "Antecedents", "Outcomes", "Unit of Analysis"]:
df.at[i, col] = "N/A"
st.success("Done!")
st.dataframe(df.head(50))
csv_data = df.to_csv(index=False).encode("utf-8")
st.download_button("Download Analyzed CSV", data=csv_data, file_name="analysis_output.csv", mime="text/csv")
# === Functions ===
def classify_empirical(text):
candidate_labels = ["empirical study", "theoretical paper"]
res = zero_shot(text, candidate_labels)
top_label = res["labels"][0]
top_score = res["scores"][0]
if top_label == "empirical study" and top_score > 0.5:
return "Yes"
elif top_label == "theoretical paper" and top_score > 0.5:
return "No"
return "Unknown"
def find_constructs(text):
tokens = text.lower().split()
freq = {}
for w in tokens:
if len(w) > 5 and w.isalpha():
freq[w] = freq.get(w, 0) + 1
sorted_freq = sorted(freq.items(), key=lambda x: x[1], reverse=True)
if not sorted_freq:
return "Unknown"
return ", ".join([x[0] for x in sorted_freq[:2]])
def extract_sample_details(text):
import re
t = text.lower()
pattern = r"(n\s*=\s*\d+|sample of \d+|\d+\s+participants|\d+\s+subjects)"
matches = re.findall(pattern, t)
info = "; ".join([m[0] if isinstance(m, tuple) else m for m in matches]) if matches else ""
if "student" in t:
info += "; students"
if "employee" in t:
info += "; employees"
return info if info else "Unknown"
def guess_research_question(text):
lower = text.lower()
if "effect of" in lower:
idx = lower.index("effect of")
snippet = text[idx: idx+60]
return f"Does {snippet}?"
elif "aim of this study" in lower:
idx = lower.index("aim of this study")
snippet = text[idx: idx+60]
return snippet
return "Unknown"
def guess_key_findings(text):
lower = text.lower()
if "we find that" in lower:
idx = lower.index("we find that")
return text[idx: idx+100]
elif "results show" in lower:
idx = lower.index("results show")
return text[idx: idx+100]
return "Unknown"
def identify_variables(text):
import re
t = text.lower()
pattern = r"(impact|influence|effect) of (\w+) on (\w+)"
match = re.search(pattern, t)
if match:
iv = match.group(2)
dv = match.group(3)
return f"IV: {iv}, DV: {dv}", iv, dv
return "Unknown", "Unknown", "Unknown"
def identify_unit_of_analysis(text):
lower = text.lower()
if "team" in lower or "groups" in lower:
return "Team"
if "organization" in lower or "firm" in lower:
return "Organization"
if any(x in lower for x in ["participant", "individual", "student", "employee"]):
return "Individual"
return "Unknown"
import os
os.environ["STREAMLIT_SERVER_HEADLESS"] = "true"
os.environ["STREAMLIT_SERVER_ADDRESS"] = "0.0.0.0"