File size: 924 Bytes
f94b825
 
 
 
 
 
6a67ea5
f94b825
 
 
6a67ea5
 
f94b825
 
c6459ea
f94b825
 
e573614
f94b825
 
2b86b0b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("jokes-tokenizer")

@st.cache
def load_model():
    model = T5ForConditionalGeneration.from_pretrained("jokes-model")
    return model

model = load_model()

def infer(input_ids):
    output_sequences = model.generate(input_ids=input_ids)
    return tokenizer.decode(output_sequences[0], skip_special_tokens=True)

st.title("Stupid jokes with transformers")
st.write("Write a question you want to see a funny answer for.")
sent = st.text_area("Text", height = 100)

if sent:
    max_source_length = 64
    max_target_length = 32
    prefix = "Answer the following question in a funny way: "
    
    input_ids = tokenizer(prefix + sent, max_length=max_source_length, truncation=True, return_tensors="pt").input_ids
    generated_sequence = infer(input_ids)
    
    st.write(generated_sequence)