pradeepsengarr commited on
Commit
6161368
·
verified ·
1 Parent(s): 37fedd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -33,7 +33,6 @@
33
  # st.warning("Please enter a query!")
34
 
35
 
36
-
37
  import os
38
  import streamlit as st
39
  import torch
@@ -52,7 +51,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
52
  # Load the model and tokenizer with token authentication
53
  MODEL_NAME = "google/gemma-2b-it"
54
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
55
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=hf_token, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto")
 
 
56
 
57
  # Streamlit UI
58
  st.title("Gemma-2B Code Assistant")
@@ -60,10 +61,12 @@ user_input = st.text_area("Enter your coding query:")
60
 
61
  if st.button("Generate Code"):
62
  if user_input:
63
- inputs = tokenizer(user_input, return_tensors="pt").to(device)
64
- output = model.generate(**inputs, max_new_tokens=100)
65
- response = tokenizer.decode(output[0], skip_special_tokens=True)
66
- st.write(response)
 
 
 
67
  else:
68
- st.warning("Please enter a query!")
69
-
 
33
  # st.warning("Please enter a query!")
34
 
35
 
 
36
  import os
37
  import streamlit as st
38
  import torch
 
51
  # Load the model and tokenizer with token authentication
52
  MODEL_NAME = "google/gemma-2b-it"
53
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_NAME, token=hf_token, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto"
56
+ )
57
 
58
  # Streamlit UI
59
  st.title("Gemma-2B Code Assistant")
 
61
 
62
  if st.button("Generate Code"):
63
  if user_input:
64
+ with st.spinner("⏳ Generating response... Please wait!"):
65
+ inputs = tokenizer(user_input, return_tensors="pt").to(device)
66
+ output = model.generate(**inputs, max_new_tokens=100)
67
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
68
+
69
+ st.subheader("📝 Generated Code:")
70
+ st.code(response, language="python") # Display code properly
71
  else:
72
+ st.warning("⚠️ Please enter a query!")