File size: 2,084 Bytes
ae274fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
model_name = "Qwen/Qwen1.5-0.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


def generate_text(text):
    # Tokenize the input text, including attention mas
    #input_ids = tokenizer(text, return_tensors="pt", padding=True)



    messages = []
    use_system_prompt = True
    DEFAULT_SYSTEM_PROMPT = "you are helpfull assistant."
    if use_system_prompt:
        messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT}
            ]

        user_messages = [
            {"role": "user", "content": text}
        ]
        messages += user_messages

    prompt = tokenizer.apply_chat_template(
            conversation=messages,
            add_generation_prompt=True,
            tokenize=False
        )
    
    input_datas = tokenizer(
        prompt,
        add_special_tokens=True,
        return_tensors="pt"
        )

    # Generate text, passing the attention mask
    generated_ids = model.generate(input_ids=input_datas.input_ids, attention_mask=input_datas.attention_mask,max_length=10000)
    #generated_ids = model.generate(input_ids=input_ids, max_length=100)
    
    # Decode the generated tokens
    generated_text = tokenizer.decode(generated_ids[0][input_datas.input_ids.size(1) :], skip_special_tokens=True)

    # Print the generated text
    #print(generated_text)
    return generated_text

from flask import Flask, request, jsonify

app = Flask(__name__)
#app.logger.disabled = True
#log = logging.getLogger('werkzeug')
#log.disabled = True

@app.route('/')
def predict():
    param_value = request.args.get('param', '') 
    # ここにモデルの推論ロジックを追加
    #output = pipe(messages, **generation_args) 
    #text = (output[0]['generated_text'])
    #print("hello")
    #result = {"prediction": "dummy_result"}
    text = generate_text(param_value)
    return f"{text}"

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)