Source code for rustworkx.visualization.matplotlib

# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
#
# NetworkX is distributed with the 3-clause BSD license.
#
#   Copyright (C) 2004-2020, NetworkX Developers
#   Aric Hagberg <hagberg@lanl.gov>
#   Dan Schult <dschult@colgate.edu>
#   Pieter Swart <swart@lanl.gov>
#   All rights reserved.
#
#   Redistribution and use in source and binary forms, with or without
#   modification, are permitted provided that the following conditions are
#   met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above
#       copyright notice, this list of conditions and the following
#       disclaimer in the documentation and/or other materials provided
#       with the distribution.
#
#     * Neither the name of the NetworkX Developers nor the names of its
#       contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
#   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
#   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
#   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
#   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
#   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
#   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
#   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
#   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This code is forked from networkx's networkx_pylab.py module and adapted to
# work with rustworkx instead. The original source can be found at:
#
# https://github.com/networkx/networkx/blob/80b1afa2ae50314a8312998c214a8c1a356adcf1/networkx/drawing/nx_pylab.py

"""Draw a rustworkx graph with matplotlib."""

from collections.abc import Iterable
from itertools import islice, cycle
from numbers import Number

import numpy as np

import rustworkx


__all__ = [
    "mpl_draw",
]


[docs]def mpl_draw(graph, pos=None, ax=None, arrows=True, with_labels=False, **kwds): r"""Draw a graph with Matplotlib. .. note:: Matplotlib is an optional dependency and will not be installed with rustworkx by default. If you intend to use this function make sure that you install matplotlib with either ``pip install matplotlib`` or ``pip install 'rustworkx[mpl]'`` :param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a :class:`~rustworkx.PyDiGraph`. :param dict pos: An optional dictionary (or a :class:`~rustworkx.Pos2DMapping` object) with nodes as keys and positions as values. If not specified a spring layout positioning will be computed. See `layout_functions` for functions that compute node positions. :param matplotlib.Axes ax: An optional Matplotlib Axes object to draw the graph in. :param bool arrows: For :class:`~rustworkx.PyDiGraph` objects if ``True`` draw arrowheads. (defaults to ``True``) Note, that the Arrows will be the same color as edges. :param str arrowstyle: An optional string for directed graphs to choose the style of the arrowsheads. See :class:`matplotlib.patches.ArrowStyle` for more options. By default the value is set to ``'-\|>'``. :param int arrow_size: For directed graphs, choose the size of the arrow head's length and width. See :class:`matplotlib.patches.FancyArrowPatch` attribute and constructor kwarg ``mutation_scale`` for more info. Defaults to 10. :param bool with_labels: Set to ``True`` to draw labels on the nodes. Edge labels will only be drawn if the ``edge_labels`` parameter is set to a function. Defaults to ``False``. :param list node_list: An optional list of node indices in the graph to draw. If not specified all nodes will be drawn. :param list edge_list: An option list of edges in the graph to draw. If not specified all edges will be drawn :param int|list node_size: Optional size of nodes. If an array is specified it must be the same length as node_list. Defaults to 300 :param node_color: Optional node color. Can be a single color or a sequence of colors with the same length as node_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the ``cmap`` and ``vmin``,``vmax`` parameters. See :func:`matplotlib.scatter` for more details. Defaults to ``'#1f78b4'``) :param str node_shape: The optional shape node. The specification is the same as the :func:`matplotlib.pyplot.scatter` function's ``marker`` kwarg, valid options are one of ``['s', 'o', '^', '>', 'v', '<', 'd', 'p', 'h', '8']``. Defaults to ``'o'`` :param float alpha: Optional value for node and edge transparency :param matplotlib.colors.Colormap cmap: An optional Matplotlib colormap object for mapping intensities of nodes :param float vmin: Optional minimum value for node colormap scaling :param float vmax: Optional minimum value for node colormap scaling :param float|sequence linewidths: An optional line width for symbol borders. If a sequence is specified it must be the same length as node_list. Defaults to 1.0 :param float|sequence width: An optional width to use for edges. Can either be a float or sequence of floats. If a sequence is specified it must be the same length as node_list. Defaults to 1.0 :param str|sequence edge_color: color or array of colors (default='k') Edge color. Can be a single color or a sequence of colors with the same length as edge_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the ``edge_cmap`` and ``edge_vmin``, ``edge_vmax`` parameters. :param matplotlib.colors.Colormap edge_cmap: An optional Matplotlib colormap for mapping intensities of edges. :param float edge_vmin: Optional minimum value for edge colormap scaling :param float edge_vmax: Optional maximum value for node colormap scaling :param str style: An optional string to specify the edge line style. For example, ``'-'``, ``'--'``, ``'-.'``, ``':'`` or words like ``'solid'`` or ``'dashed'``. See the :class:`matplotlib.patches.FancyArrowPatch` attribute and kwarg ``linestyle`` for more details. Defaults to ``'solid'``. :param func labels: An optional callback function that will be passed a node payload and return a string label for the node. For example:: labels=str could be used to just return a string cast of the node's data payload. Or something like:: labels=lambda node: node['label'] could be used if the node payloads are dictionaries. :param func edge_labels: An optional callback function that will be passed an edge payload and return a string label for the edge. For example:: edge_labels=str could be used to just return a string cast of the edge's data payload. Or something like:: edge_labels=lambda edge: edge['label'] could be used if the edge payloads are dictionaries. If this is set edge labels will be drawn in the visualization. :param int font_size: An optional fontsize to use for text labels, By default a value of 12 is used for nodes and 10 for edges. :param str font_color: An optional font color for strings. By default ``'k'`` (ie black) is set. :param str font_weight: An optional string used to specify the font weight. By default a value of ``'normal'`` is used. :param str font_family: An optional font family to use for strings. By default ``'sans-serif'`` is used. :param str label: An optional string label to use for the graph legend. :param str connectionstyle: An optional value used to create a curved arc of rounding radius rad. For example, ``connectionstyle='arc3,rad=0.2'``. See :class:`matplotlib.patches.ConnectionStyle` and :class:`matplotlib.patches.FancyArrowPatch` for more info. By default this is set to ``"arc3"``. :returns: A matplotlib figure for the visualization if not running with an interactive backend (like in jupyter) or if ``ax`` is not set. :rtype: matplotlib.figure.Figure For Example: .. jupyter-execute:: import matplotlib.pyplot as plt import rustworkx as rx from rustworkx.visualization import mpl_draw G = rx.generators.directed_path_graph(25) mpl_draw(G) plt.draw() """ try: import matplotlib.pyplot as plt # type: ignore except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if ax is None: cf = plt.gcf() else: cf = ax.get_figure() cf.set_facecolor("w") if ax is None: if cf.axes: ax = cf.gca() else: ax = cf.add_axes((0, 0, 1, 1)) draw_graph(graph, pos=pos, ax=ax, arrows=arrows, with_labels=with_labels, **kwds) ax.set_axis_off() plt.draw_if_interactive() if not plt.isinteractive() or ax is None: return cf
def draw_graph(graph, pos=None, arrows=True, with_labels=False, **kwds): r"""Draw the graph using Matplotlib. Draw the graph with Matplotlib with options for node positions, labeling, titles, and many other drawing features. See draw() for simple drawing without labels or axes. Parameters ---------- graph: A rustworkx :class:`~rustworkx.PyDiGraph` or :class:`~rustworkx.PyGraph` pos : dictionary, optional A dictionary with nodes as keys and positions as values. If not specified a spring layout positioning will be computed. See :mod:`rustworkx.drawing.layout` for functions that compute node positions. Notes ----- For directed graphs, arrows are drawn at the head end. Arrows can be turned off with keyword arrows=False. """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e valid_node_kwds = { "node_list", "node_size", "node_color", "node_shape", "alpha", "cmap", "vmin", "vmax", "ax", "linewidths", "edgecolors", "label", } valid_edge_kwds = { "edge_list", "width", "edge_color", "style", "alpha", "arrowstyle", "arrow_size", "edge_cmap", "edge_vmin", "edge_vmax", "ax", "label", "node_size", "node_list", "node_shape", "connectionstyle", "min_source_margin", "min_target_margin", } valid_label_kwds = { "labels", "font_size", "font_color", "font_family", "font_weight", "alpha", "bbox", "ax", "horizontalalignment", "verticalalignment", } valid_edge_label_kwds = { "edge_labels", "font_size", "font_color", "font_family", "font_weight", "alpha", "bbox", "ax", "rotate", "horizontalalignment", "verticalalignment", } valid_kwds = valid_node_kwds | valid_edge_kwds | valid_label_kwds | valid_edge_label_kwds if any([k not in valid_kwds for k in kwds]): invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) raise ValueError(f"Received invalid argument(s): {invalid_args}") label_fn = kwds.pop("labels", None) if label_fn: kwds["labels"] = {x: label_fn(graph[x]) for x in graph.node_indices()} edge_label_fn = kwds.pop("edge_labels", None) if edge_label_fn: kwds["edge_labels"] = { (x[0], x[1]): edge_label_fn(x[2]) for x in graph.weighted_edge_list() } node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} if isinstance(edge_kwds.get("alpha"), list): del edge_kwds["alpha"] label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} edge_label_kwds = {k: v for k, v in kwds.items() if k in valid_edge_label_kwds} if pos is None: pos = rustworkx.spring_layout(graph) # default to spring layout draw_nodes(graph, pos, **node_kwds) draw_edges(graph, pos, arrows=arrows, **edge_kwds) if with_labels: draw_labels(graph, pos, **label_kwds) if edge_label_fn: draw_edge_labels(graph, pos, **edge_label_kwds) plt.draw_if_interactive() def draw_nodes( graph, pos, node_list=None, node_size=300, node_color="#1f78b4", node_shape="o", alpha=None, cmap=None, vmin=None, vmax=None, ax=None, linewidths=None, edgecolors=None, label=None, ): """Draw the nodes of the graph. This draws only the nodes of the graph. :param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a :class:`~rustworkx.PyDiGraph`. :param dict pos: A dictionary with nodes as keys and positions as values. Positions should be sequences of length 2. :param Axes ax: An optional Matplotlib Axes object, if specified it will draw the graph in the specified Matplotlib axes. :param list node_list: If specified only draw the specified node indices. If not specified all nodes in the graph will be drawn. :param float|array node_size: Size of nodes. If an array it must be the same length as node_list. Defaults to 300 node_color : color or array of colors (default='#1f78b4') Node color. Can be a single color or a sequence of colors with the same length as node_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the cmap and vmin,vmax parameters. See matplotlib.scatter for more details. node_shape : string (default='o') The shape of the node. Specification is as matplotlib.scatter marker, one of 'so^>v<dph8'. alpha : float or array of floats (default=None) The node transparency. This can be a single alpha value, in which case it will be applied to all the nodes of color. Otherwise, if it is an array, the elements of alpha will be applied to the colors in order (cycling through alpha multiple times if necessary). cmap : Matplotlib colormap (default=None) Colormap for mapping intensities of nodes vmin,vmax : floats or None (default=None) Minimum and maximum for node colormap scaling linewidths : [None | scalar | sequence] (default=1.0) Line width of symbol border edgecolors : [None | scalar | sequence] (default = node_color) Colors of node borders label : [None | string] Label for legend Returns ------- matplotlib.collections.PathCollection `PathCollection` of the nodes. """ try: import matplotlib as mpl import matplotlib.collections # type: ignore import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if ax is None: ax = plt.gca() if node_list is None: node_list = graph.node_indices() # empty node_list, no drawing if len(node_list) == 0: return mpl.collections.PathCollection(None) try: xy = np.asarray([pos[v] for v in node_list]) except KeyError as e: raise IndexError(f"Node {e} has no position.") from e if isinstance(alpha, Iterable): node_color = apply_alpha(node_color, alpha, node_list, cmap, vmin, vmax) alpha = None node_collection = ax.scatter( xy[:, 0], xy[:, 1], s=node_size, c=node_color, marker=node_shape, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, edgecolors=edgecolors, label=label, ) ax.tick_params( axis="both", which="both", bottom=False, left=False, labelbottom=False, labelleft=False, ) node_collection.set_zorder(2) return node_collection def draw_edges( graph, pos, edge_list=None, width=1.0, edge_color="k", style="solid", alpha=None, arrowstyle=None, arrow_size=10, edge_cmap=None, edge_vmin=None, edge_vmax=None, ax=None, arrows=True, label=None, node_size=300, node_list=None, node_shape="o", connectionstyle="arc3", min_source_margin=0, min_target_margin=0, ): r"""Draw the edges of the graph. This draws only the edges of the graph. Parameters ---------- graph: A rustworkx graph pos : dictionary A dictionary with nodes as keys and positions as values. Positions should be sequences of length 2. edge_list : collection of edge tuples (default=graph.edge_list()) Draw only specified edges width : float or array of floats (default=1.0) Line width of edges edge_color : color or array of colors (default='k') Edge color. Can be a single color or a sequence of colors with the same length as edge_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. style : string (default=solid line) Edge line style e.g.: '-', '--', '-.', ':' or words like 'solid' or 'dashed'. (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) alpha : float or None (default=None) The edge transparency edge_cmap : Matplotlib colormap, optional Colormap for mapping intensities of edges edge_vmin,edge_vmax : floats, optional Minimum and maximum for edge colormap scaling ax : Matplotlib Axes object, optional Draw the graph in the specified Matplotlib axes. arrows : bool, optional (default=True) For directed graphs, if True set default to drawing arrowheads. Otherwise set default to no arrowheads. Ignored if `arrowstyle` is set. Note: Arrows will be the same color as edges. arrowstyle : str (default='-\|>' if directed else '-') For directed graphs and `arrows==True` defaults to '-\|>', otherwise defaults to '-'. See `matplotlib.patches.ArrowStyle` for more options. arrow_size : int (default=10) For directed graphs, choose the size of the arrow head's length and width. See `matplotlib.patches.FancyArrowPatch` for attribute ``mutation_scale`` for more info. node_size : scalar or array (default=300) Size of nodes. Though the nodes are not drawn with this function, the node size is used in determining edge positioning. node_list : list, optional (default=graph.node_indices()) This provides the node order for the `node_size` array (if it is an array). node_shape : string (default='o') The marker used for nodes, used in determining edge positioning. Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'. label : None or string Label for legend min_source_margin : int (default=0) The minimum margin (gap) at the begining of the edge at the source. min_target_margin : int (default=0) The minimum margin (gap) at the end of the edge at the target. Returns ------- list of matplotlib.patches.FancyArrowPatch `FancyArrowPatch` instances of the directed edges Notes ----- For directed graphs, arrows are drawn at the head end. Arrows can be turned off with keyword arrows=False or by passing an arrowstyle without an arrow on the end. Be sure to include `node_size` as a keyword argument; arrows are drawn considering the size of nodes. """ try: import matplotlib as mpl import matplotlib.colors # type: ignore import matplotlib.patches # type: ignore import matplotlib.path # type: ignore import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if arrowstyle is None: if isinstance(graph, rustworkx.PyDiGraph) and arrows: arrowstyle = "-|>" else: arrowstyle = "-" if ax is None: ax = plt.gca() if edge_list is None: edge_list = graph.edge_list() if len(edge_list) == 0: # no edges! return [] if node_list is None: node_list = list(graph.node_indices()) # FancyArrowPatch handles color=None different from LineCollection if edge_color is None: edge_color = "k" # set edge positions edge_pos = set() for e in edge_list: edge_pos.add((tuple(pos[e[0]]), tuple(pos[e[1]]))) # Check if edge_color is an array of floats and map to edge_cmap. # This is the only case handled differently from matplotlib if ( np.iterable(edge_color) and (len(edge_color) == len(edge_pos)) and np.all([isinstance(c, Number) for c in edge_color]) ): if edge_cmap is not None: assert isinstance(edge_cmap, mpl.colors.Colormap) else: edge_cmap = plt.get_cmap() if edge_vmin is None: edge_vmin = min(edge_color) if edge_vmax is None: edge_vmax = max(edge_color) color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax) edge_color = [edge_cmap(color_normal(e)) for e in edge_color] # Note: Waiting for someone to implement arrow to intersection with # marker. Meanwhile, this works well for polygons with more than 4 # sides and circle. def to_marker_edge(marker_size, marker): if marker in "s^>v<d": # `large` markers need extra space return np.sqrt(2 * marker_size) / 2 else: return np.sqrt(marker_size) / 2 # Draw arrows with `matplotlib.patches.FancyarrowPatch` arrow_collection = [] mutation_scale = arrow_size # scale factor of arrow head base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle) # Fallback for self-loop scale. Left outside of _connectionstyle so it is # only computed once max_nodesize = np.array(node_size).max() # FancyArrowPatch doesn't handle color strings arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) for i, edge in enumerate(edge_pos): x1, y1 = edge[0][0], edge[0][1] x2, y2 = edge[1][0], edge[1][1] shrink_source = 0 # space from source to tail shrink_target = 0 # space from head to target if np.iterable(node_size): # many node sizes source, target = edge_list[i][:2] source_node_size = node_size[node_list.index(source)] target_node_size = node_size[node_list.index(target)] shrink_source = to_marker_edge(source_node_size, node_shape) shrink_target = to_marker_edge(target_node_size, node_shape) else: shrink_source = shrink_target = to_marker_edge(node_size, node_shape) if shrink_source < min_source_margin: shrink_source = min_source_margin if shrink_target < min_target_margin: shrink_target = min_target_margin if len(arrow_colors) == len(edge_pos): arrow_color = arrow_colors[i] elif len(arrow_colors) == 1: arrow_color = arrow_colors[0] else: # Cycle through colors arrow_color = arrow_colors[i % len(arrow_colors)] if np.iterable(width): if len(width) == len(edge_pos): line_width = width[i] else: line_width = width[i % len(width)] else: line_width = width # radius of edges if tuple(reversed(edge)) in edge_pos: rad = 0.25 else: rad = 0.0 arrow = mpl.patches.FancyArrowPatch( (x1, y1), (x2, y2), arrowstyle=arrowstyle, shrinkA=shrink_source, shrinkB=shrink_target, mutation_scale=mutation_scale, color=arrow_color, linewidth=line_width, connectionstyle=connectionstyle + f", rad = {rad}", linestyle=style, zorder=1, ) # arrows go behind nodes arrow_collection.append(arrow) ax.add_patch(arrow) edge_pos = np.asarray(tuple(edge_pos)) # compute view mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0])) maxx = np.amax(np.ravel(edge_pos[:, :, 0])) miny = np.amin(np.ravel(edge_pos[:, :, 1])) maxy = np.amax(np.ravel(edge_pos[:, :, 1])) w = maxx - mirustworkx h = maxy - miny def _connectionstyle(posA, posB, *args, **kwargs): # check if we need to do a self-loop if np.all(posA == posB): # Self-loops are scaled by view extent, except in cases the extent # is 0, e.g. for a single node. In this case, fall back to scaling # by the maximum node size selfloop_ht = 0.005 * max_nodesize if h == 0 else h # this is called with _screen space_ values so covert back # to data space data_loc = ax.transData.inverted().transform(posA) v_shift = 0.1 * selfloop_ht h_shift = v_shift * 0.5 # put the top of the loop first so arrow is not hidden by node path = [ # 1 data_loc + np.asarray([0, v_shift]), # 4 4 4 data_loc + np.asarray([h_shift, v_shift]), data_loc + np.asarray([h_shift, 0]), data_loc, # 4 4 4 data_loc + np.asarray([-h_shift, 0]), data_loc + np.asarray([-h_shift, v_shift]), data_loc + np.asarray([0, v_shift]), ] ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4]) # if not, fall back to the user specified behavior else: ret = base_connectionstyle(posA, posB, *args, **kwargs) return ret # update view padx, pady = 0.05 * w, 0.05 * h corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady) ax.update_datalim(corners) ax.autoscale_view() ax.tick_params( axis="both", which="both", bottom=False, left=False, labelbottom=False, labelleft=False, ) return arrow_collection def draw_labels( graph, pos, labels=None, font_size=12, font_color="k", font_family="sans-serif", font_weight="normal", alpha=None, bbox=None, horizontalalignment="center", verticalalignment="center", ax=None, clip_on=True, ): """Draw node labels on the graph. Parameters ---------- graph: A rustworkx graph pos : dictionary A dictionary with nodes as keys and positions as values. Positions should be sequences of length 2. labels : dictionary (default={n: n for n in graph}) Node labels in a dictionary of text labels keyed by node. Node-keys in labels should appear as keys in `pos`. If needed use: `{n:lab for n,lab in labels.items() if n in pos}` font_size : int (default=12) Font size for text labels font_color : string (default='k' black) Font color string font_weight : string (default='normal') Font weight font_family : string (default='sans-serif') Font family alpha : float or None (default=None) The text transparency bbox : Matplotlib bbox, (default is Matplotlib's ax.text default) Specify text box properties (e.g. shape, color etc.) for node labels. horizontalalignment : string (default='center') Horizontal alignment {'center', 'right', 'left'} verticalalignment : string (default='center') Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} ax : Matplotlib Axes object, optional Draw the graph in the specified Matplotlib axes. clip_on : bool (default=True) Turn on clipping of node labels at axis boundaries Returns ------- dict `dict` of labels keyed on the nodes """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if ax is None: ax = plt.gca() if labels is None: labels = {n: n for n in graph.node_indices()} text_items = {} # there is no text collection so we'll fake one for n, label in labels.items(): (x, y) = pos[n] if not isinstance(label, str): label = str(label) # this makes "1" and 1 labeled the same t = ax.text( x, y, label, size=font_size, color=font_color, family=font_family, weight=font_weight, alpha=alpha, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, transform=ax.transData, bbox=bbox, clip_on=clip_on, ) text_items[n] = t ax.tick_params( axis="both", which="both", bottom=False, left=False, labelbottom=False, labelleft=False, ) return text_items def draw_edge_labels( graph, pos, edge_labels=None, label_pos=0.5, font_size=10, font_color="k", font_family="sans-serif", font_weight="normal", alpha=None, bbox=None, horizontalalignment="center", verticalalignment="center", ax=None, rotate=True, clip_on=True, ): """Draw edge labels. Parameters ---------- graph: A rustworkx graph pos : dictionary A dictionary with nodes as keys and positions as values. Positions should be sequences of length 2. edge_labels : dictionary (default={}) Edge labels in a dictionary of labels keyed by edge two-tuple. Only labels for the keys in the dictionary are drawn. label_pos : float (default=0.5) Position of edge label along edge (0=head, 0.5=center, 1=tail) font_size : int (default=10) Font size for text labels font_color : string (default='k' black) Font color string font_weight : string (default='normal') Font weight font_family : string (default='sans-serif') Font family alpha : float or None (default=None) The text transparency bbox : Matplotlib bbox, optional Specify text box properties (e.g. shape, color etc.) for edge labels. Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. horizontalalignment : string (default='center') Horizontal alignment {'center', 'right', 'left'} verticalalignment : string (default='center') Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} ax : Matplotlib Axes object, optional Draw the graph in the specified Matplotlib axes. rotate : bool (deafult=True) Rotate edge labels to lie parallel to edges clip_on : bool (default=True) Turn on clipping of edge labels at axis boundaries Returns ------- dict `dict` of labels keyed by edge """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if ax is None: ax = plt.gca() if edge_labels is None: labels = {(u, v): d for u, v, d in graph.weighted_edge_list()} else: labels = edge_labels text_items = {} for (n1, n2), label in labels.items(): (x1, y1) = pos[n1] (x2, y2) = pos[n2] (x, y) = ( x1 * label_pos + x2 * (1.0 - label_pos), y1 * label_pos + y2 * (1.0 - label_pos), ) if (n2, n1) in labels.keys(): # loop dy = np.abs(y2 - y1) if n2 > n1: y -= 0.25 * dy else: y += 0.25 * dy if rotate: # in degrees angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360 # make label orientation "right-side-up" if angle > 90: angle -= 180 if angle < -90: angle += 180 # transform data coordinate angle to screen coordinate angle xy = np.array((x, y)) trans_angle = ax.transData.transform_angles(np.array((angle,)), xy.reshape((1, 2)))[0] else: trans_angle = 0.0 # use default box of white with white border if bbox is None: bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)) if not isinstance(label, str): label = str(label) # this makes "1" and 1 labeled the same t = ax.text( x, y, label, size=font_size, color=font_color, family=font_family, weight=font_weight, alpha=alpha, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, rotation=trans_angle, transform=ax.transData, bbox=bbox, zorder=1, clip_on=clip_on, ) text_items[(n1, n2)] = t ax.tick_params( axis="both", which="both", bottom=False, left=False, labelbottom=False, labelleft=False, ) return text_items def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): """Apply an alpha (or list of alphas) to the colors provided. Parameters ---------- colors : color string or array of floats (default='r') Color of element. Can be a single color format string, or a sequence of colors with the same length as node_list. If numeric values are specified they will be mapped to colors using the cmap and vmin,vmax parameters. See matplotlib.scatter for more details. alpha : float or array of floats Alpha values for elements. This can be a single alpha value, in which case it will be applied to all the elements of color. Otherwise, if it is an array, the elements of alpha will be applied to the colors in order (cycling through alpha multiple times if necessary). elem_list : array of rustworkx objects The list of elements which are being colored. These could be nodes, edges or labels. cmap : matplotlib colormap Color map for use if colors is a list of floats corresponding to points on a color mapping. vmin, vmax : float Minimum and maximum values for normalizing colors if a colormap is used Returns ------- rgba_colors : numpy ndarray Array containing RGBA format values for each of the node colours. """ try: import matplotlib as mpl import matplotlib.colors # call as mpl.colors import matplotlib.cm # type: ignore except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e # If we have been provided with a list of numbers as long as elem_list, # apply the color mapping. if len(colors) == len(elem_list) and isinstance(colors[0], Number): mapper = mpl.cm.ScalarMappable(cmap=cmap) mapper.set_clim(vmin, vmax) rgba_colors = mapper.to_rgba(colors) # Otherwise, convert colors to matplotlib's RGB using the colorConverter # object. These are converted to numpy ndarrays to be consistent with the # to_rgba method of ScalarMappable. else: try: rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)]) except ValueError: rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(color) for color in colors]) # Set the final column of the rgba_colors to have the relevant alpha values try: # If alpha is longer than the number of colors, resize to the number of # elements. Also, if rgba_colors.size (the number of elements of # rgba_colors) is the same as the number of elements, resize the array, # to avoid it being interpreted as a colormap by scatter() if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) rgba_colors[1:, 0] = rgba_colors[0, 0] rgba_colors[1:, 1] = rgba_colors[0, 1] rgba_colors[1:, 2] = rgba_colors[0, 2] rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) except TypeError: rgba_colors[:, -1] = alpha return rgba_colors