rahideer commited on
Commit
855a31e
Β·
verified Β·
1 Parent(s): 7bf94f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -68
app.py CHANGED
@@ -1,69 +1,44 @@
1
- import zipfile
2
- import os
3
- import pandas as pd
4
- import torch
5
- from sentence_transformers import SentenceTransformer, util
6
- from transformers import pipeline
7
  import streamlit as st
8
-
9
- st.set_page_config(page_title="News Fact Checker", page_icon="πŸ“°")
10
-
11
- # Step 1: Unzip dataset
12
- @st.cache_data
13
- def extract_dataset():
14
- zip_path = "climate.zip"
15
- extract_dir = "climate_extracted"
16
-
17
- if not os.path.exists(extract_dir):
18
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
19
- zip_ref.extractall(extract_dir)
20
-
21
- train_path = os.path.join(extract_dir, "climate", "train")
22
-
23
- # Try CSV or TSV format detection
24
- try:
25
- df = pd.read_csv(train_path, header=None)
26
- except:
27
- df = pd.read_csv(train_path, sep='\t', header=None)
28
-
29
- df.columns = ["label", "title", "description"]
30
- df["text"] = df["title"].astype(str) + ". " + df["description"].astype(str)
31
- return df.head(1000)
32
-
33
- # Step 2: Load models
34
- @st.cache_resource
35
- def load_models():
36
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
37
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
38
- return embedder, summarizer
39
-
40
- st.title("πŸ“° News Fact Checker")
41
- st.markdown("Enter a **claim** about climate or news events. We'll pull relevant facts from real news and summarize them for you.")
42
-
43
- # Step 3: User input
44
- claim = st.text_input("πŸ” Enter your claim:")
45
- data = extract_dataset()
46
- embedder, summarizer = load_models()
47
-
48
- # Step 4: Fact checking
49
- if claim:
50
- with st.spinner("Searching news..."):
51
- corpus = data["text"].tolist()
52
- corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
53
- query_embedding = embedder.encode(claim, convert_to_tensor=True)
54
-
55
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=3)[0]
56
- top_passages = [corpus[hit['corpus_id']] for hit in hits]
57
-
58
- combined = " ".join(top_passages)
59
- if len(combined) > 1024:
60
- combined = combined[:1024]
61
-
62
- summary = summarizer(combined, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
63
-
64
- st.markdown("### βœ… Summary Based on News")
65
- st.success(summary)
66
-
67
- with st.expander("πŸ”Ž View Related News Snippets"):
68
- for i, passage in enumerate(top_passages, 1):
69
- st.markdown(f"**Snippet {i}:** {passage}")
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
5
+
6
+ # Load AG News dataset from Hugging Face
7
+ dataset = load_dataset("kk0105/ag-news", split="train")
8
+
9
+ # Tokenizer and Model setup for RAG
10
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
11
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="default")
12
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
13
+
14
+ # Function to generate response using RAG
15
+ def generate_answer(query):
16
+ # Tokenize input query
17
+ inputs = tokenizer(query, return_tensors="pt")
18
+
19
+ # Retrieve relevant documents from dataset
20
+ input_ids = inputs["input_ids"]
21
+ question_embedding = retriever.compute_question_embeddings(input_ids)
22
+ context_input_ids = retriever.retrieve(input_ids, question_embedding)
23
+
24
+ # Generate an answer using the retrieved context
25
+ outputs = model.generate(input_ids=input_ids, context_input_ids=context_input_ids)
26
+
27
+ # Decode the answer and return it
28
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+ return answer
30
+
31
+ # Streamlit interface
32
+ st.title("News Fact Checker")
33
+ st.write("""
34
+ **Welcome to the News Fact Checker!**
35
+ Input a claim or question about a news topic, and we will verify or refute it based on recent news snippets.
36
+ """)
37
+
38
+ # User input for claim
39
+ user_claim = st.text_input("Enter your claim or question:")
40
+
41
+ if user_claim:
42
+ with st.spinner('Fetching relevant news snippets...'):
43
+ answer = generate_answer(user_claim)
44
+ st.write(f"**Fact Check Answer:** {answer}")