|
import gradio as gr |
|
import torch |
|
from gradio.themes.utils import sizes |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import utils |
|
from constants import END_OF_TEXT |
|
from settings import DEFAULT_PORT |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"BEE-spoke-data/smol_llama-101M-GQA-python", |
|
use_fast=False, |
|
) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.pad_token = END_OF_TEXT |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"BEE-spoke-data/smol_llama-101M-GQA-python", |
|
device_map="auto", |
|
) |
|
model = torch.compile(model, mode="reduce-overhead") |
|
|
|
|
|
|
|
_styles = utils.get_file_as_string("styles.css") |
|
|
|
|
|
readme_file_content = utils.get_file_as_string("README.md", path="./") |
|
( |
|
manifest, |
|
description, |
|
disclaimer, |
|
base_model_info, |
|
formats, |
|
) = utils.get_sections(readme_file_content, "---", up_to=5) |
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="yellow", |
|
secondary_hue="orange", |
|
neutral_hue="slate", |
|
radius_size=sizes.radius_sm, |
|
font=[ |
|
gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
text_size=sizes.text_lg, |
|
) |
|
|
|
|
|
def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty): |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=8, |
|
renormalize_logits=True, |
|
no_repeat_ngram_size=6, |
|
repetition_penalty=repetition_penalty, |
|
num_beams=3, |
|
early_stopping=True, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return text |
|
|
|
|
|
|
|
def gradio_interface( |
|
prompt: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
top_p: float, |
|
repetition_penalty: float, |
|
): |
|
return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty) |
|
|
|
|
|
import random |
|
|
|
examples = [ |
|
["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2], |
|
[ |
|
"class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
"import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
"def factorial(n):\n if n == 0:\n return 1\n else:", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:', |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
"import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2], |
|
["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2], |
|
[ |
|
"def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
"def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:", |
|
0.2, |
|
192, |
|
0.9, |
|
1.2, |
|
], |
|
] |
|
|
|
|
|
with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: |
|
with gr.Column(): |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
with gr.Column(): |
|
instruction = gr.Textbox( |
|
value=random.choice([e[0] for e in examples]), |
|
placeholder="Enter your code here", |
|
label="Code", |
|
elem_id="q-input", |
|
) |
|
submit = gr.Button("Generate", variant="primary") |
|
output = gr.Code(elem_id="q-output", language="python", lines=10) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Accordion("Advanced settings", open=False): |
|
with gr.Row(): |
|
column_1, column_2 = gr.Column(), gr.Column() |
|
with column_1: |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
value=0.2, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
) |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
value=128, |
|
minimum=0, |
|
maximum=512, |
|
step=64, |
|
interactive=True, |
|
info="Number of tokens to generate", |
|
) |
|
with column_2: |
|
top_p = gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.90, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
) |
|
repetition_penalty = gr.Slider( |
|
label="Repetition penalty", |
|
value=1.1, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
with gr.Column(): |
|
version = gr.Dropdown( |
|
[ |
|
"smol_llama-101M-GQA-python", |
|
], |
|
value="smol_llama-101M-GQA-python", |
|
label="Version", |
|
info="", |
|
) |
|
gr.Markdown(disclaimer) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
instruction, |
|
temperature, |
|
max_new_tokens, |
|
top_p, |
|
repetition_penalty, |
|
version, |
|
], |
|
cache_examples=False, |
|
fn=gradio_interface, |
|
outputs=[output], |
|
) |
|
gr.Markdown(base_model_info) |
|
gr.Markdown(formats) |
|
|
|
submit.click( |
|
gradio_interface, |
|
inputs=[ |
|
instruction, |
|
temperature, |
|
max_new_tokens, |
|
top_p, |
|
repetition_penalty, |
|
], |
|
outputs=[output], |
|
|
|
max_batch_size=2, |
|
show_progress=True, |
|
) |
|
|
|
demo.queue(max_size=10).launch( |
|
debug=True, |
|
server_port=DEFAULT_PORT, |
|
max_threads=utils.get_workers(), |
|
) |
|
|