Spaces:
Sleeping
Sleeping
import pydot | |
import re | |
from keras.models import Model | |
from keras.layers import Layer, InputLayer | |
from pygments.lexers import graphviz | |
# May be necessary to manually add Graphviz to PATH, e.g. | |
# import os | |
# os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin' | |
def visualize_model(model, layer_labels = None, layer_colors = None, groupings = None, exclude_input_layer = False, | |
verbose = False, output_filename = 'model_graph.png'): | |
""" | |
Creates a visual graph of a keras model. There is an option to group certain layers into subgraphs | |
(argument 'groupings'). | |
Args: | |
model: A Keras Model instance | |
layer_labels (optional): List of labels for each layer. Defaults to layer names. | |
layer_colors (optional): List of colors for each layer. Defaults to white for all layers. | |
groupings (optional): Dictionary specifying groups of layers. Each key is a group name, | |
and its value is a list of layer names belonging to that group. | |
exclude_input_layer (optional): Boolean indicating whether to exclude the input layer from the graph. | |
verbose (boolean, optional): Whether to print verbose output. Defaults to False. | |
output_filename (optional): name of the output file for saving the generated graph. | |
Output: | |
Image file with name 'output_filename'. | |
""" | |
if not isinstance(model, Model): | |
raise ValueError("model should be a Keras model instance") | |
num_layers = len(model.layers) | |
# Default labels and colors if not provided | |
if not layer_labels: | |
layer_labels = [layer.name for layer in model.layers] | |
if not layer_colors: | |
default_color = 'white' | |
layer_colors = [default_color] * num_layers | |
# Create a directed graph | |
graph = pydot.Dot(graph_type = 'digraph', rankdir = 'LR') | |
# Create nodes for each layer and add to subgraphs if specified | |
subgraphs = {} | |
layer_id_map = {} | |
for i, layer in enumerate(model.layers): | |
# Exclude the input layer if specified | |
if exclude_input_layer and isinstance(layer, InputLayer): | |
continue | |
# Create a node for the layer | |
layer_id = str(id(layer)) | |
layer_id_map[layer] = layer_id | |
label = layer_labels[i] | |
color = layer_colors[i] | |
node = pydot.Node(layer_id, label = label, style = 'filled', fillcolor = color, shape = 'box') | |
# Check for groupings and add the node to the appropriate subgraph or main graph | |
group_name = None | |
if groupings: | |
for group, members in groupings.items(): | |
if layer.name in members: | |
group_name = group | |
break | |
if group_name: | |
if group_name not in subgraphs: | |
subgraph = pydot.Cluster(group_name, label = group_name, style = 'dashed', fontsize = 24) | |
subgraphs[group_name] = subgraph | |
subgraphs[group_name].add_node(node) | |
else: | |
graph.add_node(node) | |
# Add subgraphs to the main graph | |
for subgraph in subgraphs.values(): | |
graph.add_subgraph(subgraph) | |
# Add edges based on layer connections | |
for layer in model.layers: | |
if exclude_input_layer and isinstance(layer, InputLayer): | |
continue | |
# Handle custom or non-standard layers | |
if hasattr(layer, '_inbound_nodes'): | |
inbound_nodes = layer._inbound_nodes | |
else: | |
# If the layer doesn't have '_inbound_nodes', skip edge creation | |
continue | |
inbound_layers = [] | |
for inbound_node in inbound_nodes: | |
inbound_layers = inbound_node.inbound_layers | |
if not isinstance(inbound_layers, list): | |
inbound_layers = [inbound_layers] | |
for inbound_node in inbound_nodes: | |
for inbound_layer in inbound_layers: | |
if isinstance(inbound_layer, Layer) and inbound_layer in layer_id_map: | |
src_id = layer_id_map[inbound_layer] | |
dest_id = layer_id_map[layer] | |
if (re.search('sequential', inbound_layer.name, flags = re.IGNORECASE) or | |
re.search(r'operators__.getitem_[0-9]+$', inbound_layer.name, flags = re.IGNORECASE)): | |
graph.add_edge(pydot.Edge(src_id, dest_id, style = 'invis')) | |
else: | |
graph.add_edge(pydot.Edge(src_id, dest_id)) | |
if verbose: | |
print(f"Added edge from {inbound_layer.name} to {layer.name}") | |
graph.set_graph_defaults(sep = '+125,125') | |
try: | |
graph.write_png(output_filename) | |
except FileNotFoundError as e: | |
print(f'\nFailed to create network visualization using pydot and graphviz. Pleasure ensure that ' | |
'the output filename is valid, and graphviz is installed and included in the system PATH variable. ' | |
f'Original error: {e}') | |
except Exception as e: | |
print(f'\nFailed to create network visualization using pydot and graphviz. Original error: {e}') | |
else: | |
print(f'Model visualization saved to {output_filename}') | |