Spaces:
Sleeping
Sleeping
File size: 6,825 Bytes
0aa8067 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import json
import gradio as gr
import os
import requests
from huggingface_hub import AsyncInferenceClient
HF_TOKEN = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL')
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
client = AsyncInferenceClient(api_url)
system_message = """
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics.
"""
title = "Python Refactoring"
description = """
Please give it 3 to 4 minutes for the model to load and Run , consider using Python code with less than 120 lines of code due to GPU constrainst
"""
css = """.toast-wrap { display: none !important } """
examples=[
['Hello there! How are you doing?'],
['Can you explain to me briefly what is Python programming language?'],
['Explain the plot of Cinderella in a sentence.'],
['How many hours does it take a man to eat a Helicopter?'],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
]
# Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023]
# Prompting style for Llama2 without using system prompt
# <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
# Stream text - stream tokens with InferenceClient from TGI
async def predict(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1,):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
partial_message = ""
async for token in await client.text_generation(prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty):
partial_message = partial_message + token
yield partial_message
# No Stream - batch produce tokens using TGI inference endpoint
def predict_batch(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
print(f"input_prompt - {input_prompt}")
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens":max_new_tokens,
"temperature":temperature,
"top_p":top_p,
"repetition_penalty":repetition_penalty,
"do_sample":True,
},
}
response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
if response.status_code == 200: # check if the request was successful
try:
json_obj = response.json()
if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
return json_obj[0]['generated_text']
elif 'error' in json_obj[0]:
return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt'
else:
print(f"Unexpected response: {json_obj[0]}")
except json.JSONDecodeError:
print(f"Failed to decode response as JSON: {response.text}")
else:
print(f"Request failed with status code {response.status_code}")
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_stream,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
chat_interface_batch=gr.ChatInterface(predict_batch,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_batch,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
# Gradio Demo
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
# streaming chatbot
chatbot_stream.like(vote, None, None)
chat_interface_stream.render()
with gr.Tab("Batch"):
# non-streaming chatbot
chatbot_batch.like(vote, None, None)
chat_interface_batch.render()
demo.queue(max_size=100).launch() |