yakine commited on
Commit
15e9cb3
·
verified ·
1 Parent(s): 9802a6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -2,17 +2,23 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import os
4
  hf_token = os.getenv('HF_API_TOKEN')
5
- # Load the Llama 3.1 model and tokenizer
6
- model_name = "meta-llama/Meta-Llama-3.1-8B"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, token= hf_token)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto",token = hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Streamlit app interface
11
- st.title("Llama 3.1 Text Generator")
12
- prompt = st.text_area("Enter a prompt:", "Once upon a time")
13
 
14
- if st.button("Generate"):
15
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
16
- outputs = model.generate(**inputs, max_length=512, top_p=0.9, temperature=0.8)
17
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- st.write(generated_text)
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import os
4
  hf_token = os.getenv('HF_API_TOKEN')
5
+ import streamlit as st
6
+
7
+ from transformers import pipeline
8
+
9
+ # Load the model
10
+ generator = pipeline("text-generation", model="meta-llama/Meta-Llama-3.1-8B")
11
+
12
+ # Create an API route in Streamlit
13
+ @st.cache_resource
14
+ def predict(inputs):
15
+ return generator(inputs, max_length=512, top_p=0.9, temperature=0.8)[0]['generated_text']
16
+
17
+ @st.cache_resource
18
+ def predict_endpoint():
19
+ inputs = st.experimental_get_query_params().get('inputs', [''])[0]
20
+ return predict(inputs)
21
 
22
+ st.experimental_set_query_params(result=predict_endpoint())
 
 
23
 
24
+ st.title("Llama3.1 API is Running")