karths commited on
Commit
0aa8067
·
verified ·
1 Parent(s): 8a40e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py CHANGED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import os
4
+ import requests
5
+ from huggingface_hub import AsyncInferenceClient
6
+
7
+ HF_TOKEN = os.getenv('HF_TOKEN')
8
+ api_url = os.getenv('API_URL')
9
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
10
+ client = AsyncInferenceClient(api_url)
11
+
12
+
13
+ system_message = """
14
+ Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics.
15
+ """
16
+ title = "Python Refactoring"
17
+ description = """
18
+ Please give it 3 to 4 minutes for the model to load and Run , consider using Python code with less than 120 lines of code due to GPU constrainst
19
+ """
20
+ css = """.toast-wrap { display: none !important } """
21
+ examples=[
22
+ ['Hello there! How are you doing?'],
23
+ ['Can you explain to me briefly what is Python programming language?'],
24
+ ['Explain the plot of Cinderella in a sentence.'],
25
+ ['How many hours does it take a man to eat a Helicopter?'],
26
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
27
+ ]
28
+
29
+
30
+ # Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023]
31
+ # Prompting style for Llama2 without using system prompt
32
+ # <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
33
+
34
+
35
+ # Stream text - stream tokens with InferenceClient from TGI
36
+ async def predict(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1,):
37
+
38
+ if system_prompt != "":
39
+ input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
40
+ else:
41
+ input_prompt = f"<s>[INST] "
42
+
43
+ temperature = float(temperature)
44
+ if temperature < 1e-2:
45
+ temperature = 1e-2
46
+ top_p = float(top_p)
47
+
48
+ for interaction in chatbot:
49
+ input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
50
+
51
+ input_prompt = input_prompt + str(message) + " [/INST] "
52
+
53
+ partial_message = ""
54
+ async for token in await client.text_generation(prompt=input_prompt,
55
+ max_new_tokens=max_new_tokens,
56
+ stream=True,
57
+ best_of=1,
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ do_sample=True,
61
+ repetition_penalty=repetition_penalty):
62
+ partial_message = partial_message + token
63
+ yield partial_message
64
+
65
+
66
+ # No Stream - batch produce tokens using TGI inference endpoint
67
+ def predict_batch(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, repetition_penalty=1.1):
68
+
69
+ if system_prompt != "":
70
+ input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
71
+ else:
72
+ input_prompt = f"<s>[INST] "
73
+
74
+ temperature = float(temperature)
75
+ if temperature < 1e-2:
76
+ temperature = 1e-2
77
+ top_p = float(top_p)
78
+
79
+ for interaction in chatbot:
80
+ input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
81
+
82
+ input_prompt = input_prompt + str(message) + " [/INST] "
83
+ print(f"input_prompt - {input_prompt}")
84
+
85
+ data = {
86
+ "inputs": input_prompt,
87
+ "parameters": {
88
+ "max_new_tokens":max_new_tokens,
89
+ "temperature":temperature,
90
+ "top_p":top_p,
91
+ "repetition_penalty":repetition_penalty,
92
+ "do_sample":True,
93
+ },
94
+ }
95
+
96
+ response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
97
+
98
+ if response.status_code == 200: # check if the request was successful
99
+ try:
100
+ json_obj = response.json()
101
+ if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
102
+ return json_obj[0]['generated_text']
103
+ elif 'error' in json_obj[0]:
104
+ return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt'
105
+ else:
106
+ print(f"Unexpected response: {json_obj[0]}")
107
+ except json.JSONDecodeError:
108
+ print(f"Failed to decode response as JSON: {response.text}")
109
+ else:
110
+ print(f"Request failed with status code {response.status_code}")
111
+
112
+
113
+
114
+ def vote(data: gr.LikeData):
115
+ if data.liked:
116
+ print("You upvoted this response: " + data.value)
117
+ else:
118
+ print("You downvoted this response: " + data.value)
119
+
120
+
121
+ additional_inputs=[
122
+ gr.Textbox("", label="Optional system prompt"),
123
+ gr.Slider(
124
+ label="Temperature",
125
+ value=0.9,
126
+ minimum=0.0,
127
+ maximum=1.0,
128
+ step=0.05,
129
+ interactive=True,
130
+ info="Higher values produce more diverse outputs",
131
+ ),
132
+ gr.Slider(
133
+ label="Max new tokens",
134
+ value=256,
135
+ minimum=0,
136
+ maximum=4096,
137
+ step=64,
138
+ interactive=True,
139
+ info="The maximum numbers of new tokens",
140
+ ),
141
+ gr.Slider(
142
+ label="Top-p (nucleus sampling)",
143
+ value=0.6,
144
+ minimum=0.0,
145
+ maximum=1,
146
+ step=0.05,
147
+ interactive=True,
148
+ info="Higher values sample more low-probability tokens",
149
+ ),
150
+ gr.Slider(
151
+ label="Repetition penalty",
152
+ value=1.2,
153
+ minimum=1.0,
154
+ maximum=2.0,
155
+ step=0.05,
156
+ interactive=True,
157
+ info="Penalize repeated tokens",
158
+ )
159
+ ]
160
+
161
+ chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
162
+ chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
163
+ chat_interface_stream = gr.ChatInterface(predict,
164
+ title=title,
165
+ description=description,
166
+ textbox=gr.Textbox(),
167
+ chatbot=chatbot_stream,
168
+ css=css,
169
+ examples=examples,
170
+ #cache_examples=True,
171
+ additional_inputs=additional_inputs,)
172
+ chat_interface_batch=gr.ChatInterface(predict_batch,
173
+ title=title,
174
+ description=description,
175
+ textbox=gr.Textbox(),
176
+ chatbot=chatbot_batch,
177
+ css=css,
178
+ examples=examples,
179
+ #cache_examples=True,
180
+ additional_inputs=additional_inputs,)
181
+
182
+ # Gradio Demo
183
+ with gr.Blocks() as demo:
184
+
185
+ with gr.Tab("Streaming"):
186
+ # streaming chatbot
187
+ chatbot_stream.like(vote, None, None)
188
+ chat_interface_stream.render()
189
+
190
+ with gr.Tab("Batch"):
191
+ # non-streaming chatbot
192
+ chatbot_batch.like(vote, None, None)
193
+ chat_interface_batch.render()
194
+
195
+ demo.queue(max_size=100).launch()