import os
import json
import subprocess
from threading import Thread
import requests
import random
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from huggingface_hub import HfApi
from datetime import datetime
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = os.environ.get("MODEL_ID")
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")
DISCORD_WEBHOOK = os.environ.get("DISCORD_WEBHOOK")
TOKEN = os.environ.get("TOKEN")
api = HfApi()
def send_discord(i,o):
url = DISCORD_WEBHOOK
embed1 = {
"description": i,
"title": "Input"
}
embed2 = {
"description": o,
"title": "Output"
}
data = {
"content": "https://huggingface.co/spaces/speakleash/Bielik-7B-Instruct-v0.1",
"username": "Bielik Logger",
"embeds": [
embed1, embed2
],
}
headers = {
"Content-Type": "application/json"
}
result = requests.post(url, json=data, headers=headers)
if 200 <= result.status_code < 300:
print(f"Webhook sent {result.status_code}")
else:
print(f"Not sent with {result.status_code}, response:\n{result.json()}")
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype='auto',
attn_implementation="flash_attention_2",
)
@spaces.GPU()
def generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True if temperature else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
repetition_penalty=float(repetition_penalty)
print('LLL', [message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p])
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["", "[INST]", "[INST] ", "", "[/INST]", "[/INST] "]
instruction = '[INST] ' + system_prompt
for human, assistant in history:
instruction += human + ' [/INST] ' + assistant + '[INST]'
instruction += ' ' + message + ' [/INST]'
elif CHAT_TEMPLATE == "Bielik":
stop_tokens = [""]
prompt_builder = ["[INST] "]
if system_prompt:
prompt_builder.append(f"<>\n{system_prompt}\n<>\n\n")
for human, assistant in history:
prompt_builder.append(f"{human} [/INST] {assistant}[INST] ")
prompt_builder.append(f"{message} [/INST]")
instruction = ''.join(prompt_builder)
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
for output_text in generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
yield output_text
send_discord(instruction, output_text)
hfapi = HfApi()
day=datetime.now().strftime("%Y-%m-%d")
timestamp=datetime.now().timestamp()
dd={
'message': message,
'history': history,
'system_prompt':system_prompt,
'temperature':temperature,
'max_new_tokens':max_new_tokens,
'top_k':top_k,
'repetition_penalty':repetition_penalty,
'top_p':top_p,
'instruction':instruction,
'output':output_text,
'precision': 'auto '+str(model.dtype),
}
hfapi.upload_file(
path_or_fileobj=json.dumps(dd, indent=2, ensure_ascii=False).encode('utf-8'),
path_in_repo=f"{day}/{timestamp}.json",
repo_id="speakleash/bielik-logs",
repo_type="dataset",
commit_message=f"X",
token=TOKEN,
run_as_future=True
)
on_load="""
async()=>{
alert("Przed skorzystaniem z usługi użytkownicy muszą wyrazić zgodę na następujące warunki:\\n\\nProszę pamiętać, że przedstawiony tutaj model jest narzędziem eksperymentalnym, które wciąż jest rozwijane i doskonalone.\\n\\nW trakcie procesu tworzenia modelu podjęto środki mające na celu zminimalizowanie ryzyka generowania treści wulgarnych, niedozwolonych lub nieodpowiednich. Niemniej jednak, w rzadkich przypadkach, niepożądane treści mogą zostać wygenerowane. Jeśli napotkają Państwo na jakiekolwiek treści uznane za nieodpowiednie lub naruszające zasady, prosimy o kontakt w celu zgłoszenia tego faktu. Dzięki Państwa informacjom będziemy mogli podejmować dalsze działania mające na celu poprawę i rozwój modelu, tak aby był on bezpieczny i przyjazny dla użytkowników.\\n\\nNie wolno używać modelu do celów nielegalnych, szkodliwych, brutalnych, rasistowskich lub seksualnych. Proszę nie przesyłać żadnych prywatnych informacji. Serwis gromadzi dane dialogowe użytkownika i zastrzega sobie prawo do ich rozpowszechniania na podstawie licencji Creative Commons Uznanie autorstwa (CC-BY) lub podobnej.");
}
"""
def vote(chatbot, data: gr.LikeData):
day=datetime.now().strftime("%Y-%m-%d")
timestamp=datetime.now().timestamp()
api.upload_file(
path_or_fileobj=json.dumps({"history":chatbot, 'index': data.index, 'liked': data.liked}, indent=2, ensure_ascii=False).encode('utf-8'),
path_in_repo=f"liked/{day}/{timestamp}.json",
repo_id="speakleash/bielik-logs",
repo_type="dataset",
commit_message=f"L",
token=TOKEN,
run_as_future=True
)
# Create Gradio interface
def update_examples():
exs = [
["Kim jesteś?"],
["Ile to jest 9+2-1?"],
["Napisz mi coś miłego."]
]
random.shuffle(exs)
return gr.Dataset(samples=exs)
with gr.Blocks(js=on_load) as demo:
chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False)
chatbot.like(vote, [chatbot], None)
chat = gr.ChatInterface(
predict,
chatbot=chatbot,
title=EMOJI + " " + MODEL_NAME + " - online chat demo",
description=DESCRIPTION,
examples=[
["Kim jesteś?"],
["Ile to jest 9+2-1?"],
["Napisz mi coś miłego."]
],
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox("", label="System prompt", render=False),
gr.Slider(0, 1, 0.6, label="Temperature", render=False),
gr.Slider(128, 4096, 1024, label="Max new tokens", render=False),
gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False),
gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False),
gr.Slider(0, 1, 0.95, label="Top P sampling", render=False),
],
theme=gr.themes.Soft(primary_hue=COLOR),
)
demo.load(update_examples, None, chat.examples_handler.dataset)
demo.queue(max_size=20).launch()
# chatbot = gr.Chatbot(label="Chatbot", likeable=True)
# chatbot.like(vote, None, None)
# gr.ChatInterface(
# predict,
# chatbot=chatbot,
# title=EMOJI + " " + MODEL_NAME,
# description=DESCRIPTION,
# examples=[
# ["Kim jesteś?"],
# ["Ile to jest 9+2-1?"],
# ["Napisz mi coś miłego."]
# ],
# additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
# additional_inputs=[
# gr.Textbox("", label="System prompt"),
# gr.Slider(0, 1, 0.6, label="Temperature"),
# gr.Slider(128, 4096, 1024, label="Max new tokens"),
# gr.Slider(1, 80, 40, label="Top K sampling"),
# gr.Slider(0, 2, 1.1, label="Repetition penalty"),
# gr.Slider(0, 1, 0.95, label="Top P sampling"),
# ],
# theme=gr.themes.Soft(primary_hue=COLOR),
# js=on_load,
# ).queue().launch()