Spaces:
Runtime error
Runtime error
File size: 4,241 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from functools import partial
from typing import Any, Dict, List, Optional, Sequence
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import root_validator
class CTransformers(LLM):
"""C Transformers LLM models.
To use, you should have the ``ctransformers`` python package installed.
See https://github.com/marella/ctransformers
Example:
.. code-block:: python
from langchain_community.llms import CTransformers
llm = CTransformers(model="/path/to/ggml-gpt-2.bin", model_type="gpt2")
"""
client: Any #: :meta private:
model: str
"""The path to a model file or directory or the name of a Hugging Face Hub
model repo."""
model_type: Optional[str] = None
"""The model type."""
model_file: Optional[str] = None
"""The name of the model file in repo or directory."""
config: Optional[Dict[str, Any]] = None
"""The config parameters.
See https://github.com/marella/ctransformers#config"""
lib: Optional[str] = None
"""The path to a shared library or one of `avx2`, `avx`, `basic`."""
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"model_type": self.model_type,
"model_file": self.model_file,
"config": self.config,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ctransformers"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that ``ctransformers`` package is installed."""
try:
from ctransformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"Could not import `ctransformers` package. "
"Please install it with `pip install ctransformers`"
)
config = values["config"] or {}
values["client"] = AutoModelForCausalLM.from_pretrained(
values["model"],
model_type=values["model_type"],
model_file=values["model_file"],
lib=values["lib"],
**config,
)
return values
def _call(
self,
prompt: str,
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate text from a prompt.
Args:
prompt: The prompt to generate text from.
stop: A list of sequences to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
response = llm.invoke("Tell me a joke.")
"""
text = []
_run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
for chunk in self.client(prompt, stop=stop, stream=True):
text.append(chunk)
_run_manager.on_llm_new_token(chunk, verbose=self.verbose)
return "".join(text)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Asynchronous Call out to CTransformers generate method.
Very helpful when streaming (like with websockets!)
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = llm.invoke("Once upon a time, ")
"""
text_callback = None
if run_manager:
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
text = ""
for token in self.client(prompt, stop=stop, stream=True):
if text_callback:
await text_callback(token)
text += token
return text
|