Spaces:
Running
Running
import os | |
import requests | |
import gradio as gr | |
from transformers import pipeline | |
from smolagents import Tool | |
class TextGenerationTool(Tool): | |
name = "text_generator" | |
description = "This is a tool for text generation. It takes a prompt as input and returns the generated text." | |
inputs = { | |
"text": { | |
"type": "string", | |
"description": "The prompt for text generation" | |
} | |
} | |
output_type = "string" | |
# Available text generation models | |
models = { | |
"orca": "microsoft/Orca-2-13b", | |
"gpt2-dolly": "lgaalves/gpt2-dolly", | |
"gpt2": "gpt2", | |
"bloom": "bigscience/bloom-560m", | |
"openchat": "openchat/openchat_3.5" | |
} | |
def __init__(self, default_model="gpt2", use_api=False): | |
"""Initialize with a default model and API preference.""" | |
super().__init__() | |
self.default_model = default_model | |
self.use_api = use_api | |
self._pipelines = {} | |
# Check for API token | |
self.token = os.environ.get('HF_token') | |
if self.token is None and use_api: | |
print("Warning: HF_token environment variable not set. API calls will fail.") | |
def forward(self, text: str): | |
"""Process the input prompt and generate text.""" | |
return self.generate_text(text) | |
def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7): | |
"""Generate text based on the prompt using the specified or default model.""" | |
# Determine which model to use | |
model_key = model_key or self.default_model | |
model_name = self.models.get(model_key, self.models[self.default_model]) | |
# Generate using API if specified | |
if self.use_api and model_key == "openchat": | |
return self._generate_via_api(prompt, model_name) | |
# Otherwise use local pipeline | |
return self._generate_via_pipeline(prompt, model_name, max_length, temperature) | |
def _generate_via_pipeline(self, prompt, model_name, max_length, temperature): | |
"""Generate text using a local pipeline.""" | |
# Get or create the pipeline | |
if model_name not in self._pipelines: | |
self._pipelines[model_name] = pipeline( | |
"text-generation", | |
model=model_name, | |
token=self.token | |
) | |
generator = self._pipelines[model_name] | |
# Generate text | |
result = generator( | |
prompt, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature | |
) | |
# Extract and return the generated text | |
if isinstance(result, list) and len(result) > 0: | |
if isinstance(result[0], dict) and 'generated_text' in result[0]: | |
return result[0]['generated_text'] | |
return result[0] | |
return str(result) | |
def _generate_via_api(self, prompt, model_name): | |
"""Generate text by calling the Hugging Face API.""" | |
if not self.token: | |
return "Error: HF_token not set. Cannot use API." | |
api_url = f"https://api-inference.huggingface.co/models/{model_name}" | |
headers = {"Authorization": f"Bearer {self.token}"} | |
payload = {"inputs": prompt} | |
try: | |
response = requests.post(api_url, headers=headers, json=payload) | |
response.raise_for_status() # Raise exception for HTTP errors | |
result = response.json() | |
# Handle different response formats | |
if isinstance(result, list) and len(result) > 0: | |
if isinstance(result[0], dict) and 'generated_text' in result[0]: | |
return result[0]['generated_text'] | |
elif isinstance(result, dict) and 'generated_text' in result: | |
return result['generated_text'] | |
# Fall back to returning the raw response | |
return str(result) | |
except Exception as e: | |
return f"Error generating text: {str(e)}" | |
# For standalone testing | |
if __name__ == "__main__": | |
# Create an instance of the TextGenerationTool | |
text_generator = TextGenerationTool(default_model="gpt2") | |
# Test with a simple prompt | |
test_prompt = "Once upon a time in a digital world," | |
result = text_generator(test_prompt) | |
print(f"Prompt: {test_prompt}") | |
print(f"Generated text:\n{result}") |