Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,6 @@
|
|
32 |
# else:
|
33 |
# st.warning("Please enter a query!")
|
34 |
|
35 |
-
|
36 |
import os
|
37 |
import streamlit as st
|
38 |
import torch
|
@@ -48,11 +47,14 @@ if not hf_token:
|
|
48 |
# Check if CUDA is available
|
49 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
50 |
|
51 |
-
# Load
|
52 |
MODEL_NAME = "google/gemma-2b-it"
|
53 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
|
54 |
model = AutoModelForCausalLM.from_pretrained(
|
55 |
-
MODEL_NAME,
|
|
|
|
|
|
|
56 |
)
|
57 |
|
58 |
# Streamlit UI
|
@@ -63,10 +65,10 @@ if st.button("Generate Code"):
|
|
63 |
if user_input:
|
64 |
with st.spinner("⏳ Generating response... Please wait!"):
|
65 |
inputs = tokenizer(user_input, return_tensors="pt").to(device)
|
66 |
-
output = model.generate(**inputs, max_new_tokens=
|
67 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
68 |
|
69 |
st.subheader("📝 Generated Code:")
|
70 |
-
st.code(response, language="python")
|
71 |
else:
|
72 |
st.warning("⚠️ Please enter a query!")
|
|
|
32 |
# else:
|
33 |
# st.warning("Please enter a query!")
|
34 |
|
|
|
35 |
import os
|
36 |
import streamlit as st
|
37 |
import torch
|
|
|
47 |
# Check if CUDA is available
|
48 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
|
50 |
+
# Load model in 8-bit mode (less RAM usage)
|
51 |
MODEL_NAME = "google/gemma-2b-it"
|
52 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
|
53 |
model = AutoModelForCausalLM.from_pretrained(
|
54 |
+
MODEL_NAME,
|
55 |
+
token=hf_token,
|
56 |
+
device_map="auto",
|
57 |
+
load_in_8bit=True # 👈 This reduces RAM usage!
|
58 |
)
|
59 |
|
60 |
# Streamlit UI
|
|
|
65 |
if user_input:
|
66 |
with st.spinner("⏳ Generating response... Please wait!"):
|
67 |
inputs = tokenizer(user_input, return_tensors="pt").to(device)
|
68 |
+
output = model.generate(**inputs, max_new_tokens=50)
|
69 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
70 |
|
71 |
st.subheader("📝 Generated Code:")
|
72 |
+
st.code(response, language="python")
|
73 |
else:
|
74 |
st.warning("⚠️ Please enter a query!")
|