File size: 4,518 Bytes
9f212f1
2c2fb5c
 
45b4dc5
 
9f212f1
 
 
2c2fb5c
 
45b4dc5
 
 
 
 
 
 
2c2fb5c
 
 
 
 
 
 
 
 
 
45b4dc5
2c2fb5c
45b4dc5
2c2fb5c
 
 
 
 
 
 
 
 
45b4dc5
2c2fb5c
45b4dc5
2c2fb5c
 
 
 
 
 
9f212f1
2c2fb5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f212f1
 
2c2fb5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import requests
import gradio as gr
from transformers import pipeline
from smolagents import Tool

class TextGenerationTool(Tool):
    name = "text_generator"
    description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
    
    inputs = {
        "text": {
            "type": "string",
            "description": "The prompt for text generation"
        }
    }
    output_type = "string"
    
    # Available text generation models
    models = {
        "orca": "microsoft/Orca-2-13b",
        "gpt2-dolly": "lgaalves/gpt2-dolly",
        "gpt2": "gpt2",
        "bloom": "bigscience/bloom-560m",
        "openchat": "openchat/openchat_3.5"
    }
    
    def __init__(self, default_model="gpt2", use_api=False):
        """Initialize with a default model and API preference."""
        super().__init__()
        self.default_model = default_model
        self.use_api = use_api
        self._pipelines = {}
        
        # Check for API token
        self.token = os.environ.get('HF_token')
        if self.token is None and use_api:
            print("Warning: HF_token environment variable not set. API calls will fail.")
    
    def forward(self, text: str):
        """Process the input prompt and generate text."""
        return self.generate_text(text)
    
    def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
        """Generate text based on the prompt using the specified or default model."""
        # Determine which model to use
        model_key = model_key or self.default_model
        model_name = self.models.get(model_key, self.models[self.default_model])
        
        # Generate using API if specified
        if self.use_api and model_key == "openchat":
            return self._generate_via_api(prompt, model_name)
        
        # Otherwise use local pipeline
        return self._generate_via_pipeline(prompt, model_name, max_length, temperature)
    
    def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
        """Generate text using a local pipeline."""
        # Get or create the pipeline
        if model_name not in self._pipelines:
            self._pipelines[model_name] = pipeline(
                "text-generation", 
                model=model_name, 
                token=self.token
            )
        
        generator = self._pipelines[model_name]
        
        # Generate text
        result = generator(
            prompt, 
            max_length=max_length, 
            num_return_sequences=1, 
            temperature=temperature
        )
        
        # Extract and return the generated text
        if isinstance(result, list) and len(result) > 0:
            if isinstance(result[0], dict) and 'generated_text' in result[0]:
                return result[0]['generated_text']
            return result[0]
        
        return str(result)
    
    def _generate_via_api(self, prompt, model_name):
        """Generate text by calling the Hugging Face API."""
        if not self.token:
            return "Error: HF_token not set. Cannot use API."
        
        api_url = f"https://api-inference.huggingface.co/models/{model_name}"
        headers = {"Authorization": f"Bearer {self.token}"}
        payload = {"inputs": prompt}
        
        try:
            response = requests.post(api_url, headers=headers, json=payload)
            response.raise_for_status()  # Raise exception for HTTP errors
            
            result = response.json()
            
            # Handle different response formats
            if isinstance(result, list) and len(result) > 0:
                if isinstance(result[0], dict) and 'generated_text' in result[0]:
                    return result[0]['generated_text']
            elif isinstance(result, dict) and 'generated_text' in result:
                return result['generated_text']
            
            # Fall back to returning the raw response
            return str(result)
            
        except Exception as e:
            return f"Error generating text: {str(e)}"

# For standalone testing
if __name__ == "__main__":
    # Create an instance of the TextGenerationTool
    text_generator = TextGenerationTool(default_model="gpt2")
    
    # Test with a simple prompt
    test_prompt = "Once upon a time in a digital world,"
    result = text_generator(test_prompt)
    
    print(f"Prompt: {test_prompt}")
    print(f"Generated text:\n{result}")