Chris4K commited on
Commit
45b4dc5
·
verified ·
1 Parent(s): 7e8ebae

Update text_generator.py

Browse files
Files changed (1) hide show
  1. text_generator.py +13 -6
text_generator.py CHANGED
@@ -1,14 +1,20 @@
1
  import os
2
  import requests
3
  import gradio as gr
4
- from transformers import pipeline, Tool
 
5
 
6
  class TextGenerationTool(Tool):
7
  name = "text_generator"
8
  description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
9
 
10
- inputs = {"text": {"type": "text", "description": "The prompt for text generation"}}
11
- outputs = {"text": {"type": "text", "description": "The generated text"}}
 
 
 
 
 
12
 
13
  # Available text generation models
14
  models = {
@@ -19,8 +25,9 @@ class TextGenerationTool(Tool):
19
  "openchat": "openchat/openchat_3.5"
20
  }
21
 
22
- def __init__(self, default_model="orca", use_api=False):
23
  """Initialize with a default model and API preference."""
 
24
  self.default_model = default_model
25
  self.use_api = use_api
26
  self._pipelines = {}
@@ -30,9 +37,9 @@ class TextGenerationTool(Tool):
30
  if self.token is None and use_api:
31
  print("Warning: HF_token environment variable not set. API calls will fail.")
32
 
33
- def __call__(self, prompt: str):
34
  """Process the input prompt and generate text."""
35
- return self.generate_text(prompt)
36
 
37
  def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
38
  """Generate text based on the prompt using the specified or default model."""
 
1
  import os
2
  import requests
3
  import gradio as gr
4
+ from transformers import pipeline
5
+ from smolagents import Tool
6
 
7
  class TextGenerationTool(Tool):
8
  name = "text_generator"
9
  description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
10
 
11
+ inputs = {
12
+ "text": {
13
+ "type": "string",
14
+ "description": "The prompt for text generation"
15
+ }
16
+ }
17
+ output_type = "string"
18
 
19
  # Available text generation models
20
  models = {
 
25
  "openchat": "openchat/openchat_3.5"
26
  }
27
 
28
+ def __init__(self, default_model="gpt2", use_api=False):
29
  """Initialize with a default model and API preference."""
30
+ super().__init__()
31
  self.default_model = default_model
32
  self.use_api = use_api
33
  self._pipelines = {}
 
37
  if self.token is None and use_api:
38
  print("Warning: HF_token environment variable not set. API calls will fail.")
39
 
40
+ def forward(self, text: str):
41
  """Process the input prompt and generate text."""
42
+ return self.generate_text(text)
43
 
44
  def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
45
  """Generate text based on the prompt using the specified or default model."""