File size: 9,420 Bytes
9554708 a104392 c1399be 9554708 c1399be 8c85a02 9554708 c1399be 9554708 21bc905 8a48ceb 9554708 c1399be 9554708 8a48ceb 9554708 c1399be 8a48ceb c1399be 8a48ceb c1399be 8a48ceb c1399be 9554708 c1399be 8a48ceb c1399be 9554708 c1399be 9554708 c1399be 9554708 c1399be 9554708 c1399be 9554708 c1399be 8a48ceb c1399be 8a48ceb c1399be 9554708 c1399be 8a48ceb 9554708 8a48ceb 9554708 c1399be 9554708 c1399be 9554708 c1399be 9554708 c1399be 9554708 c1399be 8a48ceb c1399be 8a48ceb c1399be 9554708 8a48ceb 9554708 c1399be 8a48ceb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax')
# to your requirements.txt file in the Hugging Face Space repository.
# gated model
# Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant)
from huggingface_hub import login
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining
# --- hf lpgin ---
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)
# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
examples_list = [] # Initialize empty list for examples
try:
# Check if GPU is available and use it, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
print("Tokenizer loaded.")
# Load the model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=MODEL_CACHE,
device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
)
print("Model loaded.")
# Download and load the prompts JSON file
print(f"Downloading {PROMPT_FILENAME}...")
prompts_file_path = hf_hub_download(
repo_id=MODEL_NAME,
filename=PROMPT_FILENAME,
cache_dir=MODEL_CACHE,
# force_download=True, # Uncomment to force redownload if needed
)
print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")
# Load the JSON data
with open(prompts_file_path, 'r') as f:
tdc_prompts_data = json.load(f)
print(f"Loaded prompts data from {PROMPT_FILENAME}.")
# --- Prepare examples for Gradio ---
# Updated logic: Parse the dictionary format from tdc_prompts.json
# The JSON is expected to be a dictionary where values are prompt templates.
if isinstance(tdc_prompts_data, dict):
print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...")
count = 0
for prompt_template in tdc_prompts_data.values():
if count >= MAX_EXAMPLES:
break
if isinstance(prompt_template, str):
# Replace the placeholder with the example SMILES string
example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
# Add to examples list with default parameters
examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7
count += 1
else:
print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}")
print(f"Prepared {len(examples_list)} examples for Gradio.")
else:
print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
# examples_list remains empty
except Exception as e:
print(f"Error loading model, tokenizer, or prompts: {e}")
# Ensure examples_list is empty on error during setup
examples_list = []
raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")
# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
"""
Generates text based on the input prompt using the loaded model.
Args:
prompt (str): The input text prompt.
max_new_tokens (int): The maximum number of new tokens to generate.
temperature (float): Controls the randomness of the generation. Lower is more deterministic.
Returns:
str: The generated text.
"""
print(f"Received prompt: {prompt}")
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")
try:
# Prepare the input for the model
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's an integer
temperature=float(temperature), # Ensure it's a float
do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
pad_token_id=tokenizer.eos_token_id # Set pad token id
)
# Decode the generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text (raw): {generated_text}")
# Remove the prompt from the beginning of the generated text
if generated_text.startswith(prompt):
prompt_length = len(prompt)
result_text = generated_text[prompt_length:].lstrip()
else:
# Handle cases where the model might slightly alter the prompt start
# This is a basic check; more robust checks might be needed
common_prefix = os.path.commonprefix([prompt, generated_text])
# Check if a significant portion of the prompt is at the start
# Use a threshold relative to prompt length, e.g., 80%
if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
result_text = generated_text[len(common_prefix):].lstrip()
else:
result_text = generated_text # Assume prompt is not included or significantly altered
print(f"Generated text (processed): {result_text}")
return result_text
except Exception as e:
print(f"Error during prediction: {e}")
return f"An error occurred during generation: {e}"
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🤖 TXGemma-2B-Predict Text Generation
Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it.
Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`.
Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...",
lines=5
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=500, # Adjust max limit if needed
value=100,
step=10,
label="Max New Tokens",
info="Maximum number of tokens to generate after the prompt."
)
temperature_slider = gr.Slider(
minimum=0.0, # Allow deterministic generation
maximum=1.5,
value=0.7,
step=0.05, # Finer control for temperature
label="Temperature",
info="Controls randomness (0=deterministic, >0=random)."
)
submit_button = gr.Button("Generate Text", variant="primary")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False # Output is not editable by user
)
# --- Connect Components ---
submit_button.click(
fn=predict,
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
api_name="predict" # Name for API endpoint if needed
)
# Use the loaded examples if available
if examples_list:
gr.Examples(
examples=examples_list,
# Ensure inputs match the order expected by the 'predict' function and the structure of examples_list
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
fn=predict, # The function to run when an example is clicked
cache_examples=False # Caching might be slow/problematic for LLMs
)
else:
gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_")
# --- Launch the App ---
print("Launching Gradio app...")
# queue() enables handling multiple users concurrently
# Set share=True if you need a public link, otherwise False or omit
demo.queue().launch(debug=True) # Set debug=False for production
|