summarization / app.py
zoya23's picture
Update app.py
50a4735 verified
raw
history blame
2.21 kB
import streamlit as st
from langchain.prompts import FewShotChatMessagePromptTemplate
from langchain.prompts.example_selector import LengthBasedExampleSelector
from langchain_huggingface import HuggingFaceEndpoint, HuggingFacePipeline
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from datasets import load_dataset
from transformers import pipeline
# Load dataset (using knkarthick/dialogsum as an example)
@st.cache_data
def load_examples():
dataset = load_dataset("knkarthick/dialogsum", split="train[:5]") # Take only 5 for speed
examples = []
for example in dataset:
examples.append({
"input": example["dialogue"],
"output": example["summary"]
})
return examples
examples = load_examples()
# Load the Hugging Face model
hf_endpoint = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/t5-small" # or any model you like
)
# Create FewShotChatMessagePromptTemplate
example_prompt = FewShotChatMessagePromptTemplate.from_examples(
examples=examples,
example_selector=LengthBasedExampleSelector(examples=examples, max_length=1000),
input_variables=["input"],
prefix="You are a helpful assistant that summarizes dialogues. Examples:",
suffix="Now summarize this:\n{input}"
)
# Streamlit UI
st.title("πŸ’¬ Dialogue Summarizer using Few-Shot Prompt + T5 (via Langchain)")
input_text = st.text_area("πŸ“ Paste your conversation:")
if st.button("Generate Summary"):
if input_text.strip():
# Create the prompt using FewShotChatMessagePromptTemplate
messages = example_prompt.format_messages(input=input_text)
with st.expander("πŸ“‹ Generated Prompt"):
for msg in messages:
st.markdown(f"**{msg.type.upper()}**:\n```\n{msg.content}\n```")
# Set up HuggingFacePipeline with the model endpoint
hf_pipeline = HuggingFacePipeline(pipeline="summarization", model=hf_endpoint)
# Generate summary
summary = hf_pipeline(messages[0].content)
st.success("βœ… Summary:")
st.write(summary[0]['summary_text'])
else:
st.warning("Please enter some text.")