sivan22 commited on
Commit
eb46f15
·
verified ·
1 Parent(s): f3a8fa4

Delete chat_gemini.py

Browse files
Files changed (1) hide show
  1. chat_gemini.py +0 -260
chat_gemini.py DELETED
@@ -1,260 +0,0 @@
1
- import json
2
- from random import choices
3
- import string
4
- from langchain.tools import BaseTool
5
- import requests
6
- from dotenv import load_dotenv
7
- from dataclasses import dataclass
8
- from langchain_core.language_models.chat_models import BaseChatModel
9
- from typing import (
10
- Any,
11
- Callable,
12
- Dict,
13
- List,
14
- Literal,
15
- Mapping,
16
- Optional,
17
- Sequence,
18
- Type,
19
- Union,
20
- cast,
21
- )
22
- from langchain_core.callbacks import (
23
- CallbackManagerForLLMRun,
24
- )
25
- from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
26
- from langchain_core.exceptions import OutputParserException
27
- from langchain_core.language_models import LanguageModelInput
28
- from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
29
- from langchain_core.messages import (
30
- AIMessage,
31
- BaseMessage,
32
- HumanMessage,
33
- ToolMessage,
34
- SystemMessage,
35
- )
36
- from langchain_core.outputs import ChatGeneration, ChatResult
37
- from langchain_core.runnables import Runnable
38
- from langchain_core.tools import BaseTool
39
-
40
-
41
- class ChatGemini(BaseChatModel):
42
-
43
- @property
44
- def _llm_type(self) -> str:
45
- """Get the type of language model used by this chat model."""
46
- return "gemini"
47
-
48
- api_key :str
49
- base_url:str = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
50
- model_kwargs: Any = {}
51
-
52
- def _generate(
53
- self,
54
- messages: list[BaseMessage],
55
- stop: Optional[list[str]] = None,
56
- run_manager: Optional[CallbackManagerForLLMRun] = None,
57
- **kwargs: Any,
58
- ) -> ChatResult:
59
- """Generate a chat response using the Gemini API.
60
-
61
- This method handles both regular text responses and function calls.
62
- For function calls, it returns a ToolMessage with structured function call data
63
- that can be processed by Langchain's agent executor.
64
-
65
- Function calls are returned with:
66
- - tool_name: The name of the function to call
67
- - tool_call_id: A unique identifier for the function call (name is used as Gemini doesn't provide one)
68
- - content: The function arguments as a JSON string
69
- - additional_kwargs: Contains the full function call details
70
-
71
- Args:
72
- messages: List of input messages
73
- stop: Optional list of stop sequences
74
- run_manager: Optional callback manager
75
- **kwargs: Additional arguments passed to the Gemini API
76
-
77
- Returns:
78
- ChatResult containing either an AIMessage for text responses
79
- or a ToolMessage for function calls
80
- """
81
- # Convert messages to Gemini format
82
- gemini_messages = []
83
- system_message = None
84
- for msg in messages:
85
- # Handle both dict and LangChain message objects
86
- if isinstance(msg, BaseMessage):
87
- if isinstance(msg, SystemMessage):
88
- system_message = msg.content
89
- kwargs["system_instruction"]= {"parts":[{"text": system_message}]}
90
- continue
91
- if isinstance(msg, HumanMessage):
92
- role = "user"
93
- content = msg.content
94
- elif isinstance(msg, AIMessage):
95
- role = "model"
96
- content = msg.content
97
- elif isinstance(msg, ToolMessage):
98
- # Handle tool messages by adding them as function outputs
99
- gemini_messages.append(
100
- {
101
- "role": "model",
102
- "parts": [{
103
- "functionResponse": {
104
- "name": msg.name,
105
- "response": {"name": msg.name, "content": msg.content},
106
- }}]}
107
- )
108
- continue
109
- else:
110
- role = "user" if msg["role"] == "human" else "model"
111
- content = msg["content"]
112
-
113
- message_part = {
114
- "role": role,
115
- "parts":[{"functionCall": { "name": msg.tool_calls[0]["name"], "args": msg.tool_calls[0]["args"]}}] if isinstance(msg, AIMessage) and msg.tool_calls else [{"text": content}]
116
- }
117
- gemini_messages.append(message_part)
118
-
119
-
120
-
121
- # Prepare the request
122
- headers = {
123
- "Content-Type": "application/json"
124
- }
125
-
126
- params = {
127
- "key": self.api_key
128
- }
129
-
130
- data = {
131
- "contents": gemini_messages,
132
- "generationConfig": {
133
- "maxOutputTokens": 2048,
134
- },
135
- **kwargs
136
- }
137
-
138
-
139
- try:
140
- response = requests.post(
141
- self.base_url,
142
- headers=headers,
143
- params=params,
144
- json=data,
145
- )
146
- response.raise_for_status()
147
-
148
- result = response.json()
149
- if "candidates" in result and len(result["candidates"]) > 0 and "parts" in result["candidates"][0]["content"]:
150
- parts = result["candidates"][0]["content"]["parts"]
151
- tool_calls = []
152
- content = ""
153
- for part in parts:
154
- if "text" in part:
155
- content += part["text"]
156
- if "functionCall" in part:
157
- function_call = part["functionCall"]
158
- tool_calls.append( {
159
- "name": function_call["name"],
160
- "id": function_call["name"]+random_string(5), # Gemini doesn't provide a unique id,}
161
- "args": function_call["args"],
162
- "type": "tool_call",})
163
- # Create a proper ToolMessage with structured function call data
164
- return ChatResult(generations=[
165
- ChatGeneration(
166
- message=AIMessage(
167
- content=content,
168
- tool_calls=tool_calls,
169
- ) if len(tool_calls) > 0 else AIMessage(content=content)
170
- )
171
- ])
172
-
173
-
174
- else:
175
- raise Exception("No response generated")
176
-
177
- except Exception as e:
178
- raise Exception(f"Error calling Gemini API: {str(e)}")
179
-
180
-
181
- def bind_tools(
182
- self,
183
- tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
184
- *,
185
- tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
186
- **kwargs: Any,
187
- ) -> Runnable[LanguageModelInput, BaseMessage]:
188
- """Bind tool-like objects to this chat model.
189
-
190
-
191
- Args:
192
- tools: A list of tool definitions to bind to this chat model.
193
- Supports any tool definition handled by
194
- :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
195
- tool_choice: If provided, which tool for model to call. **This parameter
196
- is currently ignored as it is not supported by Ollama.**
197
- kwargs: Any additional parameters are passed directly to
198
- ``self.bind(**kwargs)``.
199
- """
200
-
201
- formatted_tools = {"function_declarations": [convert_to_gemini_tool(tool) for tool in tools]}
202
- return super().bind(tools=formatted_tools, **kwargs)
203
-
204
- def convert_to_gemini_tool(
205
- tool: Union[BaseTool],
206
- *,
207
- strict: Optional[bool] = None,
208
- ) -> dict[str, Any]:
209
- """Convert a tool-like object to an Gemini tool schema.
210
-
211
- Gemini tool schema reference:
212
- https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode
213
-
214
- Args:
215
- tool:
216
- BaseTool.
217
- strict:
218
- If True, model output is guaranteed to exactly match the JSON Schema
219
- provided in the function definition. If None, ``strict`` argument will not
220
- be included in tool definition.
221
-
222
- Returns:
223
- A dict version of the passed in tool which is compatible with the
224
- Gemini tool-calling API.
225
- """
226
- if isinstance(tool, BaseTool):
227
- # Extract the tool's schema
228
- schema = tool.args_schema.schema() if tool.args_schema else {"type": "object", "properties": {}}
229
-
230
- #convert to gemini schema
231
- raw_properties = schema.get("properties", {})
232
- properties = {}
233
- for key, value in raw_properties.items():
234
- properties[key] = {
235
- "type": value.get("type", "string"),
236
- "description": value.get("title", ""),
237
- }
238
-
239
-
240
- # Build the function definition
241
- function_def = {
242
- "name": tool.name,
243
- "description": tool.description,
244
- "parameters": {
245
- "type": "object",
246
- "properties": properties,
247
- "required": schema.get("required", [])
248
- }
249
- }
250
-
251
- if strict is not None:
252
- function_def["strict"] = strict
253
-
254
- return function_def
255
- else:
256
- raise ValueError(f"Unsupported tool type: {type(tool)}")
257
-
258
- def random_string(length: int) -> str:
259
- return ''.join(choices(string.ascii_letters + string.digits, k=length))
260
-