Ghosthash commited on
Commit
86265d8
·
verified ·
1 Parent(s): 047b49f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,12 +1,20 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
- pipe = pipeline("text-generation", model="yubi12/autotrain-crypto-degen-tweet-generator-v1", device=1)
5
- text = st.text_input('enter text here')
 
 
 
 
 
 
6
 
7
  if text:
8
  messages = [
9
  {"role": "user", "content": text},
10
  ]
11
- out = pipe(text)
12
- st.json(out)
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ model_path = "yubi12/autotrain-crypto-degen-tweet-generator-v1"
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ model_path,
9
+ device_map="auto",
10
+ torch_dtype='auto'
11
+ ).eval()
12
 
13
  if text:
14
  messages = [
15
  {"role": "user", "content": text},
16
  ]
17
+ input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
18
+ output_ids = model.generate(input_ids.to('cuda'))
19
+ response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
20
+ print(response)