File size: 3,412 Bytes
9a76509
45d388a
 
 
 
 
59cd1f7
 
 
 
 
 
 
 
 
 
 
 
45d388a
ac716aa
9a76509
45d388a
9a76509
45d388a
 
9a76509
45d388a
 
 
 
 
 
9a76509
 
ac716aa
45d388a
ac716aa
 
 
 
 
 
 
 
 
 
 
 
 
 
59cd1f7
ac716aa
 
 
9f3a452
9a76509
ac716aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a76509
ac716aa
 
9a76509
ac716aa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager"  # Tránh cảnh báo sdpa
    )
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length=100):
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error: {str(e)}"

# Hàm hiển thị thông tin API
def get_api_info():
    # Trên Hugging Face Spaces, API URL sẽ dựa trên tên Space
    # Khi chạy local, ta giả định port 7860
    base_url = "http://localhost:7860" if gr.context.local else "https://<your-space-name>.hf.space"
    return (
        "Welcome to Qwen2.5-0.5B API!\n"
        f"API Base URL: {base_url}\n"
        "Endpoints:\n"
        f"- GET {base_url}/api/health_check (Check API status)\n"
        f"- POST {base_url}/api/generate (Generate text)\n"
        "To use the generate API, send a POST request with JSON:\n"
        '{"prompt": "your prompt", "max_length": 150}'
    )

# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
    return "Qwen2.5-0.5B API is running!"

# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
    gr.Markdown("# Qwen2.5-0.5B Text Generator")
    gr.Markdown("Enter a prompt below or use the API!")
    
    # Hiển thị thông tin API
    gr.Markdown("### API Information")
    api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
    
    # Giao diện sinh văn bản
    gr.Markdown("### Generate Text")
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
        max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
    
    generate_button = gr.Button("Generate")
    output_text = gr.Textbox(label="Generated Text", interactive=False)
    
    # Liên kết button với hàm generate_text
    generate_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_input],
        outputs=output_text
    )

# Định nghĩa API endpoints với Gradio
demo = gr.Interface(
    fn=generate_text,
    inputs=["text", "number"],
    outputs="text",
    title="Qwen2.5-0.5B API",
    api_name="/generate"  # API endpoint: /api/generate
).queue()

# Thêm endpoint health check
health_interface = gr.Interface(
    fn=health_check,
    inputs=None,
    outputs="text",
    api_name="/health_check"  # API endpoint: /api/health_check
)

# Kết hợp giao diện và API
app = gr.mount_gradio_app(demo, health_interface)

# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860)