File size: 2,298 Bytes
02a4337
26f89cf
 
 
 
 
 
 
 
 
02a4337
 
26f89cf
9599706
02a4337
26f89cf
9599706
 
 
 
 
02a4337
9599706
02a4337
9599706
 
 
 
 
02a4337
26f89cf
 
 
02a4337
9599706
 
26f89cf
 
9599706
 
02a4337
9599706
 
 
 
 
 
26f89cf
9599706
 
 
 
 
 
26f89cf
02a4337
9599706
02a4337
 
 
 
 
 
26f89cf
9599706
26f89cf
9599706
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st

# 1. Streamlit page config MUST be the first Streamlit command
st.set_page_config(
    page_title="Khmer Text Summarization",
    page_icon="πŸ“",
    layout="wide",
    initial_sidebar_state="expanded"
)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# 2. Model identifier
MODEL_ID = "songhieng/khmer-mt5-summarization"

# 3. Load tokenizer & model, cached to avoid reloading every run
@st.cache_resource
def load_tokenizer_and_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
    return tokenizer, model

tokenizer, model = load_tokenizer_and_model(MODEL_ID)

# 4. App header
st.title("πŸ“ Khmer Text Summarization")
st.write("Paste your Khmer text below and click **Summarize** to get a concise summary.")

# 5. Sidebar summarization settings
st.sidebar.header("Summarization Settings")
max_length = st.sidebar.slider("Maximum summary length", 50, 300, 150, step=10)
min_length = st.sidebar.slider("Minimum summary length", 10, 100, 30, step=5)
num_beams = st.sidebar.slider("Beam search width", 1, 10, 4, step=1)

# 6. Text input
user_input = st.text_area(
    "Enter Khmer text here…",
    height=300,
    placeholder="αžŸαžΌαž˜αžœαžΆαž™αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαž“αŸ…αž‘αžΈαž“αŸαŸ‡β€¦"
)

# 7. Summarize button
if st.button("Summarize"):
    if not user_input.strip():
        st.warning("⚠️ Please enter some text to summarize.")
    else:
        with st.spinner("Generating summary…"):
            # Tokenize the input text
            inputs = tokenizer(
                user_input,
                return_tensors="pt",
                truncation=True,
                padding="longest"
            )
            # Generate the summary
            summary_ids = model.generate(
                **inputs,
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                length_penalty=2.0,
                early_stopping=True
            )
            # Decode and display
            summary = tokenizer.decode(
                summary_ids[0],
                skip_special_tokens=True
            )
        st.subheader("πŸ”– Summary:")
        st.write(summary)