karths's picture
Update app.py
0aa8067 verified
raw
history blame
6.83 kB
import json
import gradio as gr
import os
import requests
from huggingface_hub import AsyncInferenceClient
HF_TOKEN = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL')
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
client = AsyncInferenceClient(api_url)
system_message = """
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics.
"""
title = "Python Refactoring"
description = """
Please give it 3 to 4 minutes for the model to load and Run , consider using Python code with less than 120 lines of code due to GPU constrainst
"""
css = """.toast-wrap { display: none !important } """
examples=[
['Hello there! How are you doing?'],
['Can you explain to me briefly what is Python programming language?'],
['Explain the plot of Cinderella in a sentence.'],
['How many hours does it take a man to eat a Helicopter?'],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
]
# Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023]
# Prompting style for Llama2 without using system prompt
# <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
# Stream text - stream tokens with InferenceClient from TGI
async def predict(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1,):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
partial_message = ""
async for token in await client.text_generation(prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty):
partial_message = partial_message + token
yield partial_message
# No Stream - batch produce tokens using TGI inference endpoint
def predict_batch(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
print(f"input_prompt - {input_prompt}")
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens":max_new_tokens,
"temperature":temperature,
"top_p":top_p,
"repetition_penalty":repetition_penalty,
"do_sample":True,
},
}
response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
if response.status_code == 200: # check if the request was successful
try:
json_obj = response.json()
if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
return json_obj[0]['generated_text']
elif 'error' in json_obj[0]:
return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt'
else:
print(f"Unexpected response: {json_obj[0]}")
except json.JSONDecodeError:
print(f"Failed to decode response as JSON: {response.text}")
else:
print(f"Request failed with status code {response.status_code}")
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_stream,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
chat_interface_batch=gr.ChatInterface(predict_batch,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_batch,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
# Gradio Demo
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
# streaming chatbot
chatbot_stream.like(vote, None, None)
chat_interface_stream.render()
with gr.Tab("Batch"):
# non-streaming chatbot
chatbot_batch.like(vote, None, None)
chat_interface_batch.render()
demo.queue(max_size=100).launch()