Spaces:
Running
Running
File size: 4,518 Bytes
9f212f1 2c2fb5c 45b4dc5 9f212f1 2c2fb5c 45b4dc5 2c2fb5c 45b4dc5 2c2fb5c 45b4dc5 2c2fb5c 45b4dc5 2c2fb5c 45b4dc5 2c2fb5c 9f212f1 2c2fb5c 9f212f1 2c2fb5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import requests
import gradio as gr
from transformers import pipeline
from smolagents import Tool
class TextGenerationTool(Tool):
name = "text_generator"
description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
inputs = {
"text": {
"type": "string",
"description": "The prompt for text generation"
}
}
output_type = "string"
# Available text generation models
models = {
"orca": "microsoft/Orca-2-13b",
"gpt2-dolly": "lgaalves/gpt2-dolly",
"gpt2": "gpt2",
"bloom": "bigscience/bloom-560m",
"openchat": "openchat/openchat_3.5"
}
def __init__(self, default_model="gpt2", use_api=False):
"""Initialize with a default model and API preference."""
super().__init__()
self.default_model = default_model
self.use_api = use_api
self._pipelines = {}
# Check for API token
self.token = os.environ.get('HF_token')
if self.token is None and use_api:
print("Warning: HF_token environment variable not set. API calls will fail.")
def forward(self, text: str):
"""Process the input prompt and generate text."""
return self.generate_text(text)
def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
"""Generate text based on the prompt using the specified or default model."""
# Determine which model to use
model_key = model_key or self.default_model
model_name = self.models.get(model_key, self.models[self.default_model])
# Generate using API if specified
if self.use_api and model_key == "openchat":
return self._generate_via_api(prompt, model_name)
# Otherwise use local pipeline
return self._generate_via_pipeline(prompt, model_name, max_length, temperature)
def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
"""Generate text using a local pipeline."""
# Get or create the pipeline
if model_name not in self._pipelines:
self._pipelines[model_name] = pipeline(
"text-generation",
model=model_name,
token=self.token
)
generator = self._pipelines[model_name]
# Generate text
result = generator(
prompt,
max_length=max_length,
num_return_sequences=1,
temperature=temperature
)
# Extract and return the generated text
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and 'generated_text' in result[0]:
return result[0]['generated_text']
return result[0]
return str(result)
def _generate_via_api(self, prompt, model_name):
"""Generate text by calling the Hugging Face API."""
if not self.token:
return "Error: HF_token not set. Cannot use API."
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
headers = {"Authorization": f"Bearer {self.token}"}
payload = {"inputs": prompt}
try:
response = requests.post(api_url, headers=headers, json=payload)
response.raise_for_status() # Raise exception for HTTP errors
result = response.json()
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and 'generated_text' in result[0]:
return result[0]['generated_text']
elif isinstance(result, dict) and 'generated_text' in result:
return result['generated_text']
# Fall back to returning the raw response
return str(result)
except Exception as e:
return f"Error generating text: {str(e)}"
# For standalone testing
if __name__ == "__main__":
# Create an instance of the TextGenerationTool
text_generator = TextGenerationTool(default_model="gpt2")
# Test with a simple prompt
test_prompt = "Once upon a time in a digital world,"
result = text_generator(test_prompt)
print(f"Prompt: {test_prompt}")
print(f"Generated text:\n{result}") |