SentinelAI102 / app.py
Shreyas94's picture
Update app.py
d8b5900 verified
raw
history blame
10.3 kB
import os
import urllib
import requests
from bs4 import BeautifulSoup
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import feedparser
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Define device and load model and tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
# Load model and tokenizer
try:
logger.debug("Attempting to load the model and tokenizer")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
logger.debug("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Error loading model and tokenizer: {e}")
model = None
tokenizer = None
# Function to fetch news from Google News RSS feed
def fetch_news(term, num_results=2):
logger.debug(f"Fetching news for term: {term}")
encoded_term = urllib.parse.quote(term)
url = f"https://news.google.com/rss/search?q={encoded_term}"
feed = feedparser.parse(url)
results = []
for entry in feed.entries[:num_results]:
results.append({"link": entry.link, "text": entry.title})
logger.debug(f"Fetched news results: {results}")
return results
# Function to perform a Google search and return the results
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
logger.debug(f"Starting search for term: {term}")
escaped_term = urllib.parse.quote_plus(term)
start = 0
all_results = []
max_chars_per_page = 8000
with requests.Session() as session:
while start < num_results:
try:
resp = session.get(
url="https://www.google.com/search",
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
params={
"q": term,
"num": num_results - start,
"hl": lang,
"start": start,
"safe": safe,
},
timeout=timeout,
verify=ssl_verify,
)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")
result_block = soup.find_all("div", attrs={"class": "g"})
if not result_block:
start += 1
continue
for result in result_block:
link = result.find("a", href=True)
if link:
link = link["href"]
try:
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
webpage.raise_for_status()
visible_text = extract_text_from_webpage(webpage.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page] + "..."
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException as e:
logger.error(f"Error fetching or processing {link}: {e}")
all_results.append({"link": link, "text": None})
else:
all_results.append({"link": None, "text": None})
start += len(result_block)
except Exception as e:
logger.error(f"Error during search: {e}")
break
logger.debug(f"Search results: {all_results}")
return all_results
# Function to extract visible text from HTML content
def extract_text_from_webpage(html_content):
soup = BeautifulSoup(html_content, "html.parser")
for tag in soup(["script", "style", "header", "footer", "nav"]):
tag.extract()
visible_text = soup.get_text(strip=True)
return visible_text
# Function to format the prompt for the language model
def format_prompt(user_prompt, chat_history):
logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
prompt = ""
for item in chat_history:
prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
prompt += f"User: {user_prompt}\nAssistant:"
logger.debug(f"Formatted prompt: {prompt}")
return prompt
# Function for model inference
def model_inference(
user_prompt,
chat_history,
web_search,
temperature,
max_new_tokens,
repetition_penalty,
top_p,
tokenizer # Pass tokenizer as an argument
):
logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}")
if not isinstance(user_prompt, dict):
logger.error("Invalid input format. Expected a dictionary.")
return "Invalid input format. Expected a dictionary."
if "files" not in user_prompt:
user_prompt["files"] = []
if not user_prompt["files"]:
if web_search:
logger.debug("Performing news search")
news_results = fetch_news(user_prompt["text"])
news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results])
formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
if model:
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
response = "Model is not available. Please try again later."
logger.debug(f"Model response: {response}")
return response
else:
formatted_prompt = format_prompt(user_prompt["text"], chat_history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
if model:
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
response = "Model is not available. Please try again later."
logger.debug(f"Model response: {response}")
return response
else:
return "Image input not supported in this implementation."
# Define Gradio interface components
max_new_tokens = gr.Slider(
minimum=1,
maximum=16000,
value=2048,
step=64,
interactive=True,
label="Maximum number of new tokens to generate",
)
repetition_penalty = gr.Slider(
minimum=0.01,
maximum=5.0,
value=1,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
)
decoding_strategy = gr.Radio(
[
"Greedy",
"Top P Sampling",
],
value="Top P Sampling",
label="Decoding strategy",
interactive=True,
info="Higher values are equivalent to sampling more low-probability tokens.",
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.5,
step=0.05,
visible=True,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.9,
step=0.01,
visible=True,
interactive=True,
label="Top P",
info="Higher values are equivalent to sampling more low-probability tokens.",
)
# Create a chatbot interface
chatbot = gr.Chatbot(
label="OpenGPT-4o-Chatty",
show_copy_button=True,
likeable=True,
layout="panel"
)
# Define Gradio interface
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
# Ensure the tokenizer is accessible within the function scope
global tokenizer
# Perform model inference
response = model_inference(
user_prompt=user_input,
chat_history=history,
web_search=web_search,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
top_p=top_p,
tokenizer=tokenizer # Pass tokenizer to the model_inference function
)
# Update the chat history with the new interaction
history.append([user_input, response])
# Return the updated history and the response
return history, response
# Define the Gradio interface components
interface = gr.Interface(
fn=chat_interface,
inputs=[
gr.Textbox(label="User Input", placeholder="Type your message here..."),
gr.State([]), # Chat history
gr.Checkbox(label="Perform Web Search"),
gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048),
gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
],
outputs=[gr.State([]), gr.Textbox(label="Assistant Response")],
live=True,
title="OpenGPT-4o-Chatty",
description="Chat with the AI and optionally perform web searches to enhance responses."
)
# Launch the Gradio interface
interface.launch()