Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
HTTP Tunnel over SSH using Paramiko with improved monitoring and reliability | |
This script establishes an HTTP tunnel over SSH using the Paramiko library | |
with additional monitoring capabilities and automatic reconnection. | |
""" | |
import socket | |
import select | |
import threading | |
import logging | |
import paramiko | |
import time | |
class SSHTunnel: | |
""" | |
A class that establishes an HTTP tunnel over SSH using Paramiko with improved monitoring. | |
This creates a secure tunnel that forwards traffic from a local port | |
to a remote HTTP port through an SSH connection and automatically reconnects if needed. | |
""" | |
def __init__(self, ssh_host, remote_port=80, local_port=80, ssh_port=22, | |
username=None, password=None, key_filename=None, | |
reconnect_interval=30, keep_alive_interval=15): | |
""" | |
Initialize the SSH tunnel with the given parameters. | |
Args: | |
ssh_host (str): The SSH server hostname or IP address | |
remote_port (int): The remote HTTP port to forward to (default: 80) | |
local_port (int): The local port to listen on (default: 80) | |
ssh_port (int): The SSH server port (default: 22) | |
username (str): The SSH username | |
password (str): The SSH password (optional if using key_filename) | |
key_filename (str): Path to private key file (optional if using password) | |
reconnect_interval (int): Seconds to wait before reconnecting (default: 30) | |
keep_alive_interval (int): Seconds between keep-alive packets (default: 15) | |
""" | |
self.ssh_host = ssh_host | |
self.remote_port = remote_port | |
self.local_port = local_port | |
self.ssh_port = ssh_port | |
self.username = username | |
self.password = password | |
self.key_filename = key_filename | |
self.reconnect_interval = reconnect_interval | |
self.keep_alive_interval = keep_alive_interval | |
self.server_socket = None | |
self.transport = None | |
self.client = None | |
self.is_running = False | |
self.should_reconnect = True | |
self.last_activity_time = time.time() | |
self.connection_status = "disconnected" | |
self.connection_error = None | |
self.threads = [] | |
self.last_check_time = 0 | |
self.check_interval = 10 # Check tunnel every 10 seconds | |
# Set up logging | |
self.logger = logging.getLogger('ssh_tunnel') | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
self.logger.addHandler(handler) | |
self.logger.setLevel(logging.INFO) | |
def start(self): | |
""" | |
Start the SSH tunnel with monitoring threads. | |
This method establishes the SSH connection and starts listening for | |
incoming connections on the local port, with additional threads for monitoring. | |
Returns: | |
bool: True if the tunnel was started successfully, False otherwise. | |
""" | |
if self.is_running: | |
self.logger.warning("Tunnel is already running") | |
return True | |
self.should_reconnect = True | |
self.connect() | |
# Start the monitoring thread | |
monitor_thread = threading.Thread(target=self._monitor_connection) | |
monitor_thread.daemon = True | |
monitor_thread.start() | |
self.threads.append(monitor_thread) | |
# Start the keep-alive thread | |
keepalive_thread = threading.Thread(target=self._send_keepalive) | |
keepalive_thread.daemon = True | |
keepalive_thread.start() | |
self.threads.append(keepalive_thread) | |
return self.is_running | |
def connect(self): | |
""" | |
Establish the SSH connection and start the server socket. | |
Returns: | |
bool: True if connection was successful, False otherwise. | |
""" | |
try: | |
# Create an SSH client | |
self.client = paramiko.SSHClient() | |
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
# Connect to the SSH server | |
self.logger.info(f"Connecting to SSH server {self.ssh_host}:{self.ssh_port}") | |
connect_kwargs = { | |
'hostname': self.ssh_host, | |
'port': self.ssh_port, | |
'username': self.username, | |
} | |
if self.password: | |
connect_kwargs['password'] = self.password | |
if self.key_filename: | |
connect_kwargs['key_filename'] = self.key_filename | |
self.client.connect(**connect_kwargs) | |
# Get the transport layer | |
self.transport = self.client.get_transport() | |
self.transport.set_keepalive(self.keep_alive_interval) | |
# Start a server socket to listen for incoming connections | |
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
try: | |
self.server_socket.bind(('', self.local_port)) | |
self.server_socket.listen(5) | |
except OSError as e: | |
self.logger.error(f"Could not bind to port {self.local_port}: {str(e)}") | |
self.client.close() | |
return False | |
self.is_running = True | |
self.connection_status = "connected" | |
self.connection_error = None | |
self.last_activity_time = time.time() | |
self.logger.info(f"SSH tunnel established. Forwarding local port {self.local_port} " | |
f"to remote port {self.remote_port} via {self.ssh_host}") | |
# Start the main thread to handle incoming connections | |
connection_thread = threading.Thread(target=self._handle_connections) | |
connection_thread.daemon = True | |
connection_thread.start() | |
self.threads.append(connection_thread) | |
return True | |
except Exception as e: | |
self.logger.error(f"Failed to connect to SSH server: {str(e)}") | |
self.connection_status = "error" | |
self.connection_error = str(e) | |
self.is_running = False | |
return False | |
def _monitor_connection(self): | |
""" | |
Monitor the connection and reconnect if necessary. | |
""" | |
while self.should_reconnect: | |
current_time = time.time() | |
# Only check tunnel every check_interval seconds | |
if current_time - self.last_check_time >= self.check_interval: | |
self.last_check_time = current_time | |
# Perform active check of tunnel status | |
is_active = self._check_tunnel_active() | |
if not is_active and self.is_running: | |
self.logger.warning("Active check detected tunnel is down despite is_running=True") | |
self.stop(reconnect=True) | |
continue | |
if not self.is_running: | |
self.logger.info("Connection is down, attempting to reconnect...") | |
successful = self.connect() | |
if not successful: | |
self.logger.warning(f"Reconnection failed, waiting {self.reconnect_interval} seconds...") | |
time.sleep(self.reconnect_interval) | |
# Check if transport is still active | |
elif self.transport and not self.transport.is_active(): | |
self.logger.warning("Transport is no longer active") | |
self.stop(reconnect=True) | |
# Check for activity timeout (in case keepalive fails) | |
elif time.time() - self.last_activity_time > self.reconnect_interval * 2: | |
self.logger.warning("No activity detected, reconnecting...") | |
self.stop(reconnect=True) | |
time.sleep(5) # Check connection status every 5 seconds | |
def _check_tunnel_active(self): | |
""" | |
Actively check if the tunnel is working by attempting a simple operation. | |
Returns: | |
bool: True if the tunnel is active, False otherwise | |
""" | |
if not self.transport or not self.client: | |
return False | |
try: | |
# Try to execute a simple command to test connection | |
if self.transport.is_active(): | |
# Try to open a test channel | |
test_channel = self.transport.open_channel("session") | |
if test_channel: | |
test_channel.close() | |
return True | |
return False | |
except Exception as e: | |
self.logger.warning(f"Active tunnel check failed: {str(e)}") | |
return False | |
def _send_keepalive(self): | |
""" | |
Send keepalive packets to maintain the SSH connection. | |
""" | |
while self.should_reconnect: | |
if self.is_running and self.transport and self.transport.is_active(): | |
try: | |
self.transport.send_ignore() | |
self.last_activity_time = time.time() | |
self.logger.debug("Sent keepalive packet") | |
except Exception as e: | |
self.logger.warning(f"Failed to send keepalive: {str(e)}") | |
time.sleep(self.keep_alive_interval) | |
def check_status(self): | |
""" | |
Check the current status of the tunnel. | |
Returns: | |
dict: Status information including whether the tunnel is running, | |
connection status, and any error messages. | |
""" | |
# Do an active check if it's been more than check_interval since last check | |
current_time = time.time() | |
if current_time - self.last_check_time >= self.check_interval: | |
self.last_check_time = current_time | |
is_active = self._check_tunnel_active() | |
if not is_active and self.is_running: | |
self.logger.warning("Status check: tunnel appears down but marked as running") | |
self.is_running = False | |
self.connection_status = "disconnected" | |
self.connection_error = "Tunnel appears to be down" | |
return { | |
"is_running": self.is_running, | |
"status": self.connection_status, | |
"error": self.connection_error, | |
"last_activity": time.time() - self.last_activity_time | |
} | |
def _handle_connections(self): | |
""" | |
Handle incoming connections on the local port. | |
This method accepts incoming connections and creates a new thread | |
to handle each connection. | |
""" | |
try: | |
while self.is_running: | |
try: | |
client_socket, client_addr = self.server_socket.accept() | |
self.logger.info(f"New connection from {client_addr[0]}:{client_addr[1]}") | |
self.last_activity_time = time.time() | |
thread = threading.Thread( | |
target=self._forward_traffic, | |
args=(client_socket, client_addr) | |
) | |
thread.daemon = True | |
thread.start() | |
self.threads.append(thread) | |
except (socket.timeout, socket.error) as e: | |
if self.is_running: | |
self.logger.error(f"Socket error: {str(e)}") | |
except Exception as e: | |
if self.is_running: | |
self.logger.error(f"Error accepting connection: {str(e)}") | |
except Exception as e: | |
self.logger.error(f"Error in connection handler: {str(e)}") | |
if self.is_running: | |
self.stop(reconnect=True) | |
def _forward_traffic(self, client_socket, client_addr): | |
""" | |
Forward traffic between the local client and the remote server. | |
Args: | |
client_socket (socket.socket): The socket for the local client connection | |
client_addr (tuple): The address of the local client | |
""" | |
channel = None | |
try: | |
# Create a channel to the remote host through the SSH transport | |
remote_addr = ('127.0.0.1', self.remote_port) | |
channel = self.transport.open_channel( | |
"direct-tcpip", remote_addr, client_addr | |
) | |
if channel is None: | |
self.logger.error(f"Failed to open channel to {remote_addr[0]}:{remote_addr[1]}") | |
client_socket.close() | |
return | |
self.logger.info(f"Channel opened to {remote_addr[0]}:{remote_addr[1]}") | |
self.last_activity_time = time.time() | |
# Forward data in both directions | |
while True: | |
r, w, e = select.select([client_socket, channel], [], [], 1) | |
if client_socket in r: | |
data = client_socket.recv(4096) | |
if len(data) == 0: | |
break | |
channel.send(data) | |
self.last_activity_time = time.time() | |
if channel in r: | |
data = channel.recv(4096) | |
if len(data) == 0: | |
break | |
client_socket.send(data) | |
self.last_activity_time = time.time() | |
except Exception as e: | |
self.logger.error(f"Error forwarding traffic: {str(e)}") | |
finally: | |
client_socket.close() | |
if channel: | |
channel.close() | |
self.logger.info(f"Connection from {client_addr[0]}:{client_addr[1]} closed") | |
def stop(self, reconnect=False): | |
""" | |
Stop the SSH tunnel. | |
Args: | |
reconnect (bool): If True, the tunnel will attempt to reconnect after stopping. | |
""" | |
self.is_running = False | |
if self.server_socket: | |
try: | |
self.server_socket.close() | |
except Exception: | |
pass | |
self.server_socket = None | |
if self.transport: | |
try: | |
self.transport.close() | |
except Exception: | |
pass | |
self.transport = None | |
if self.client: | |
try: | |
self.client.close() | |
except Exception: | |
pass | |
self.client = None | |
if not reconnect: | |
self.should_reconnect = False | |
self.logger.info("SSH tunnel stopped permanently") | |
else: | |
self.logger.info("SSH tunnel stopped, will attempt to reconnect") | |
def __enter__(self): | |
""" | |
Enter the context manager. | |
Returns: | |
SSHTunnel: The SSH tunnel instance. | |
""" | |
self.start() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
""" | |
Exit the context manager. | |
""" | |
self.stop(reconnect=False) |