File size: 9,420 Bytes
9554708
 
 
a104392
c1399be
 
 
 
 
 
 
9554708
 
 
c1399be
 
 
8c85a02
 
 
9554708
 
c1399be
9554708
21bc905
8a48ceb
9554708
c1399be
9554708
8a48ceb
 
9554708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1399be
 
 
 
 
 
 
 
 
 
 
 
 
 
8a48ceb
c1399be
 
8a48ceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1399be
8a48ceb
 
c1399be
9554708
 
c1399be
8a48ceb
 
c1399be
9554708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1399be
 
9554708
 
 
 
c1399be
9554708
c1399be
9554708
 
c1399be
9554708
c1399be
 
 
8a48ceb
 
 
c1399be
 
8a48ceb
c1399be
 
9554708
 
 
 
 
 
 
 
 
 
 
 
 
c1399be
 
8a48ceb
9554708
 
 
 
 
 
8a48ceb
9554708
 
 
 
 
c1399be
9554708
 
 
 
 
 
c1399be
9554708
 
c1399be
9554708
c1399be
9554708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1399be
 
 
 
8a48ceb
 
c1399be
 
 
 
 
8a48ceb
c1399be
9554708
 
 
 
8a48ceb
9554708
c1399be
8a48ceb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax')
# to your requirements.txt file in the Hugging Face Space repository.
# gated model
# Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant)
from huggingface_hub import login

# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.

import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining
# --- hf lpgin ---
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)

# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
examples_list = [] # Initialize empty list for examples
try:
    # Check if GPU is available and use it, otherwise use CPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
    print("Tokenizer loaded.")

    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        cache_dir=MODEL_CACHE,
        device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
    )
    print("Model loaded.")

    # Download and load the prompts JSON file
    print(f"Downloading {PROMPT_FILENAME}...")
    prompts_file_path = hf_hub_download(
        repo_id=MODEL_NAME,
        filename=PROMPT_FILENAME,
        cache_dir=MODEL_CACHE,
        # force_download=True, # Uncomment to force redownload if needed
    )
    print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")

    # Load the JSON data
    with open(prompts_file_path, 'r') as f:
        tdc_prompts_data = json.load(f)
    print(f"Loaded prompts data from {PROMPT_FILENAME}.")

    # --- Prepare examples for Gradio ---
    # Updated logic: Parse the dictionary format from tdc_prompts.json
    # The JSON is expected to be a dictionary where values are prompt templates.
    if isinstance(tdc_prompts_data, dict):
        print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...")
        count = 0
        for prompt_template in tdc_prompts_data.values():
            if count >= MAX_EXAMPLES:
                break
            if isinstance(prompt_template, str):
                # Replace the placeholder with the example SMILES string
                example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
                # Add to examples list with default parameters
                examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7
                count += 1
            else:
                print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}")
        print(f"Prepared {len(examples_list)} examples for Gradio.")

    else:
        print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
        # examples_list remains empty


except Exception as e:
    print(f"Error loading model, tokenizer, or prompts: {e}")
    # Ensure examples_list is empty on error during setup
    examples_list = []
    raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")


# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
    """
    Generates text based on the input prompt using the loaded model.

    Args:
        prompt (str): The input text prompt.
        max_new_tokens (int): The maximum number of new tokens to generate.
        temperature (float): Controls the randomness of the generation. Lower is more deterministic.

    Returns:
        str: The generated text.
    """
    print(f"Received prompt: {prompt}")
    print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")

    try:
        # Prepare the input for the model
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device

        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=int(max_new_tokens), # Ensure it's an integer
                temperature=float(temperature),   # Ensure it's a float
                do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
                pad_token_id=tokenizer.eos_token_id # Set pad token id
            )

        # Decode the generated tokens
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated text (raw): {generated_text}")

        # Remove the prompt from the beginning of the generated text
        if generated_text.startswith(prompt):
            prompt_length = len(prompt)
            result_text = generated_text[prompt_length:].lstrip()
        else:
             # Handle cases where the model might slightly alter the prompt start
             # This is a basic check; more robust checks might be needed
             common_prefix = os.path.commonprefix([prompt, generated_text])
             # Check if a significant portion of the prompt is at the start
             # Use a threshold relative to prompt length, e.g., 80%
             if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
                 result_text = generated_text[len(common_prefix):].lstrip()
             else:
                 result_text = generated_text # Assume prompt is not included or significantly altered

        print(f"Generated text (processed): {result_text}")
        return result_text

    except Exception as e:
        print(f"Error during prediction: {e}")
        return f"An error occurred during generation: {e}"

# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        f"""
        # 🤖 TXGemma-2B-Predict Text Generation

        Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it.
        Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`.
        Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
        """
    )
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="Your Prompt",
                placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...",
                lines=5
            )
            with gr.Row():
                 max_tokens_slider = gr.Slider(
                    minimum=10,
                    maximum=500, # Adjust max limit if needed
                    value=100,
                    step=10,
                    label="Max New Tokens",
                    info="Maximum number of tokens to generate after the prompt."
                 )
                 temperature_slider = gr.Slider(
                    minimum=0.0, # Allow deterministic generation
                    maximum=1.5,
                    value=0.7,
                    step=0.05, # Finer control for temperature
                    label="Temperature",
                    info="Controls randomness (0=deterministic, >0=random)."
                 )
            submit_button = gr.Button("Generate Text", variant="primary")
        with gr.Column(scale=3):
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10,
                interactive=False # Output is not editable by user
            )

    # --- Connect Components ---
    submit_button.click(
        fn=predict,
        inputs=[prompt_input, max_tokens_slider, temperature_slider],
        outputs=output_text,
        api_name="predict" # Name for API endpoint if needed
    )

    # Use the loaded examples if available
    if examples_list:
        gr.Examples(
            examples=examples_list,
            # Ensure inputs match the order expected by the 'predict' function and the structure of examples_list
            inputs=[prompt_input, max_tokens_slider, temperature_slider],
            outputs=output_text,
            fn=predict, # The function to run when an example is clicked
            cache_examples=False # Caching might be slow/problematic for LLMs
        )
    else:
        gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_")


# --- Launch the App ---
print("Launching Gradio app...")
# queue() enables handling multiple users concurrently
# Set share=True if you need a public link, otherwise False or omit
demo.queue().launch(debug=True) # Set debug=False for production