Shreyas94 commited on
Commit
eb9ecb0
·
verified ·
1 Parent(s): 9d2f962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -17
app.py CHANGED
@@ -4,8 +4,7 @@ import requests
4
  from bs4 import BeautifulSoup
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
- from huggingface_hub import InferenceClient
9
  import logging
10
 
11
  # Set up logging
@@ -14,12 +13,12 @@ logger = logging.getLogger(__name__)
14
 
15
  # Define device and load model and tokenizer
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
18
 
19
  # Load model and tokenizer
20
  try:
21
  logger.debug("Attempting to load the model and tokenizer")
22
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  logger.debug("Model and tokenizer loaded successfully")
25
  except Exception as e:
@@ -27,9 +26,6 @@ except Exception as e:
27
  model = None
28
  tokenizer = None
29
 
30
- # Assert to ensure tokenizer is loaded
31
- assert tokenizer is not None, "Tokenizer failed to load and is None"
32
-
33
  # Function to perform a Google search and return the results
34
  def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
35
  logger.debug(f"Starting search for term: {term}")
@@ -94,13 +90,10 @@ def extract_text_from_webpage(html_content):
94
  # Function to format the prompt for the language model
95
  def format_prompt(user_prompt, chat_history):
96
  logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
97
- prompt = "<s>"
98
  for item in chat_history:
99
- if isinstance(item, tuple):
100
- prompt += f"[INST] {item[0]} [/INST] {item[1]}</s>"
101
- else:
102
- prompt += f" [Image] "
103
- prompt += f"[INST] {user_prompt} [/INST]"
104
  logger.debug(f"Formatted prompt: {prompt}")
105
  return prompt
106
 
@@ -109,7 +102,6 @@ def model_inference(
109
  user_prompt,
110
  chat_history,
111
  web_search,
112
- decoding_strategy,
113
  temperature,
114
  max_new_tokens,
115
  repetition_penalty,
@@ -167,9 +159,9 @@ def model_inference(
167
 
168
  # Define Gradio interface components
169
  max_new_tokens = gr.Slider(
170
- minimum=2048,
171
  maximum=16000,
172
- value=4096,
173
  step=64,
174
  interactive=True,
175
  label="Maximum number of new tokens to generate",
@@ -231,7 +223,6 @@ def chat_interface(user_input, history, web_search, decoding_strategy, temperatu
231
  user_input,
232
  history,
233
  web_search,
234
- decoding_strategy,
235
  temperature,
236
  max_new_tokens,
237
  repetition_penalty,
 
4
  from bs4 import BeautifulSoup
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
8
  import logging
9
 
10
  # Set up logging
 
13
 
14
  # Define device and load model and tokenizer
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
17
 
18
  # Load model and tokenizer
19
  try:
20
  logger.debug("Attempting to load the model and tokenizer")
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  logger.debug("Model and tokenizer loaded successfully")
24
  except Exception as e:
 
26
  model = None
27
  tokenizer = None
28
 
 
 
 
29
  # Function to perform a Google search and return the results
30
  def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
31
  logger.debug(f"Starting search for term: {term}")
 
90
  # Function to format the prompt for the language model
91
  def format_prompt(user_prompt, chat_history):
92
  logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
93
+ prompt = ""
94
  for item in chat_history:
95
+ prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
96
+ prompt += f"User: {user_prompt}\nAssistant:"
 
 
 
97
  logger.debug(f"Formatted prompt: {prompt}")
98
  return prompt
99
 
 
102
  user_prompt,
103
  chat_history,
104
  web_search,
 
105
  temperature,
106
  max_new_tokens,
107
  repetition_penalty,
 
159
 
160
  # Define Gradio interface components
161
  max_new_tokens = gr.Slider(
162
+ minimum=1,
163
  maximum=16000,
164
+ value=2048,
165
  step=64,
166
  interactive=True,
167
  label="Maximum number of new tokens to generate",
 
223
  user_input,
224
  history,
225
  web_search,
 
226
  temperature,
227
  max_new_tokens,
228
  repetition_penalty,