How to use the `pos` argument in `networkx` to create a flowchart-style Graph? (Python 3)
Asked Answered
R

2

22

I am trying create a linear network graph using Python (preferably with matplotlib and networkx although would be interested in bokeh) similar in concept to the one below.

enter image description here

How can this graph plot be constructed efficiently (pos?) in Python using networkx? I want to use this for more complicated examples so I feel that hard coding the positions for this simple example won't be useful :( . Does networkx have a solution to this?

pos (dictionary, optional) – A dictionary with nodes as keys and positions as values. If not specified a spring layout positioning will be computed. See networkx.layout for functions that compute node positions.

I haven't seen any tutorials on how this can be achieved in networkx which is why I believe this question will be a reliable resource for the community. I've extensively gone through the networkx tutorials and nothing like this is on there. The layouts for networkx would make this type of network impossible to interpret without careful use of the pos argument... which I believe is my only option. None of the precomputed layouts on the https://networkx.github.io/documentation/networkx-1.9/reference/drawing.html documentation seem to handle this type of network structure well.

Simple Example:

(A) every outer key is the iteration in the graph moving from left to the right (e.g. iteration 0 represents samples, iteration 1 has groups 1 - 3, same with iteration 2, iteration 3 has Groups 1 - 2, etc.). (B) The inner dictionary contains the current grouping at that particular iteration, and the weights for the previous groups merging that represent the current group (e.g. iteration 3 has Group 1 and Group 2 and for iteration 4 all of iteration 3's Group 2 has gone into iteration 4's Group 2 but iteration 3's Group 1 has been split up. The weights always sum to 1.

My code for the connections w/ weights for the plot above:

D_iter_current_previous =    {
        1: {
            "Group 1":{"sample_0":0.5, "sample_1":0.5, "sample_2":0, "sample_3":0, "sample_4":0},
            "Group 2":{"sample_0":0, "sample_1":0, "sample_2":1, "sample_3":0, "sample_4":0},
            "Group 3":{"sample_0":0, "sample_1":0, "sample_2":0, "sample_3":0.5, "sample_4":0.5}
            },
        2: {
            "Group 1":{"Group 1":1, "Group 2":0, "Group 3":0},
            "Group 2":{"Group 1":0, "Group 2":1, "Group 3":0},
            "Group 3":{"Group 1":0, "Group 2":0, "Group 3":1}
            },
        3: {
            "Group 1":{"Group 1":0.25, "Group 2":0, "Group 3":0.75},
            "Group 2":{"Group 1":0.25, "Group 2":0.75, "Group 3":0}
            },
        4: {
            "Group 1":{"Group 1":1, "Group 2":0},
            "Group 2":{"Group 1":0.25, "Group 2":0.75}
            }
        }

This is what happened when I made the Graph in networkx:

import networkx
import matplotlib.pyplot as plt

# Create Directed Graph
G = nx.DiGraph()

# Iterate through all connections
for iter_n, D_current_previous in D_iter_current_previous.items():
    for current_group, D_previous_weights in D_current_previous.items():
        for previous_group, weight in D_previous_weights.items():
            if weight > 0:
                # Define connections using `|__|` as a delimiter for the names
                previous_node = "%d|__|%s"%(iter_n - 1, previous_group)
                current_node = "%d|__|%s"%(iter_n, current_group)
                connection = (previous_node, current_node)
                G.add_edge(*connection, weight=weight)

# Draw Graph with labels and width thickness
nx.draw(G, with_labels=True, width=[G[u][v]['weight'] for u,v in G.edges()])

enter image description here

Note: The only other way, I could think of to do this would be in matplotlib creating a scatter plot with every tick representing a iteration (5 including the initial samples) then connecting the points to each other with different weights. This would be some pretty messy code especially trying to line up the edges of the markers w/ the connections...However, I'm not sure if this and networkx is the best way to do it or if there is a tool (e.g. bokeh or plotly) that is designed for this type of plotting.

Reive answered 1/10, 2016 at 0:39 Comment(3)
Did you have a look at the networkx tutorial? Plotting networkx graphs in matplotlib is really easy. At what point are you having problems?Inanimate
I know how to use networkx but the layouts will make them a randomized mess. I could use the pos argument but it is weird customizing the positional dictionary. I don't think there are any tutorials on linear networksReive
I think there's a way to do this by using the graphviz layouts, but I'm not quite sure about how that interacts with the more recent versions of networkx (the API has changed a little). If I have some time, I'll try to get a working install up and running. In the meantime, try looking at this answer - the prog='dot' is crucial, but may well get you where you want to go. Note you need to install graphviz and pygraphviz in addition to networkx and matplotlib.Nebulous
L
19

Networkx has decent plotting facilities for exploratory data analysis, it is not the tool to make publication quality figures, for various reason that I don't want to go into here. I hence rewrote that part of the code base from scratch, and made a stand-alone drawing module called netgraph that can be found here (like the original purely based on matplotlib). The API is very, very similar and well documented, so it should not be too hard to mold to your purposes.

Building on that I get the following result:

enter image description here

I chose colour to denote the edge strength as you can
1) indicate negative values, and
2) distinguish small values better.
However, you can also pass an edge width to netgraph instead (see netgraph.draw_edges()).

The different order of the branches is a result of your data structure (a dict), which indicates no inherent order. You would have to amend your data structure and the function _parse_input() below to fix that issue.

Code:

import itertools
import numpy as np
import matplotlib.pyplot as plt
import netgraph; reload(netgraph)

def plot_layered_network(weight_matrices,
                         distance_between_layers=2,
                         distance_between_nodes=1,
                         layer_labels=None,
                         **kwargs):
    """
    Convenience function to plot layered network.

    Arguments:
    ----------
        weight_matrices: [w1, w2, ..., wn]
            list of weight matrices defining the connectivity between layers;
            each weight matrix is a 2-D ndarray with rows indexing source and columns indexing targets;
            the number of sources has to match the number of targets in the last layer

        distance_between_layers: int

        distance_between_nodes: int

        layer_labels: [str1, str2, ..., strn+1]
            labels of layers

        **kwargs: passed to netgraph.draw()

    Returns:
    --------
        ax: matplotlib axis instance

    """
    nodes_per_layer = _get_nodes_per_layer(weight_matrices)

    node_positions = _get_node_positions(nodes_per_layer,
                                         distance_between_layers,
                                         distance_between_nodes)

    w = _combine_weight_matrices(weight_matrices, nodes_per_layer)

    ax = netgraph.draw(w, node_positions, **kwargs)

    if not layer_labels is None:
        ax.set_xticks(distance_between_layers*np.arange(len(weight_matrices)+1))
        ax.set_xticklabels(layer_labels)
        ax.xaxis.set_ticks_position('bottom')

    return ax

def _get_nodes_per_layer(weight_matrices):
    nodes_per_layer = []
    for w in weight_matrices:
        sources, targets = w.shape
        nodes_per_layer.append(sources)
    nodes_per_layer.append(targets)
    return nodes_per_layer

def _get_node_positions(nodes_per_layer,
                        distance_between_layers,
                        distance_between_nodes):
    x = []
    y = []
    for ii, n in enumerate(nodes_per_layer):
        x.append(distance_between_nodes * np.arange(0., n))
        y.append(ii * distance_between_layers * np.ones((n)))
    x = np.concatenate(x)
    y = np.concatenate(y)
    return np.c_[y,x]

def _combine_weight_matrices(weight_matrices, nodes_per_layer):
    total_nodes = np.sum(nodes_per_layer)
    w = np.full((total_nodes, total_nodes), np.nan, np.float)

    a = 0
    b = nodes_per_layer[0]
    for ii, ww in enumerate(weight_matrices):
        w[a:a+ww.shape[0], b:b+ww.shape[1]] = ww
        a += nodes_per_layer[ii]
        b += nodes_per_layer[ii+1]

    return w

def test():
    w1 = np.random.rand(4,5) #< 0.50
    w2 = np.random.rand(5,6) #< 0.25
    w3 = np.random.rand(6,3) #< 0.75

    import string
    node_labels = dict(zip(range(18), list(string.ascii_lowercase)))

    fig, ax = plt.subplots(1,1)
    plot_layered_network([w1,w2,w3],
                         layer_labels=['start', 'step 1', 'step 2', 'finish'],
                         ax=ax,
                         node_size=20,
                         node_edge_width=2,
                         node_labels=node_labels,
                         edge_width=5,
    )
    plt.show()
    return

def test_example(input_dict):
    weight_matrices, node_labels = _parse_input(input_dict)
    fig, ax = plt.subplots(1,1)
    plot_layered_network(weight_matrices,
                         layer_labels=['', '1', '2', '3', '4'],
                         distance_between_layers=10,
                         distance_between_nodes=8,
                         ax=ax,
                         node_size=300,
                         node_edge_width=10,
                         node_labels=node_labels,
                         edge_width=50,
    )
    plt.show()
    return

def _parse_input(input_dict):
    weight_matrices = []
    node_labels = []

    # initialise sources
    sources = set()
    for v in input_dict[1].values():
        for s in v.keys():
            sources.add(s)
    sources = list(sources)

    for ii in range(len(input_dict)):
        inner_dict = input_dict[ii+1]
        targets = inner_dict.keys()

        w = np.full((len(sources), len(targets)), np.nan, np.float)
        for ii, s in enumerate(sources):
            for jj, t in enumerate(targets):
                try:
                    w[ii,jj] = inner_dict[t][s]
                except KeyError:
                    pass

        weight_matrices.append(w)
        node_labels.append(sources)
        sources = targets

    node_labels.append(targets)
    node_labels = list(itertools.chain.from_iterable(node_labels))
    node_labels = dict(enumerate(node_labels))

    return weight_matrices, node_labels

# --------------------------------------------------------------------------------
# script
# --------------------------------------------------------------------------------

if __name__ == "__main__":

    # test()

    input_dict =   {
        1: {
            "Group 1":{"sample_0":0.5, "sample_1":0.5, "sample_2":0, "sample_3":0, "sample_4":0},
            "Group 2":{"sample_0":0, "sample_1":0, "sample_2":1, "sample_3":0, "sample_4":0},
            "Group 3":{"sample_0":0, "sample_1":0, "sample_2":0, "sample_3":0.5, "sample_4":0.5}
            },
        2: {
            "Group 1":{"Group 1":1, "Group 2":0, "Group 3":0},
            "Group 2":{"Group 1":0, "Group 2":1, "Group 3":0},
            "Group 3":{"Group 1":0, "Group 2":0, "Group 3":1}
            },
        3: {
            "Group 1":{"Group 1":0.25, "Group 2":0, "Group 3":0.75},
            "Group 2":{"Group 1":0.25, "Group 2":0.75, "Group 3":0}
            },
        4: {
            "Group 1":{"Group 1":1, "Group 2":0},
            "Group 2":{"Group 1":0.25, "Group 2":0.75}
            }
        }

    test_example(input_dict)

    pass
Lumbago answered 5/10, 2016 at 0:0 Comment(10)
hey thanks for this I want to try it out . can you install netgraph with conda or pip?Reive
@Reive I haven't bothered setting it up for that yet and I have to hand in my PhD in 3 weeks so I won't until then. Just download the .py file (it is just one file), and drop it in your working directory for the time being. It just relies on matplotlib, so as long as you have that installed everything should work.Lumbago
Cool, I'll try it out in a few! Good luck with your dissertation.Reive
Just don't forget to accept the answer when you have -- I procrastinated -- I mean: worked -- hard for those imaginary internet points. ;-)Lumbago
Just accepted them now :/ sorry I thought accepting the bounty did bothReive
btw, i looked at your source code. kwargs.setdefault('edge_color', weights) I had no idea you could do that. . .Reive
It's a neat dictionary trick. Also check out dictionary.get(key, default).Lumbago
Your code sample does not work with the current pip version of netgraph. The API seems to have changed: ax = netgraph.draw(w, node_positions, **kwargs), node_positions is a np.array but should be a dict. Could you update the code?Pegasus
@moritzschaefer: I try to maintain my code and my code documentation. I don't have the time to maintain all my stackoverflow answers as well. netgraph is on a new major version and the API has changed (slightly) in a backwards incompatible way. As described in the docs, node_positions is now a dictionary mapping node IDs to 2-tuples of floats (or equivalent). If you provide a weight matrix, the node IDs have to be integers corresponding to the indices of the matrix (i.e. starting at zero). Happy to help if you post a new question with an MWE; happy to accept edits to this answer.Lumbago
why is an answer written in netgraph getting highest vote on a networkx question? It's completely irrelevantSomatic
E
0

There is a way to do this using the multipartite_layout() in NetworkX. You can even separately plot the edges and edge labels to get the weighted edge opacity that you visualized in your concept.

The following repex is available at this link, which demonstrates how to create the Mutipartite Weighted Directed Graph / Flowchart using only matplotlib and NetworkX.

import networkx as nx
import matplotlib.pyplot as plt

D_iter_current_previous =    {
        1: {
            "Group 1":{"sample_0":0.5, "sample_1":0.5, "sample_2":0, "sample_3":0, "sample_4":0},
            "Group 2":{"sample_0":0, "sample_1":0, "sample_2":1, "sample_3":0, "sample_4":0},
            "Group 3":{"sample_0":0, "sample_1":0, "sample_2":0, "sample_3":0.5, "sample_4":0.5}
            },
        2: {
            "Group 1":{"Group 1":1, "Group 2":0, "Group 3":0},
            "Group 2":{"Group 1":0, "Group 2":1, "Group 3":0},
            "Group 3":{"Group 1":0, "Group 2":0, "Group 3":1}
            },
        3: {
            "Group 1":{"Group 1":0.25, "Group 2":0, "Group 3":0.75},
            "Group 2":{"Group 1":0.25, "Group 2":0.75, "Group 3":0}
            },
        4: {
            "Group 1":{"Group 1":1, "Group 2":0},
            "Group 2":{"Group 1":0.25, "Group 2":0.75}
            }
        }

# Create a NextworkX directed graph
g = nx.DiGraph()
g.add_nodes_from(['sample_0', 'sample_1', 'sample_2', 'sample_3', 'sample_4'], subset=0)
g.add_nodes_from(['Group 1.1', 'Group 2.1', 'Group 3.1'], subset=1)
g.add_nodes_from(['Group 1.2', 'Group 2.2', 'Group 3.2'], subset=2)
g.add_nodes_from(['Group 1.3', 'Group 2.3'], subset=3)
g.add_nodes_from(['Group 1.4', 'Group 2.4'], subset=4)

# Add Title to Plot
plt.title('Multipartite Weighted Directed Graph')

# Create a list of the edges and alphas (opacity)
edges = []
alphas = []
for subset, subset_stuff in D_iter_current_previous.items():
    for node, prev_nodes in subset_stuff.items():
        for prev_node, weight in prev_nodes.items():
            if subset > 1:
                edges.append((f'{prev_node}.{subset-1}', f'{node}.{subset}'))
                alphas.append(weight)
            else:
                edges.append((f'{prev_node}', f'{node}.{subset}'))
                alphas.append(weight)

# Make a dict of the edge labels
edge_labels = dict(zip(edges, alphas))

# Draw the nodes with the multipartite layout
pos = nx.multipartite_layout(g, align='vertical')
nx.draw(g, pos, with_labels=True, node_size=1500, font_size=8)

# Draw the edges with their corresponding alphas
nx.draw_networkx_edges(g, pos, edgelist=edges, alpha=alphas, arrows=True, node_size=1500)

# Draw the edge labels with their corresponding alphas
counter = 0
for edge, alpha in edge_labels.items():
    if alpha > 0:
        nx.draw_networkx_edge_labels(g, pos, edge_labels={edge: float(alpha)}, alpha=float(alpha), font_size=7)

# Show the plot
plt.show()

Resulting NetworkX Graph

Enthronement answered 23/3, 2024 at 21:26 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.