File size: 4,429 Bytes
f0dff07
 
 
9ab7a40
f0dff07
 
 
 
 
97befb1
59be267
f0dff07
c99eaf2
f0dff07
 
 
 
 
 
 
 
245e479
f0dff07
ff3e19f
f602cdc
59be267
1ff7179
b0097b1
59be267
 
 
b0097b1
97befb1
59be267
 
 
5a02dd0
 
97befb1
f0dff07
 
 
 
 
 
 
 
 
 
 
 
1ff7179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97befb1
6c44471
97befb1
 
f0dff07
97befb1
b0097b1
5a02dd0
0e35e11
9d7e24a
 
b0097b1
f0dff07
920b6db
f0dff07
920b6db
 
f0dff07
 
 
 
 
 
 
 
 
 
 
 
 
0e35e11
f0dff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e35e11
f0dff07
 
 
 
97befb1
 
 
 
f0dff07
920b6db
 
 
f0dff07
 
 
920b6db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
from peft import PeftModel

DESCRIPTION = "# 真空ジェネレータ\n<p>Imitate 真空 (@vericava)'s posts interactively</p>"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))


if torch.cuda.is_available():
    model_id = "vericava/llm-jp-3-1.8b-instruct-lora-vericava17"
    base_model_id = "llm-jp/llm-jp-3-1.8b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
    tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )
    model.load_adapter(model_id)
    my_pipeline=pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        do_sample=True,
        num_beams=1,
    )

@spaces.GPU
@torch.inference_mode()
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    from datetime import datetime, timezone, timedelta

    d=datetime.now(timezone(timedelta(hours=9), 'JST'))
    m=d.month
    if m < 3 or m > 11:
        season = '冬'
    elif m < 6:
        season = '春'
    elif m < 9:
        season = '夏'
    else:
        season = '秋'

    h=d.hour
    go = '午前' if h < 12 else '午後'
    h = h % 12
    minute = d.minute
    time = go + str(h) + '時' + str(minute) + '分'

    messages = [
        {"role": "system", "content": "なお今は日本の" + season + "で、時刻は" + time + "であるものとする。また、あなたは真空という名前のユーザであるとする。"},
        {"role": "user", "content": message},
    ]

    output = my_pipeline(
        messages,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
    )
    print(output)
    yield output[-1]["generated_text"][-1]["content"]

demo = gr.ChatInterface(
    fn=generate,
    type="tuples",
    additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=1.0,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.5,
        ),
    ],
    stop_btn=None,
    examples=[
        ["サマリーを作る男の人,サマリーマン。"],
        ["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
        ["にゃん"],
        ["Wikipedia の情報は入っているのかもしれない"],
    ],
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
)

if __name__ == "__main__":
    demo.launch()