# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
#
# NetworkX is distributed with the 3-clause BSD license.
#
# Copyright (C) 2004-2020, NetworkX Developers
# Aric Hagberg <hagberg@lanl.gov>
# Dan Schult <dschult@colgate.edu>
# Pieter Swart <swart@lanl.gov>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# * Neither the name of the NetworkX Developers nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This code is forked from networkx's networkx_pylab.py module and adapted to
# work with rustworkx instead. The original source can be found at:
#
# https://github.com/networkx/networkx/blob/80b1afa2ae50314a8312998c214a8c1a356adcf1/networkx/drawing/nx_pylab.py
"""Draw a rustworkx graph with matplotlib."""
from collections.abc import Iterable
from itertools import islice, cycle
from numbers import Number
import numpy as np
import rustworkx
__all__ = [
"mpl_draw",
]
[docs]def mpl_draw(graph, pos=None, ax=None, arrows=True, with_labels=False, **kwds):
r"""Draw a graph with Matplotlib.
.. note::
Matplotlib is an optional dependency and will not be installed with
rustworkx by default. If you intend to use this function make sure that
you install matplotlib with either ``pip install matplotlib`` or
``pip install 'rustworkx[mpl]'``
:param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a
:class:`~rustworkx.PyDiGraph`.
:param dict pos: An optional dictionary (or
a :class:`~rustworkx.Pos2DMapping` object) with nodes as keys and
positions as values. If not specified a spring layout positioning will
be computed. See `layout_functions` for functions that compute
node positions.
:param matplotlib.Axes ax: An optional Matplotlib Axes object to draw the
graph in.
:param bool arrows: For :class:`~rustworkx.PyDiGraph` objects if ``True``
draw arrowheads. (defaults to ``True``) Note, that the Arrows will
be the same color as edges.
:param str arrowstyle: An optional string for directed graphs to choose
the style of the arrowsheads. See
:class:`matplotlib.patches.ArrowStyle` for more options. By default the
value is set to ``'-\|>'``.
:param int arrow_size: For directed graphs, choose the size of the arrow
head's length and width. See
:class:`matplotlib.patches.FancyArrowPatch` attribute and constructor
kwarg ``mutation_scale`` for more info. Defaults to 10.
:param bool with_labels: Set to ``True`` to draw labels on the nodes. Edge
labels will only be drawn if the ``edge_labels`` parameter is set to a
function. Defaults to ``False``.
:param list node_list: An optional list of node indices in the graph to
draw. If not specified all nodes will be drawn.
:param list edge_list: An option list of edges in the graph to draw. If not
specified all edges will be drawn
:param int|list node_size: Optional size of nodes. If an array is
specified it must be the same length as node_list. Defaults to 300
:param node_color: Optional node color. Can be a single color or
a sequence of colors with the same length as node_list. Color can be
string or rgb (or rgba) tuple of floats from 0-1. If numeric values
are specified they will be mapped to colors using the ``cmap`` and
``vmin``,``vmax`` parameters. See :func:`matplotlib.scatter` for more
details. Defaults to ``'#1f78b4'``)
:param str node_shape: The optional shape node. The specification is the
same as the :func:`matplotlib.pyplot.scatter` function's ``marker``
kwarg, valid options are one of
``['s', 'o', '^', '>', 'v', '<', 'd', 'p', 'h', '8']``. Defaults to
``'o'``
:param float alpha: Optional value for node and edge transparency
:param matplotlib.colors.Colormap cmap: An optional Matplotlib colormap
object for mapping intensities of nodes
:param float vmin: Optional minimum value for node colormap scaling
:param float vmax: Optional minimum value for node colormap scaling
:param float|sequence linewidths: An optional line width for symbol
borders. If a sequence is specified it must be the same length as
node_list. Defaults to 1.0
:param float|sequence width: An optional width to use for edges. Can
either be a float or sequence of floats. If a sequence is specified
it must be the same length as node_list. Defaults to 1.0
:param str|sequence edge_color: color or array of colors (default='k')
Edge color. Can be a single color or a sequence of colors with the same
length as edge_list. Color can be string or rgb (or rgba) tuple of
floats from 0-1. If numeric values are specified they will be
mapped to colors using the ``edge_cmap`` and ``edge_vmin``,
``edge_vmax`` parameters.
:param matplotlib.colors.Colormap edge_cmap: An optional Matplotlib
colormap for mapping intensities of edges.
:param float edge_vmin: Optional minimum value for edge colormap scaling
:param float edge_vmax: Optional maximum value for node colormap scaling
:param str style: An optional string to specify the edge line style.
For example, ``'-'``, ``'--'``, ``'-.'``, ``':'`` or words like
``'solid'`` or ``'dashed'``. See the
:class:`matplotlib.patches.FancyArrowPatch` attribute and kwarg
``linestyle`` for more details. Defaults to ``'solid'``.
:param func labels: An optional callback function that will be passed a
node payload and return a string label for the node. For example::
labels=str
could be used to just return a string cast of the node's data payload.
Or something like::
labels=lambda node: node['label']
could be used if the node payloads are dictionaries.
:param func edge_labels: An optional callback function that will be passed
an edge payload and return a string label for the edge. For example::
edge_labels=str
could be used to just return a string cast of the edge's data payload.
Or something like::
edge_labels=lambda edge: edge['label']
could be used if the edge payloads are dictionaries. If this is set
edge labels will be drawn in the visualization.
:param int font_size: An optional fontsize to use for text labels, By
default a value of 12 is used for nodes and 10 for edges.
:param str font_color: An optional font color for strings. By default
``'k'`` (ie black) is set.
:param str font_weight: An optional string used to specify the font weight.
By default a value of ``'normal'`` is used.
:param str font_family: An optional font family to use for strings. By
default ``'sans-serif'`` is used.
:param str label: An optional string label to use for the graph legend.
:param str connectionstyle: An optional value used to create a curved arc
of rounding radius rad. For example,
``connectionstyle='arc3,rad=0.2'``. See
:class:`matplotlib.patches.ConnectionStyle` and
:class:`matplotlib.patches.FancyArrowPatch` for more info. By default
this is set to ``"arc3"``.
:returns: A matplotlib figure for the visualization if not running with an
interactive backend (like in jupyter) or if ``ax`` is not set.
:rtype: matplotlib.figure.Figure
For Example:
.. jupyter-execute::
import matplotlib.pyplot as plt
import rustworkx as rx
from rustworkx.visualization import mpl_draw
G = rx.generators.directed_path_graph(25)
mpl_draw(G)
plt.draw()
"""
try:
import matplotlib.pyplot as plt # type: ignore
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
if ax is None:
cf = plt.gcf()
else:
cf = ax.get_figure()
cf.set_facecolor("w")
if ax is None:
if cf.axes:
ax = cf.gca()
else:
ax = cf.add_axes((0, 0, 1, 1))
draw_graph(graph, pos=pos, ax=ax, arrows=arrows, with_labels=with_labels, **kwds)
ax.set_axis_off()
plt.draw_if_interactive()
if not plt.isinteractive() or ax is None:
return cf
def draw_graph(graph, pos=None, arrows=True, with_labels=False, **kwds):
r"""Draw the graph using Matplotlib.
Draw the graph with Matplotlib with options for node positions,
labeling, titles, and many other drawing features.
See draw() for simple drawing without labels or axes.
Parameters
----------
graph: A rustworkx :class:`~rustworkx.PyDiGraph` or
:class:`~rustworkx.PyGraph`
pos : dictionary, optional
A dictionary with nodes as keys and positions as values.
If not specified a spring layout positioning will be computed.
See :mod:`rustworkx.drawing.layout` for functions that
compute node positions.
Notes
-----
For directed graphs, arrows are drawn at the head end. Arrows can be
turned off with keyword arrows=False.
"""
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
valid_node_kwds = {
"node_list",
"node_size",
"node_color",
"node_shape",
"alpha",
"cmap",
"vmin",
"vmax",
"ax",
"linewidths",
"edgecolors",
"label",
}
valid_edge_kwds = {
"edge_list",
"width",
"edge_color",
"style",
"alpha",
"arrowstyle",
"arrow_size",
"edge_cmap",
"edge_vmin",
"edge_vmax",
"ax",
"label",
"node_size",
"node_list",
"node_shape",
"connectionstyle",
"min_source_margin",
"min_target_margin",
}
valid_label_kwds = {
"labels",
"font_size",
"font_color",
"font_family",
"font_weight",
"alpha",
"bbox",
"ax",
"horizontalalignment",
"verticalalignment",
}
valid_edge_label_kwds = {
"edge_labels",
"font_size",
"font_color",
"font_family",
"font_weight",
"alpha",
"bbox",
"ax",
"rotate",
"horizontalalignment",
"verticalalignment",
}
valid_kwds = valid_node_kwds | valid_edge_kwds | valid_label_kwds | valid_edge_label_kwds
if any([k not in valid_kwds for k in kwds]):
invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
raise ValueError(f"Received invalid argument(s): {invalid_args}")
label_fn = kwds.pop("labels", None)
if label_fn:
kwds["labels"] = {x: label_fn(graph[x]) for x in graph.node_indices()}
edge_label_fn = kwds.pop("edge_labels", None)
if edge_label_fn:
kwds["edge_labels"] = {
(x[0], x[1]): edge_label_fn(x[2]) for x in graph.weighted_edge_list()
}
node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
if isinstance(edge_kwds.get("alpha"), list):
del edge_kwds["alpha"]
label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
edge_label_kwds = {k: v for k, v in kwds.items() if k in valid_edge_label_kwds}
if pos is None:
pos = rustworkx.spring_layout(graph) # default to spring layout
draw_nodes(graph, pos, **node_kwds)
draw_edges(graph, pos, arrows=arrows, **edge_kwds)
if with_labels:
draw_labels(graph, pos, **label_kwds)
if edge_label_fn:
draw_edge_labels(graph, pos, **edge_label_kwds)
plt.draw_if_interactive()
def draw_nodes(
graph,
pos,
node_list=None,
node_size=300,
node_color="#1f78b4",
node_shape="o",
alpha=None,
cmap=None,
vmin=None,
vmax=None,
ax=None,
linewidths=None,
edgecolors=None,
label=None,
):
"""Draw the nodes of the graph.
This draws only the nodes of the graph.
:param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a
:class:`~rustworkx.PyDiGraph`.
:param dict pos: A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
:param Axes ax: An optional Matplotlib Axes object, if specified it will
draw the graph in the specified Matplotlib axes.
:param list node_list: If specified only draw the specified node indices.
If not specified all nodes in the graph will be drawn.
:param float|array node_size: Size of nodes. If an array it must be the
same length as node_list. Defaults to 300
node_color : color or array of colors (default='#1f78b4')
Node color. Can be a single color or a sequence of colors with the same
length as node_list. Color can be string or rgb (or rgba) tuple of
floats from 0-1. If numeric values are specified they will be
mapped to colors using the cmap and vmin,vmax parameters. See
matplotlib.scatter for more details.
node_shape : string (default='o')
The shape of the node. Specification is as matplotlib.scatter
marker, one of 'so^>v<dph8'.
alpha : float or array of floats (default=None)
The node transparency. This can be a single alpha value,
in which case it will be applied to all the nodes of color. Otherwise,
if it is an array, the elements of alpha will be applied to the colors
in order (cycling through alpha multiple times if necessary).
cmap : Matplotlib colormap (default=None)
Colormap for mapping intensities of nodes
vmin,vmax : floats or None (default=None)
Minimum and maximum for node colormap scaling
linewidths : [None | scalar | sequence] (default=1.0)
Line width of symbol border
edgecolors : [None | scalar | sequence] (default = node_color)
Colors of node borders
label : [None | string]
Label for legend
Returns
-------
matplotlib.collections.PathCollection
`PathCollection` of the nodes.
"""
try:
import matplotlib as mpl
import matplotlib.collections # type: ignore
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
if ax is None:
ax = plt.gca()
if node_list is None:
node_list = graph.node_indices()
# empty node_list, no drawing
if len(node_list) == 0:
return mpl.collections.PathCollection(None)
try:
xy = np.asarray([pos[v] for v in node_list])
except KeyError as e:
raise IndexError(f"Node {e} has no position.") from e
if isinstance(alpha, Iterable):
node_color = apply_alpha(node_color, alpha, node_list, cmap, vmin, vmax)
alpha = None
node_collection = ax.scatter(
xy[:, 0],
xy[:, 1],
s=node_size,
c=node_color,
marker=node_shape,
cmap=cmap,
vmin=vmin,
vmax=vmax,
alpha=alpha,
linewidths=linewidths,
edgecolors=edgecolors,
label=label,
)
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
node_collection.set_zorder(2)
return node_collection
def draw_edges(
graph,
pos,
edge_list=None,
width=1.0,
edge_color="k",
style="solid",
alpha=None,
arrowstyle=None,
arrow_size=10,
edge_cmap=None,
edge_vmin=None,
edge_vmax=None,
ax=None,
arrows=True,
label=None,
node_size=300,
node_list=None,
node_shape="o",
connectionstyle="arc3",
min_source_margin=0,
min_target_margin=0,
):
r"""Draw the edges of the graph.
This draws only the edges of the graph.
Parameters
----------
graph: A rustworkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
edge_list : collection of edge tuples (default=graph.edge_list())
Draw only specified edges
width : float or array of floats (default=1.0)
Line width of edges
edge_color : color or array of colors (default='k')
Edge color. Can be a single color or a sequence of colors with the same
length as edge_list. Color can be string or rgb (or rgba) tuple of
floats from 0-1. If numeric values are specified they will be
mapped to colors using the edge_cmap and edge_vmin,edge_vmax
parameters.
style : string (default=solid line)
Edge line style e.g.: '-', '--', '-.', ':'
or words like 'solid' or 'dashed'.
(See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
alpha : float or None (default=None)
The edge transparency
edge_cmap : Matplotlib colormap, optional
Colormap for mapping intensities of edges
edge_vmin,edge_vmax : floats, optional
Minimum and maximum for edge colormap scaling
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
arrows : bool, optional (default=True)
For directed graphs, if True set default to drawing arrowheads.
Otherwise set default to no arrowheads. Ignored if `arrowstyle` is set.
Note: Arrows will be the same color as edges.
arrowstyle : str (default='-\|>' if directed else '-')
For directed graphs and `arrows==True` defaults to '-\|>',
otherwise defaults to '-'.
See `matplotlib.patches.ArrowStyle` for more options.
arrow_size : int (default=10)
For directed graphs, choose the size of the arrow head's length and
width. See `matplotlib.patches.FancyArrowPatch` for attribute
``mutation_scale`` for more info.
node_size : scalar or array (default=300)
Size of nodes. Though the nodes are not drawn with this function, the
node size is used in determining edge positioning.
node_list : list, optional (default=graph.node_indices())
This provides the node order for the `node_size` array (if it is an
array).
node_shape : string (default='o')
The marker used for nodes, used in determining edge positioning.
Specification is as a `matplotlib.markers` marker, e.g. one of
'so^>v<dph8'.
label : None or string
Label for legend
min_source_margin : int (default=0)
The minimum margin (gap) at the beginning of the edge at the source.
min_target_margin : int (default=0)
The minimum margin (gap) at the end of the edge at the target.
Returns
-------
list of matplotlib.patches.FancyArrowPatch
`FancyArrowPatch` instances of the directed edges
Notes
-----
For directed graphs, arrows are drawn at the head end. Arrows can be
turned off with keyword arrows=False or by passing an arrowstyle without
an arrow on the end.
Be sure to include `node_size` as a keyword argument; arrows are
drawn considering the size of nodes.
"""
try:
import matplotlib as mpl
import matplotlib.colors # type: ignore
import matplotlib.patches # type: ignore
import matplotlib.path # type: ignore
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
if arrowstyle is None:
if isinstance(graph, rustworkx.PyDiGraph) and arrows:
arrowstyle = "-|>"
else:
arrowstyle = "-"
if ax is None:
ax = plt.gca()
if edge_list is None:
edge_list = graph.edge_list()
if len(edge_list) == 0: # no edges!
return []
if node_list is None:
node_list = list(graph.node_indices())
# FancyArrowPatch handles color=None different from LineCollection
if edge_color is None:
edge_color = "k"
# set edge positions
edge_pos_keys = dict()
for e in edge_list:
edge_pos_keys[(tuple(pos[e[0]]), tuple(pos[e[1]]))] = None
edge_pos = edge_pos_keys.keys()
# Check if edge_color is an array of floats and map to edge_cmap.
# This is the only case handled differently from matplotlib
if (
np.iterable(edge_color)
and (len(edge_color) == len(edge_pos))
and np.all([isinstance(c, Number) for c in edge_color])
):
if edge_cmap is not None:
assert isinstance(edge_cmap, mpl.colors.Colormap)
else:
edge_cmap = plt.get_cmap()
if edge_vmin is None:
edge_vmin = min(edge_color)
if edge_vmax is None:
edge_vmax = max(edge_color)
color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
# Note: Waiting for someone to implement arrow to intersection with
# marker. Meanwhile, this works well for polygons with more than 4
# sides and circle.
def to_marker_edge(marker_size, marker):
if marker in "s^>v<d": # `large` markers need extra space
return np.sqrt(2 * marker_size) / 2
else:
return np.sqrt(marker_size) / 2
# Draw arrows with `matplotlib.patches.FancyarrowPatch`
arrow_collection = []
mutation_scale = arrow_size # scale factor of arrow head
base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle)
# Fallback for self-loop scale. Left outside of _connectionstyle so it is
# only computed once
max_nodesize = np.array(node_size).max()
# FancyArrowPatch doesn't handle color strings
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
for i, edge in enumerate(edge_pos):
x1, y1 = edge[0][0], edge[0][1]
x2, y2 = edge[1][0], edge[1][1]
shrink_source = 0 # space from source to tail
shrink_target = 0 # space from head to target
if np.iterable(node_size): # many node sizes
source, target = edge_list[i][:2]
source_node_size = node_size[node_list.index(source)]
target_node_size = node_size[node_list.index(target)]
shrink_source = to_marker_edge(source_node_size, node_shape)
shrink_target = to_marker_edge(target_node_size, node_shape)
else:
shrink_source = shrink_target = to_marker_edge(node_size, node_shape)
if shrink_source < min_source_margin:
shrink_source = min_source_margin
if shrink_target < min_target_margin:
shrink_target = min_target_margin
if len(arrow_colors) == len(edge_pos):
arrow_color = arrow_colors[i]
elif len(arrow_colors) == 1:
arrow_color = arrow_colors[0]
else: # Cycle through colors
arrow_color = arrow_colors[i % len(arrow_colors)]
if np.iterable(width):
if len(width) == len(edge_pos):
line_width = width[i]
else:
line_width = width[i % len(width)]
else:
line_width = width
# radius of edges
if tuple(reversed(edge)) in edge_pos:
rad = 0.25
else:
rad = 0.0
arrow = mpl.patches.FancyArrowPatch(
(x1, y1),
(x2, y2),
arrowstyle=arrowstyle,
shrinkA=shrink_source,
shrinkB=shrink_target,
mutation_scale=mutation_scale,
color=arrow_color,
linewidth=line_width,
connectionstyle=connectionstyle + f", rad = {rad}",
linestyle=style,
zorder=1,
) # arrows go behind nodes
arrow_collection.append(arrow)
ax.add_patch(arrow)
edge_pos = np.asarray(tuple(edge_pos))
# compute view
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
w = maxx - mirustworkx
h = maxy - miny
def _connectionstyle(posA, posB, *args, **kwargs):
# check if we need to do a self-loop
if np.all(posA == posB):
# Self-loops are scaled by view extent, except in cases the extent
# is 0, e.g. for a single node. In this case, fall back to scaling
# by the maximum node size
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
# this is called with _screen space_ values so covert back
# to data space
data_loc = ax.transData.inverted().transform(posA)
v_shift = 0.1 * selfloop_ht
h_shift = v_shift * 0.5
# put the top of the loop first so arrow is not hidden by node
path = [
# 1
data_loc + np.asarray([0, v_shift]),
# 4 4 4
data_loc + np.asarray([h_shift, v_shift]),
data_loc + np.asarray([h_shift, 0]),
data_loc,
# 4 4 4
data_loc + np.asarray([-h_shift, 0]),
data_loc + np.asarray([-h_shift, v_shift]),
data_loc + np.asarray([0, v_shift]),
]
ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
# if not, fall back to the user specified behavior
else:
ret = base_connectionstyle(posA, posB, *args, **kwargs)
return ret
# update view
padx, pady = 0.05 * w, 0.05 * h
corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady)
ax.update_datalim(corners)
ax.autoscale_view()
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return arrow_collection
def draw_labels(
graph,
pos,
labels=None,
font_size=12,
font_color="k",
font_family="sans-serif",
font_weight="normal",
alpha=None,
bbox=None,
horizontalalignment="center",
verticalalignment="center",
ax=None,
clip_on=True,
):
"""Draw node labels on the graph.
Parameters
----------
graph: A rustworkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
labels : dictionary (default={n: n for n in graph})
Node labels in a dictionary of text labels keyed by node.
Node-keys in labels should appear as keys in `pos`.
If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
font_size : int (default=12)
Font size for text labels
font_color : string (default='k' black)
Font color string
font_weight : string (default='normal')
Font weight
font_family : string (default='sans-serif')
Font family
alpha : float or None (default=None)
The text transparency
bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
Specify text box properties (e.g. shape, color etc.) for node labels.
horizontalalignment : string (default='center')
Horizontal alignment {'center', 'right', 'left'}
verticalalignment : string (default='center')
Vertical alignment {'center', 'top', 'bottom', 'baseline',
'center_baseline'}
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
clip_on : bool (default=True)
Turn on clipping of node labels at axis boundaries
Returns
-------
dict
`dict` of labels keyed on the nodes
"""
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
if ax is None:
ax = plt.gca()
if labels is None:
labels = {n: n for n in graph.node_indices()}
text_items = {} # there is no text collection so we'll fake one
for n, label in labels.items():
(x, y) = pos[n]
if not isinstance(label, str):
label = str(label) # this makes "1" and 1 labeled the same
t = ax.text(
x,
y,
label,
size=font_size,
color=font_color,
family=font_family,
weight=font_weight,
alpha=alpha,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
transform=ax.transData,
bbox=bbox,
clip_on=clip_on,
)
text_items[n] = t
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return text_items
def draw_edge_labels(
graph,
pos,
edge_labels=None,
label_pos=0.5,
font_size=10,
font_color="k",
font_family="sans-serif",
font_weight="normal",
alpha=None,
bbox=None,
horizontalalignment="center",
verticalalignment="center",
ax=None,
rotate=True,
clip_on=True,
):
"""Draw edge labels.
Parameters
----------
graph: A rustworkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
edge_labels : dictionary (default={})
Edge labels in a dictionary of labels keyed by edge two-tuple.
Only labels for the keys in the dictionary are drawn.
label_pos : float (default=0.5)
Position of edge label along edge (0=head, 0.5=center, 1=tail)
font_size : int (default=10)
Font size for text labels
font_color : string (default='k' black)
Font color string
font_weight : string (default='normal')
Font weight
font_family : string (default='sans-serif')
Font family
alpha : float or None (default=None)
The text transparency
bbox : Matplotlib bbox, optional
Specify text box properties (e.g. shape, color etc.) for edge labels.
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
horizontalalignment : string (default='center')
Horizontal alignment {'center', 'right', 'left'}
verticalalignment : string (default='center')
Vertical alignment {'center', 'top', 'bottom', 'baseline',
'center_baseline'}
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
rotate : bool (default=True)
Rotate edge labels to lie parallel to edges
clip_on : bool (default=True)
Turn on clipping of edge labels at axis boundaries
Returns
-------
dict
`dict` of labels keyed by edge
"""
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
if ax is None:
ax = plt.gca()
if edge_labels is None:
labels = {(u, v): d for u, v, d in graph.weighted_edge_list()}
else:
labels = edge_labels
text_items = {}
for (n1, n2), label in labels.items():
(x1, y1) = pos[n1]
(x2, y2) = pos[n2]
(x, y) = (
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
if (n2, n1) in labels.keys(): # loop
dy = np.abs(y2 - y1)
if n2 > n1:
y -= 0.25 * dy
else:
y += 0.25 * dy
if rotate:
# in degrees
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < -90:
angle += 180
# transform data coordinate angle to screen coordinate angle
xy = np.array((x, y))
trans_angle = ax.transData.transform_angles(np.array((angle,)), xy.reshape((1, 2)))[0]
else:
trans_angle = 0.0
# use default box of white with white border
if bbox is None:
bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
if not isinstance(label, str):
label = str(label) # this makes "1" and 1 labeled the same
t = ax.text(
x,
y,
label,
size=font_size,
color=font_color,
family=font_family,
weight=font_weight,
alpha=alpha,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
rotation=trans_angle,
transform=ax.transData,
bbox=bbox,
zorder=1,
clip_on=clip_on,
)
text_items[(n1, n2)] = t
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return text_items
def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
"""Apply an alpha (or list of alphas) to the colors provided.
Parameters
----------
colors : color string or array of floats (default='r')
Color of element. Can be a single color format string,
or a sequence of colors with the same length as node_list.
If numeric values are specified they will be mapped to
colors using the cmap and vmin,vmax parameters. See
matplotlib.scatter for more details.
alpha : float or array of floats
Alpha values for elements. This can be a single alpha value, in
which case it will be applied to all the elements of color. Otherwise,
if it is an array, the elements of alpha will be applied to the colors
in order (cycling through alpha multiple times if necessary).
elem_list : array of rustworkx objects
The list of elements which are being colored. These could be nodes,
edges or labels.
cmap : matplotlib colormap
Color map for use if colors is a list of floats corresponding to points
on a color mapping.
vmin, vmax : float
Minimum and maximum values for normalizing colors if a colormap is used
Returns
-------
rgba_colors : numpy ndarray
Array containing RGBA format values for each of the node colours.
"""
try:
import matplotlib as mpl
import matplotlib.colors # call as mpl.colors
import matplotlib.cm # type: ignore
except ImportError as e:
raise ImportError(
"matplotlib needs to be installed prior to running "
"rustworkx.visualization.mpl_draw(). You can install "
"matplotlib with:\n'pip install matplotlib'"
) from e
# If we have been provided with a list of numbers as long as elem_list,
# apply the color mapping.
if len(colors) == len(elem_list) and isinstance(colors[0], Number):
mapper = mpl.cm.ScalarMappable(cmap=cmap)
mapper.set_clim(vmin, vmax)
rgba_colors = mapper.to_rgba(colors)
# Otherwise, convert colors to matplotlib's RGB using the colorConverter
# object. These are converted to numpy ndarrays to be consistent with the
# to_rgba method of ScalarMappable.
else:
try:
rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
except ValueError:
rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(color) for color in colors])
# Set the final column of the rgba_colors to have the relevant alpha values
try:
# If alpha is longer than the number of colors, resize to the number of
# elements. Also, if rgba_colors.size (the number of elements of
# rgba_colors) is the same as the number of elements, resize the array,
# to avoid it being interpreted as a colormap by scatter()
if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
rgba_colors[1:, 0] = rgba_colors[0, 0]
rgba_colors[1:, 1] = rgba_colors[0, 1]
rgba_colors[1:, 2] = rgba_colors[0, 2]
rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
except TypeError:
rgba_colors[:, -1] = alpha
return rgba_colors