chat-1 / app.py
metastable-void
renamed
2e5890c unverified
raw
history blame
3.34 kB
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
DESCRIPTION = "# chat-1"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
if torch.cuda.is_available():
model_id = "vericava/llm-jp-3-1.8b-instruct-lora-vericava7-llama"
my_pipeline=pipeline(
model=model_id,
)
my_pipeline.tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"
@spaces.GPU
@torch.inference_mode()
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
messages = [
{"role": "system", "content": "あなたはSNSの投稿生成botで、次に続く投稿を考えてください。"},
{"role": "user", "content": message},
]
output = my_pipeline(
messages,
)[-1]["generated_text"][-1]["content"]
yield output
demo = gr.ChatInterface(
fn=generate,
type="tuples",
additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
["サマリーを作る男の人,サマリーマン。"],
["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
["にゃん"],
["Wikipedia の情報は入っているのかもしれない"],
],
description=DESCRIPTION,
css_paths="style.css",
fill_height=True,
)
if __name__ == "__main__":
demo.launch()