Spaces:
Runtime error
Runtime error
File size: 5,094 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import logging
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
logger = logging.getLogger(__name__)
def clean_url(url: str) -> str:
"""Remove trailing slash and /api from url if present."""
if url.endswith("/api"):
return url[:-4]
elif url.endswith("/"):
return url[:-1]
else:
return url
class KoboldApiLLM(LLM):
"""Kobold API language model.
It includes several fields that can be used to control the text generation process.
To use this class, instantiate it with the required parameters and call it with a
prompt to generate text. For example:
kobold = KoboldApiLLM(endpoint="http://localhost:5000")
result = kobold("Write a story about a dragon.")
This will send a POST request to the Kobold API with the provided prompt and
generate text.
"""
endpoint: str
"""The API endpoint to use for generating text."""
use_story: Optional[bool] = False
""" Whether or not to use the story from the KoboldAI GUI when generating text. """
use_authors_note: Optional[bool] = False
"""Whether to use the author's note from the KoboldAI GUI when generating text.
This has no effect unless use_story is also enabled.
"""
use_world_info: Optional[bool] = False
"""Whether to use the world info from the KoboldAI GUI when generating text."""
use_memory: Optional[bool] = False
"""Whether to use the memory from the KoboldAI GUI when generating text."""
max_context_length: Optional[int] = 1600
"""Maximum number of tokens to send to the model.
minimum: 1
"""
max_length: Optional[int] = 80
"""Number of tokens to generate.
maximum: 512
minimum: 1
"""
rep_pen: Optional[float] = 1.12
"""Base repetition penalty value.
minimum: 1
"""
rep_pen_range: Optional[int] = 1024
"""Repetition penalty range.
minimum: 0
"""
rep_pen_slope: Optional[float] = 0.9
"""Repetition penalty slope.
minimum: 0
"""
temperature: Optional[float] = 0.6
"""Temperature value.
exclusiveMinimum: 0
"""
tfs: Optional[float] = 0.9
"""Tail free sampling value.
maximum: 1
minimum: 0
"""
top_a: Optional[float] = 0.9
"""Top-a sampling value.
minimum: 0
"""
top_p: Optional[float] = 0.95
"""Top-p sampling value.
maximum: 1
minimum: 0
"""
top_k: Optional[int] = 0
"""Top-k sampling value.
minimum: 0
"""
typical: Optional[float] = 0.5
"""Typical sampling value.
maximum: 1
minimum: 0
"""
@property
def _llm_type(self) -> str:
return "koboldai"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain_community.llms import KoboldApiLLM
llm = KoboldApiLLM(endpoint="http://localhost:5000")
llm.invoke("Write a story about dragons.")
"""
data: Dict[str, Any] = {
"prompt": prompt,
"use_story": self.use_story,
"use_authors_note": self.use_authors_note,
"use_world_info": self.use_world_info,
"use_memory": self.use_memory,
"max_context_length": self.max_context_length,
"max_length": self.max_length,
"rep_pen": self.rep_pen,
"rep_pen_range": self.rep_pen_range,
"rep_pen_slope": self.rep_pen_slope,
"temperature": self.temperature,
"tfs": self.tfs,
"top_a": self.top_a,
"top_p": self.top_p,
"top_k": self.top_k,
"typical": self.typical,
}
if stop is not None:
data["stop_sequence"] = stop
response = requests.post(
f"{clean_url(self.endpoint)}/api/v1/generate", json=data
)
response.raise_for_status()
json_response = response.json()
if (
"results" in json_response
and len(json_response["results"]) > 0
and "text" in json_response["results"][0]
):
text = json_response["results"][0]["text"].strip()
if stop is not None:
for sequence in stop:
if text.endswith(sequence):
text = text[: -len(sequence)].rstrip()
return text
else:
raise ValueError(
f"Unexpected response format from Kobold API: {json_response}"
)
|