Librechat-copy / manager.ts
vaibhavard
origin
91a7855
import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js';
import type { JsonSchemaType } from 'librechat-data-provider';
import type { Logger } from 'winston';
import type * as t from './types/mcp';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { CONSTANTS } from './enum';
export class MCPManager {
private static instance: MCPManager | null = null;
private connections: Map<string, MCPConnection> = new Map();
private logger: Logger;
private static getDefaultLogger(): Logger {
return {
error: console.error,
warn: console.warn,
info: console.info,
debug: console.debug,
} as Logger;
}
private constructor(logger?: Logger) {
this.logger = logger || MCPManager.getDefaultLogger();
}
public static getInstance(logger?: Logger): MCPManager {
if (!MCPManager.instance) {
MCPManager.instance = new MCPManager(logger);
}
return MCPManager.instance;
}
public async initializeMCP(mcpServers: t.MCPServers): Promise<void> {
this.logger.info('[MCP] Initializing servers');
const entries = Object.entries(mcpServers);
const initializedServers = new Set();
const connectionResults = await Promise.allSettled(
entries.map(async ([serverName, config], i) => {
const connection = new MCPConnection(serverName, config, this.logger);
connection.on('connectionChange', (state) => {
this.logger.info(`[MCP][${serverName}] Connection state: ${state}`);
});
try {
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(() => reject(new Error('Connection timeout')), 1800000),
);
const connectionAttempt = this.initializeServer(connection, serverName);
await Promise.race([connectionAttempt, connectionTimeout]);
if (connection.isConnected()) {
initializedServers.add(i);
this.connections.set(serverName, connection);
const serverCapabilities = connection.client.getServerCapabilities();
this.logger.info(
`[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`,
);
if (serverCapabilities?.tools) {
const tools = await connection.client.listTools();
if (tools.tools.length) {
this.logger.info(
`[MCP][${serverName}] Available tools: ${tools.tools
.map((tool) => tool.name)
.join(', ')}`,
);
}
}
}
} catch (error) {
this.logger.error(`[MCP][${serverName}] Initialization failed`, error);
throw error;
}
}),
);
const failedConnections = connectionResults.filter(
(result): result is PromiseRejectedResult => result.status === 'rejected',
);
this.logger.info(`[MCP] Initialized ${initializedServers.size}/${entries.length} server(s)`);
if (failedConnections.length > 0) {
this.logger.warn(
`[MCP] ${failedConnections.length}/${entries.length} server(s) failed to initialize`,
);
}
entries.forEach(([serverName], index) => {
if (initializedServers.has(index)) {
this.logger.info(`[MCP][${serverName}] ✓ Initialized`);
} else {
this.logger.info(`[MCP][${serverName}] ✗ Failed`);
}
});
if (initializedServers.size === entries.length) {
this.logger.info('[MCP] All servers initialized successfully');
} else if (initializedServers.size === 0) {
this.logger.error('[MCP] No servers initialized');
}
}
private async initializeServer(connection: MCPConnection, serverName: string): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
while (attempts < maxAttempts) {
try {
await connection.connect();
if (connection.isConnected()) {
return;
}
} catch (error) {
attempts++;
if (attempts === maxAttempts) {
this.logger.error(`[MCP][${serverName}] Failed after ${maxAttempts} attempts`);
throw error;
}
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
}
}
}
public getConnection(serverName: string): MCPConnection | undefined {
return this.connections.get(serverName);
}
public getAllConnections(): Map<string, MCPConnection> {
return this.connections;
}
public async mapAvailableTools(availableTools: t.LCAvailableTools): Promise<void> {
for (const [serverName, connection] of this.connections.entries()) {
try {
if (connection.isConnected() !== true) {
this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`);
continue;
}
const tools = await connection.fetchTools();
for (const tool of tools) {
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
availableTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
}
} catch (error) {
this.logger.warn(`[MCP][${serverName}] Not connected, skipping tool fetch`);
}
}
}
public async loadManifestTools(manifestTools: t.LCToolManifest): Promise<void> {
for (const [serverName, connection] of this.connections.entries()) {
try {
if (connection.isConnected() !== true) {
this.logger.warn(`Connection ${serverName} is not connected. Skipping tool fetch.`);
continue;
}
const tools = await connection.fetchTools();
for (const tool of tools) {
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
manifestTools.push({
name: tool.name,
pluginKey,
description: tool.description ?? '',
icon: connection.iconPath,
});
}
} catch (error) {
this.logger.error(`[MCP][${serverName}] Error fetching tools`, error);
}
}
}
async callTool(
serverName: string,
toolName: string,
provider: t.Provider,
toolArguments?: Record<string, unknown>,
): Promise<t.FormattedToolResponse> {
const connection = this.connections.get(serverName);
if (!connection) {
throw new Error(
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
);
}
const result = await connection.client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
);
return formatToolContent(result, provider);
}
public async disconnectServer(serverName: string): Promise<void> {
const connection = this.connections.get(serverName);
if (connection) {
await connection.disconnect();
this.connections.delete(serverName);
}
}
public async disconnectAll(): Promise<void> {
const disconnectPromises = Array.from(this.connections.values()).map((connection) =>
connection.disconnect(),
);
await Promise.all(disconnectPromises);
this.connections.clear();
}
public static async destroyInstance(): Promise<void> {
if (MCPManager.instance) {
await MCPManager.instance.disconnectAll();
MCPManager.instance = null;
}
}
}