Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import re | |
from bs4 import BeautifulSoup | |
from transformers import AutoTokenizer | |
from vllm import LLM, SamplingParams | |
# Load model and tokenizer | |
MODEL_NAME = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3" | |
SYS_CONTENT = ( | |
"あなたは実績のある日本語意味解析ソフトウェアです。質問には正確に具体的に回答できます。" | |
"次のJSONのパターンと抽出のルールに従って,JSONのパターンの値を埋めて返して下さい。" | |
"JSONのパターン:{ \"subject\": \"...\", \"when\": \"...\", \"where\": \"...\", \"what\": [...], , \"orgs\": [...]}" | |
"抽出のルール:入力される文章について,どの住所(address)の誰(who)が,いつ(when),どこで(where),どうした(what)と書いてますか?: " | |
"[1]何(subject)には組織名,会社名を住所を含めて入れてください。" | |
"組織名,会社名の直後の括弧に住所がある場合,例えば「日経新聞社(東京都千代田区)」とある場合は,「日経新聞社(東京都千代田区)」と括弧と住所がついた語を一緒に抽出してください。subjectが複数ある場合は半角カンマで区切って下さい。" | |
"人名は絶対に入れないで下さい。例えば「大谷翔平氏」などは入れてはいけません。" | |
"[2]どこ(where)にはwhatが起きた具体的な住所や地名や施設名を入れてください。" | |
"[3]どうした(what)には文章の短い要約を3つの箇条書きで書き,リストにしてください。" | |
"[4]orgsには提示された文章にでてきた会社名や組織名をすべて列挙してリストにして下さい。" | |
"会社名や組織名の直後の括弧に住所がある場合,例えば「三井物産(東京都千代田区)」とある場合は,「三井物産(東京都千代田区)」と括弧と住所がついた語を一緒に抽出してください[4]orgsには提示された文章にでてきた会社名や組織名をすべて列挙してリストにして下さい。会社名や組織名の直後の括弧に住所がある場合,例えば「三井物産(東京都千代田区)」とある場合は,「三井物産(東京都千代田区)」と括弧と住所がついた語を一緒に抽出してください。" | |
"[5]もしも該当の情報が提示された文章になければそのJSONの要素の値を空にしてください。" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
llm = LLM( | |
model=MODEL_NAME, | |
tensor_parallel_size=1, | |
) | |
def preprocess_text(text: str) -> str: | |
# HTMLタグの削除 | |
soup = BeautifulSoup(text, 'html.parser') | |
text = soup.get_text() | |
# 独自タグの削除 (<...> </...>) | |
text = re.sub(r'<[^>]+>', '', text) | |
# 改行、タブ、余分な空白(半角・全角)の削除 | |
text = re.sub(r'[\n\t]', '', text) | |
text = re.sub(r'[\s ]+', ' ', text) # 連続する空白を1つの半角スペースに置換 | |
text = text.strip() | |
return text | |
def inference(content: str, max_tokens: int, temperature: float, top_p: float): | |
sampling_params = SamplingParams( | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_tokens, | |
stop="<|eot_id|>" | |
) | |
# 入力テキストの前処理 | |
processed_content = preprocess_text(content) | |
message = [ | |
{ | |
"role": "system", | |
"content": SYS_CONTENT | |
}, | |
{ | |
"role": "user", | |
"content": processed_content, | |
}, | |
] | |
try: | |
prompt = tokenizer.apply_chat_template( | |
message, tokenize=False, add_generation_prompt=True | |
) | |
output = llm.generate(prompt, sampling_params) | |
result_text = output[0].outputs[0].text | |
# JSONを抽出 | |
json_pattern = r'\{[^{}]*\}' | |
match = re.search(json_pattern, result_text) | |
if not match: | |
return "エラー: 生成されたテキストからJSONが見つかりませんでした。" | |
try: | |
json_data = json.loads(match.group()) | |
return json.dumps(json_data, ensure_ascii=False, indent=2) | |
except json.JSONDecodeError as e: | |
return f"JSONパースエラー: {str(e)}" | |
except Exception as e: | |
return f"生成エラー: {str(e)}" | |
# Gradioインターフェースの作成 | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Textbox(label="入力テキスト", lines=10), | |
gr.Number(label="最大トークン数", value=512), | |
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.3, step=0.1), | |
gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1), | |
], | |
outputs=gr.Textbox(label="解析結果", lines=10), | |
title="意味解析エンジン", | |
description="テキストを入力すると、5W(Who, What, When, Where, How)の形式で情報を抽出します.テキスト内に混入した改行や空白,独自タグ等を削除する整形処理を入れてますが,きちんとテストしていません.エラーが出る場合は事前に整形してからテキストを入れて下さい.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |