File size: 5,562 Bytes
9f212f1
2c2fb5c
 
45b4dc5
 
9f212f1
 
 
2c2fb5c
 
45b4dc5
 
 
 
 
 
 
2c2fb5c
 
 
1a37ce8
 
 
 
 
2c2fb5c
 
1a37ce8
2c2fb5c
45b4dc5
2c2fb5c
 
 
 
 
1a37ce8
 
 
2c2fb5c
45b4dc5
2c2fb5c
45b4dc5
2c2fb5c
 
 
 
 
 
9f212f1
2c2fb5c
 
 
 
 
 
 
 
 
1a37ce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c2fb5c
 
1a37ce8
 
 
 
 
 
 
 
 
2c2fb5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import requests
import gradio as gr
from transformers import pipeline
from smolagents import Tool

class TextGenerationTool(Tool):
    name = "text_generator"
    description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
    
    inputs = {
        "text": {
            "type": "string",
            "description": "The prompt for text generation"
        }
    }
    output_type = "string"
    
    # Available text generation models
    models = {
        "distilgpt2": "distilgpt2",  # Smaller model, may work without auth
        "gpt2-small": "sshleifer/tiny-gpt2",  # Tiny model for testing
        "opt-125m": "facebook/opt-125m",  # Small, open model
        "bloom-560m": "bigscience/bloom-560m",
        "gpt2": "gpt2"  # Original GPT-2
    }
    
    def __init__(self, default_model="distilgpt2", use_api=False):
        """Initialize with a default model and API preference."""
        super().__init__()
        self.default_model = default_model
        self.use_api = use_api
        self._pipelines = {}
        
        # Check for API token
        self.token = os.environ.get('HF_TOKEN') or os.environ.get('HF_token')
        if self.token is None:
            print("Warning: No Hugging Face token found. Set HF_TOKEN environment variable for authenticated requests.")
    
    def forward(self, text: str):
        """Process the input prompt and generate text."""
        return self.generate_text(text)
    
    def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
        """Generate text based on the prompt using the specified or default model."""
        # Determine which model to use
        model_key = model_key or self.default_model
        model_name = self.models.get(model_key, self.models[self.default_model])
        
        # Generate using API if specified
        if self.use_api and model_key == "openchat":
            return self._generate_via_api(prompt, model_name)
        
        # Otherwise use local pipeline
        return self._generate_via_pipeline(prompt, model_name, max_length, temperature)
    
    def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
        """Generate text using a local pipeline."""
        try:
            # Get or create the pipeline
            if model_name not in self._pipelines:
                # Use token if available, otherwise try without it
                try:
                    kwargs = {"token": self.token} if self.token else {}
                    self._pipelines[model_name] = pipeline(
                        "text-generation", 
                        model=model_name, 
                        **kwargs
                    )
                except Exception as e:
                    print(f"Error loading model {model_name}: {str(e)}")
                    # Fall back to tiny-distilgpt2 if available
                    if model_name != "sshleifer/tiny-gpt2":
                        print("Falling back to tiny-gpt2 model...")
                        return self._generate_via_pipeline(prompt, "sshleifer/tiny-gpt2", max_length, temperature)
                    else:
                        raise e
            
            generator = self._pipelines[model_name]
            
            # Generate text
            result = generator(
                prompt, 
                max_length=max_length, 
                num_return_sequences=1, 
                temperature=temperature
            )
        
            # Extract and return the generated text
            if isinstance(result, list) and len(result) > 0:
                if isinstance(result[0], dict) and 'generated_text' in result[0]:
                    return result[0]['generated_text']
                return result[0]
            
            return str(result)
        except Exception as e:
            return f"Error generating text: {str(e)}\n\nPlease try a different model or prompt."
    
    def _generate_via_api(self, prompt, model_name):
        """Generate text by calling the Hugging Face API."""
        if not self.token:
            return "Error: HF_token not set. Cannot use API."
        
        api_url = f"https://api-inference.huggingface.co/models/{model_name}"
        headers = {"Authorization": f"Bearer {self.token}"}
        payload = {"inputs": prompt}
        
        try:
            response = requests.post(api_url, headers=headers, json=payload)
            response.raise_for_status()  # Raise exception for HTTP errors
            
            result = response.json()
            
            # Handle different response formats
            if isinstance(result, list) and len(result) > 0:
                if isinstance(result[0], dict) and 'generated_text' in result[0]:
                    return result[0]['generated_text']
            elif isinstance(result, dict) and 'generated_text' in result:
                return result['generated_text']
            
            # Fall back to returning the raw response
            return str(result)
            
        except Exception as e:
            return f"Error generating text: {str(e)}"

# For standalone testing
if __name__ == "__main__":
    # Create an instance of the TextGenerationTool
    text_generator = TextGenerationTool(default_model="gpt2")
    
    # Test with a simple prompt
    test_prompt = "Once upon a time in a digital world,"
    result = text_generator(test_prompt)
    
    print(f"Prompt: {test_prompt}")
    print(f"Generated text:\n{result}")