Chris4K commited on
Commit
1a37ce8
·
verified ·
1 Parent(s): 28afd39

Update text_generator.py

Browse files
Files changed (1) hide show
  1. text_generator.py +46 -32
text_generator.py CHANGED
@@ -18,14 +18,14 @@ class TextGenerationTool(Tool):
18
 
19
  # Available text generation models
20
  models = {
21
- "orca": "microsoft/Orca-2-13b",
22
- "gpt2-dolly": "lgaalves/gpt2-dolly",
23
- "gpt2": "gpt2",
24
- "bloom": "bigscience/bloom-560m",
25
- "openchat": "openchat/openchat_3.5"
26
  }
27
 
28
- def __init__(self, default_model="gpt2", use_api=False):
29
  """Initialize with a default model and API preference."""
30
  super().__init__()
31
  self.default_model = default_model
@@ -33,9 +33,9 @@ class TextGenerationTool(Tool):
33
  self._pipelines = {}
34
 
35
  # Check for API token
36
- self.token = os.environ.get('HF_token')
37
- if self.token is None and use_api:
38
- print("Warning: HF_token environment variable not set. API calls will fail.")
39
 
40
  def forward(self, text: str):
41
  """Process the input prompt and generate text."""
@@ -56,31 +56,45 @@ class TextGenerationTool(Tool):
56
 
57
  def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
58
  """Generate text using a local pipeline."""
59
- # Get or create the pipeline
60
- if model_name not in self._pipelines:
61
- self._pipelines[model_name] = pipeline(
62
- "text-generation",
63
- model=model_name,
64
- token=self.token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
- generator = self._pipelines[model_name]
68
-
69
- # Generate text
70
- result = generator(
71
- prompt,
72
- max_length=max_length,
73
- num_return_sequences=1,
74
- temperature=temperature
75
- )
76
-
77
- # Extract and return the generated text
78
- if isinstance(result, list) and len(result) > 0:
79
- if isinstance(result[0], dict) and 'generated_text' in result[0]:
80
- return result[0]['generated_text']
81
- return result[0]
82
-
83
- return str(result)
84
 
85
  def _generate_via_api(self, prompt, model_name):
86
  """Generate text by calling the Hugging Face API."""
 
18
 
19
  # Available text generation models
20
  models = {
21
+ "distilgpt2": "distilgpt2", # Smaller model, may work without auth
22
+ "gpt2-small": "sshleifer/tiny-gpt2", # Tiny model for testing
23
+ "opt-125m": "facebook/opt-125m", # Small, open model
24
+ "bloom-560m": "bigscience/bloom-560m",
25
+ "gpt2": "gpt2" # Original GPT-2
26
  }
27
 
28
+ def __init__(self, default_model="distilgpt2", use_api=False):
29
  """Initialize with a default model and API preference."""
30
  super().__init__()
31
  self.default_model = default_model
 
33
  self._pipelines = {}
34
 
35
  # Check for API token
36
+ self.token = os.environ.get('HF_TOKEN') or os.environ.get('HF_token')
37
+ if self.token is None:
38
+ print("Warning: No Hugging Face token found. Set HF_TOKEN environment variable for authenticated requests.")
39
 
40
  def forward(self, text: str):
41
  """Process the input prompt and generate text."""
 
56
 
57
  def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
58
  """Generate text using a local pipeline."""
59
+ try:
60
+ # Get or create the pipeline
61
+ if model_name not in self._pipelines:
62
+ # Use token if available, otherwise try without it
63
+ try:
64
+ kwargs = {"token": self.token} if self.token else {}
65
+ self._pipelines[model_name] = pipeline(
66
+ "text-generation",
67
+ model=model_name,
68
+ **kwargs
69
+ )
70
+ except Exception as e:
71
+ print(f"Error loading model {model_name}: {str(e)}")
72
+ # Fall back to tiny-distilgpt2 if available
73
+ if model_name != "sshleifer/tiny-gpt2":
74
+ print("Falling back to tiny-gpt2 model...")
75
+ return self._generate_via_pipeline(prompt, "sshleifer/tiny-gpt2", max_length, temperature)
76
+ else:
77
+ raise e
78
+
79
+ generator = self._pipelines[model_name]
80
+
81
+ # Generate text
82
+ result = generator(
83
+ prompt,
84
+ max_length=max_length,
85
+ num_return_sequences=1,
86
+ temperature=temperature
87
  )
88
 
89
+ # Extract and return the generated text
90
+ if isinstance(result, list) and len(result) > 0:
91
+ if isinstance(result[0], dict) and 'generated_text' in result[0]:
92
+ return result[0]['generated_text']
93
+ return result[0]
94
+
95
+ return str(result)
96
+ except Exception as e:
97
+ return f"Error generating text: {str(e)}\n\nPlease try a different model or prompt."
 
 
 
 
 
 
 
 
98
 
99
  def _generate_via_api(self, prompt, model_name):
100
  """Generate text by calling the Hugging Face API."""