Shreyas94 commited on
Commit
03797ca
·
verified ·
1 Parent(s): 86399b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -156
app.py CHANGED
@@ -1,155 +1,98 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
  import logging
5
- import random
6
- import requests
7
- import urllib
8
- from bs4 import BeautifulSoup
9
- import os
10
 
11
- # Initialize logging
12
  logging.basicConfig(level=logging.DEBUG)
13
  logger = logging.getLogger(__name__)
14
 
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
-
17
- # List of user agents to choose from for requests
18
- _useragent_list = [
19
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0',
20
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
21
- 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
22
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36',
23
- 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
24
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.62',
25
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0'
26
- ]
27
-
28
- def get_useragent():
29
- """Returns a random user agent from the list."""
30
- return random.choice(_useragent_list)
31
-
32
- def extract_text_from_webpage(html_content):
33
- """Extracts visible text from HTML content using BeautifulSoup."""
34
- soup = BeautifulSoup(html_content, "html.parser")
35
- # Remove unwanted tags
36
- for tag in soup(["script", "style", "header", "footer", "nav"]):
37
- tag.extract()
38
- # Get the remaining visible text
39
- visible_text = soup.get_text(strip=True)
40
- return visible_text
41
-
42
- def search(term, num_results=1, lang="en", advanced=True, sleep_interval=0, timeout=5, safe="active", ssl_verify=None):
43
- """Performs a Google search and returns the results."""
44
- escaped_term = urllib.parse.quote_plus(term)
45
- start = 0
46
- all_results = []
47
-
48
- # Fetch results in batches
49
- while start < num_results:
50
- resp = requests.get(
51
- url="https://www.google.com/search",
52
- headers={"User-Agent": get_useragent()}, # Set random user agent
53
- params={
54
- "q": term,
55
- "num": num_results - start, # Number of results to fetch in this batch
56
- "hl": lang,
57
- "start": start,
58
- "safe": safe,
59
- },
60
- timeout=timeout,
61
- verify=ssl_verify,
62
- )
63
- resp.raise_for_status() # Raise an exception if request fails
64
-
65
- soup = BeautifulSoup(resp.text, "html.parser")
66
- result_block = soup.find_all("div", attrs={"class": "g"})
67
-
68
- # If no results, continue to the next batch
69
- if not result_block:
70
- start += 1
71
- continue
72
-
73
- # Extract link and text from each result
74
- for result in result_block:
75
- link = result.find("a", href=True)
76
- if link:
77
- link = link["href"]
78
- try:
79
- # Fetch webpage content
80
- webpage = requests.get(link, headers={"User-Agent": get_useragent()})
81
- webpage.raise_for_status()
82
- # Extract visible text from webpage
83
- visible_text = extract_text_from_webpage(webpage.text)
84
- all_results.append({"link": link, "text": visible_text})
85
- except requests.exceptions.RequestException as e:
86
- # Handle errors fetching or processing webpage
87
- print(f"Error fetching or processing {link}: {e}")
88
- all_results.append({"link": link, "text": None})
89
- else:
90
- all_results.append({"link": None, "text": None})
91
-
92
- start += len(result_block) # Update starting index for next batch
93
-
94
- return all_results
95
-
96
- # Load the model and tokenizer
97
- model_name = "mistralai/Mistral-7B-v0.3"
98
- tokenizer = AutoTokenizer.from_pretrained(model_name)
99
- model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
100
-
101
- def format_prompt(user_input, chat_history):
102
  prompt = ""
103
- for user, bot in chat_history:
104
- prompt += f"User: {user}\nBot: {bot}\n"
105
- prompt += f"User: {user_input}\nBot: "
 
106
  return prompt
107
 
108
- def model_inference(user_prompt, chat_history, web_search, temperature, max_new_tokens, repetition_penalty, top_p):
109
- try:
110
- if not user_prompt["files"]:
111
- if web_search:
112
- logger.debug("Performing web search")
113
- web_results = search(user_prompt["text"], num_results=3) # Fetching more results for better context
114
- web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']])
115
- formatted_prompt = format_prompt(f"{user_prompt['text']} [WEB] {web2}", chat_history)
116
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
117
- if model:
118
- outputs = model.generate(
119
- **inputs,
120
- max_new_tokens=max_new_tokens,
121
- repetition_penalty=repetition_penalty,
122
- do_sample=True,
123
- temperature=temperature,
124
- top_p=top_p
125
- )
126
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
127
- else:
128
- response = "Model is not available. Please try again later."
129
- logger.debug(f"Model response: {response}")
130
- return response
131
- else:
132
- formatted_prompt = format_prompt(user_prompt["text"], chat_history)
133
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
134
- if model:
135
- outputs = model.generate(
136
- **inputs,
137
- max_new_tokens=max_new_tokens,
138
- repetition_penalty=repetition_penalty,
139
- do_sample=True,
140
- temperature=temperature,
141
- top_p=top_p
142
- )
143
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
- else:
145
- response = "Model is not available. Please try again later."
146
- logger.debug(f"Model response: {response}")
147
- return response
148
  else:
149
- return "Image input not supported in this implementation."
150
- except Exception as e:
151
- logger.error(f"Error during model inference: {e}")
152
- return "An error occurred during model inference. Please try again."
 
153
 
154
  # Define Gradio interface components
155
  max_new_tokens = gr.Slider(
@@ -183,27 +126,71 @@ temperature = gr.Slider(
183
  minimum=0.0,
184
  maximum=2.0,
185
  value=0.5,
186
- step=0.1,
 
187
  interactive=True,
188
- label="Temperature",
189
- info="Control randomness: lower temperature produces less randomness.",
 
 
 
 
 
 
 
 
 
 
190
  )
191
- web_search = gr.Checkbox(label="Enable Web Search", default=False, description="Enable web search for better responses")
192
 
193
- # Define the Gradio interface
194
- gr.Interface(
195
- fn=model_inference,
196
- inputs=[
197
- gr.Textbox(label="User Input", placeholder="Type your input here..."),
198
- gr.MultiText(label="Chat History", placeholder="User: ...\nBot: ...", optional=True),
 
 
 
 
 
 
 
 
 
 
199
  web_search,
200
  temperature,
201
  max_new_tokens,
202
  repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  decoding_strategy,
 
 
 
 
 
 
 
 
204
  ],
205
- outputs=gr.Textbox(label="AI Response"),
206
- live=True,
207
- title="OpenGPT 4o Demo",
208
- description="An AI-powered assistant that can chat with you and provide informative responses.",
209
- ).launch()
 
 
 
1
+ import os
2
+ import urllib
3
+ import requests
4
+ import feedparser
5
+ from bs4 import BeautifulSoup
6
+ import torch
7
  import gradio as gr
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
  import logging
 
 
 
 
 
10
 
11
+ # Set up logging
12
  logging.basicConfig(level=logging.DEBUG)
13
  logger = logging.getLogger(__name__)
14
 
15
+ # Define device and load model and tokenizer
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
18
+
19
+ # Load model and tokenizer
20
+ try:
21
+ logger.debug("Attempting to load the model and tokenizer")
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ logger.debug("Model and tokenizer loaded successfully")
25
+ except Exception as e:
26
+ logger.error(f"Error loading model and tokenizer: {e}")
27
+ model = None
28
+ tokenizer = None
29
+
30
+ # Function to fetch news from Google News RSS feed
31
+ def fetch_news(term, num_results=2):
32
+ logger.debug(f"Fetching news for term: {term}")
33
+ url = f"https://news.google.com/rss/search?q={term}"
34
+ feed = feedparser.parse(url)
35
+ results = []
36
+ for entry in feed.entries[:num_results]:
37
+ results.append({"link": entry.link, "text": entry.title})
38
+ logger.debug(f"Fetched news results: {results}")
39
+ return results
40
+
41
+ # Function to format the prompt for the language model
42
+ def format_prompt(user_prompt, chat_history):
43
+ logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  prompt = ""
45
+ for item in chat_history:
46
+ prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
47
+ prompt += f"User: {user_prompt}\nAssistant:"
48
+ logger.debug(f"Formatted prompt: {prompt}")
49
  return prompt
50
 
51
+ # Function for model inference
52
+ def model_inference(
53
+ user_prompt,
54
+ chat_history,
55
+ web_search,
56
+ temperature,
57
+ max_new_tokens,
58
+ repetition_penalty,
59
+ top_p,
60
+ tokenizer # Pass tokenizer as an argument
61
+ ):
62
+ logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}")
63
+ if not isinstance(user_prompt, dict):
64
+ logger.error("Invalid input format. Expected a dictionary.")
65
+ return "Invalid input format. Expected a dictionary."
66
+
67
+ if "files" not in user_prompt:
68
+ user_prompt["files"] = []
69
+
70
+ if not user_prompt["files"]:
71
+ if web_search:
72
+ logger.debug("Performing news search")
73
+ news_results = fetch_news(user_prompt["text"])
74
+ news_text = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results])
75
+ formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news_text}", chat_history)
76
+ else:
77
+ formatted_prompt = format_prompt(user_prompt["text"], chat_history)
78
+
79
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
80
+ if model:
81
+ outputs = model.generate(
82
+ **inputs,
83
+ max_new_tokens=max_new_tokens,
84
+ repetition_penalty=repetition_penalty,
85
+ do_sample=True,
86
+ temperature=temperature,
87
+ top_p=top_p
88
+ )
89
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
90
  else:
91
+ response = "Model is not available. Please try again later."
92
+ logger.debug(f"Model response: {response}")
93
+ return response
94
+ else:
95
+ return "Image input not supported in this implementation."
96
 
97
  # Define Gradio interface components
98
  max_new_tokens = gr.Slider(
 
126
  minimum=0.0,
127
  maximum=2.0,
128
  value=0.5,
129
+ step=0.05,
130
+ visible=True,
131
  interactive=True,
132
+ label="Sampling temperature",
133
+ info="Higher values will produce more diverse outputs.",
134
+ )
135
+ top_p = gr.Slider(
136
+ minimum=0.01,
137
+ maximum=0.99,
138
+ value=0.9,
139
+ step=0.01,
140
+ visible=True,
141
+ interactive=True,
142
+ label="Top P",
143
+ info="Higher values are equivalent to sampling more low-probability tokens.",
144
  )
 
145
 
146
+ # Create a chatbot interface
147
+ chatbot = gr.Chatbot(
148
+ label="OpenGPT-4o-Chatty",
149
+ show_copy_button=True,
150
+ likeable=True,
151
+ layout="panel"
152
+ )
153
+
154
+ # Define Gradio interface
155
+ def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
156
+ logger.debug(f"Chat interface called with user_input: {user_input}")
157
+ if isinstance(user_input, str):
158
+ user_input = {"text": user_input, "files": []}
159
+ response = model_inference(
160
+ user_input,
161
+ history,
162
  web_search,
163
  temperature,
164
  max_new_tokens,
165
  repetition_penalty,
166
+ top_p,
167
+ tokenizer=tokenizer # Pass tokenizer to model_inference
168
+ )
169
+ history.append((user_input["text"], response))
170
+ logger.debug(f"Updated chat history: {history}")
171
+ return history, history
172
+
173
+ # Create Gradio interface
174
+ interface = gr.Interface(
175
+ fn=chat_interface,
176
+ inputs=[
177
+ gr.Textbox(label="User Input"),
178
+ gr.State([]),
179
+ gr.Checkbox(label="Fetch News", value=True),
180
  decoding_strategy,
181
+ temperature,
182
+ max_new_tokens,
183
+ repetition_penalty,
184
+ top_p
185
+ ],
186
+ outputs=[
187
+ chatbot,
188
+ gr.State([])
189
  ],
190
+ title="OpenGPT-4o-Chatty",
191
+ description="An AI assistant capable of insightful conversations and news fetching."
192
+ )
193
+
194
+ if __name__ == "__main__":
195
+ logger.debug("Launching Gradio interface")
196
+ interface.launch()