swarmplot with hue affecting marker beyond color
Asked Answered
S

3

4

I'm trying to make my swarmplot easier to read in black&white and for people that are color-blind, by having the hue affect not just the color but also another geometrical aspect of the marker.

MWE

import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")

fig, ax = plt.subplots(1,1)
ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax)
plt.show()

Result

enter image description here

Desired Result (the left one)

enter image description here

Sausauce answered 18/10, 2018 at 16:45 Comment(0)
M
6

I actually had thought of the same problem a while ago. I did not come up with the greatest of solutions, but I have a hack that works OK. Unfortunately, it's much easier to implement if you use dodge=True.

The idea is to collect the PathCollections objects created by swarmplot. If dodge=True then you'll get N_cat*N_hues+N_hues collections (the N_hues extras are used to create the legend). You can simply iterate through that list. Since we want all hues to be the same, we use a N_hues stride to get all the collections corresponding to each of the hues. After that, you are free to update the paths of that collection to whatever Path object you choose. Refer to the documentation for Path to learn how to create paths.

To simplify things, I created some dummy scatter plots before hands to get some premade Paths that I can use. Of course, any Path should be able to work.

import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")

fig, ax = plt.subplots(1,1)
# dummy plots, just to get the Path objects
a = ax.scatter([1,2],[3,4], marker='s')
b = ax.scatter([1,2],[3,4], marker='^')
square_mk, = a.get_paths()
triangle_up_mk, = b.get_paths()
a.remove()
b.remove()

ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax, dodge=True)
N_hues = len(pd.unique(tips.sex))

c = ax.collections
for a in c[::N_hues]:
    a.set_paths([triangle_up_mk])
for a in c[1::N_hues]:
    a.set_paths([square_mk])
#update legend
ax.legend(c[-2:],pd.unique(tips.sex))

plt.show()

enter image description here

UPDATE A solution that "works" with dodge=False.

If you use dodge=False, then you'll get N+2 collections, one for each category, +2 for the legend. The problem is that all the different marker colors are jumbled up in these collections.

A possible, but ugly, solution is to loop through each element of the collection, and create an array of Path objects based one the color of each element.

import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")

fig, ax = plt.subplots(1,1)
ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax, dodge=False)

collections = ax.collections
unique_colors = np.unique(collections[0].get_facecolors(), axis=0)
markers = [triangle_up_mk, square_mk]  # this array must be at least as large as the number of unique colors
for collection in collections:
    paths = []
    for current_color in collection.get_facecolors():
        for possible_marker,possible_color in zip(markers, unique_colors):
            if np.array_equal(current_color,possible_color):
                paths.append(possible_marker)
                break
    collection.set_paths(paths)
#update legend
ax.legend(collections[-2:],pd.unique(tips.sex))  

plt.show()

enter image description here

Martyr answered 20/10, 2018 at 19:18 Comment(0)
S
2

The following would provide a hack which allows to easily achieve the desired different markers for swarmplots (or more generally any categorical scatter plots). It can be used as is, just copy it on top of existing plot scripts.

The idea is to link the color of a scatter point with a marker. E.g. any scatter point would get a marker from a specified list automatically. As a consequence this only works for plots with different colors.

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

############## Begin hack ##############
class CM():
    def __init__(self, markers=["o"]):
        self.marker = np.array(markers)
        self.colors = []

    def get_markers_for_colors(self, c):
        for _co in c:
            if not any((_co == x).all() for x in self.colors):
                self.colors.append(_co)
        ind = np.array([np.where((self.colors == row).all(axis=1)) \
                        for row in c]).flatten()
        return self.marker[ind % len(self.marker)]

    def get_legend_handles(self, **kwargs):
        return [plt.Line2D([0],[0], ls="none", marker=m, color=c, mec="none", **kwargs) \
                for m,c in zip(self.marker, self.colors)]

from matplotlib.axes._axes import Axes
import matplotlib.markers as mmarkers
cm = CM(plt.Line2D.filled_markers)
old_scatter = Axes.scatter
def new_scatter(self, *args, **kwargs):
    sc = old_scatter(self, *args, **kwargs)
    c = kwargs.get("c", None)
    if isinstance(c, np.ndarray):
        m = cm.get_markers_for_colors(c)
        paths = []
        for _m in m:
            marker_obj = mmarkers.MarkerStyle(_m)
            paths.append(marker_obj.get_path().transformed(
                        marker_obj.get_transform()))
        sc.set_paths(paths)
    return sc

Axes.scatter = new_scatter
############## End hack. ##############
# Copy and past to your file ##########


## Code ###

sns.set(style="whitegrid")
tips = sns.load_dataset("tips")

fig, ax = plt.subplots(1,1)
## Optionally specify own markers:
#cm.marker = np.array(["^", "s"])
ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax)

## Optionally adjust legend:
_,l = ax.get_legend_handles_labels()
ax.legend(cm.get_legend_handles(markersize=8),l)

plt.show()

enter image description here

Syllogistic answered 22/10, 2018 at 18:8 Comment(0)
S
2

Thank you to @ImportanceOfBeingErnest for the solution. I tried to edit his/her solution to fix some minor issues but in the end he/she suggested that I post my own answer.

This solution is the same as his/hers but it doesn't change the behavior of normal scatter when the marker array is not specified. It is also simpler to apply and it fixes the error where the legend loses the title.

The following figure is produced by the code below:

enter image description here

import seaborn as sns
import matplotlib.pyplot as plt

############## Begin hack ##############
from matplotlib.axes._axes import Axes
from matplotlib.markers import MarkerStyle
from seaborn import color_palette
from numpy import ndarray

def GetColor2Marker(markers):
    palette = color_palette()
    mkcolors = [(palette[i]) for i in range(len(markers))]
    return dict(zip(mkcolors,markers))

def fixlegend(ax,markers,markersize=8,**kwargs):
    # Fix Legend
    legtitle =  ax.get_legend().get_title().get_text()
    _,l = ax.get_legend_handles_labels()
    palette = color_palette()
    mkcolors = [(palette[i]) for i in range(len(markers))]
    newHandles = [plt.Line2D([0],[0], ls="none", marker=m, color=c, mec="none", markersize=markersize,**kwargs) \
                for m,c in zip(markers, mkcolors)]
    ax.legend(newHandles,l)
    leg = ax.get_legend()
    leg.set_title(legtitle)

old_scatter = Axes.scatter
def new_scatter(self, *args, **kwargs):
    colors = kwargs.get("c", None)
    co2mk = kwargs.pop("co2mk",None)
    FinalCollection = old_scatter(self, *args, **kwargs)
    if co2mk is not None and isinstance(colors, ndarray):
        Color2Marker = GetColor2Marker(co2mk)
        paths=[]
        for col in colors:
            mk=Color2Marker[tuple(col)]
            marker_obj = MarkerStyle(mk)
            paths.append(marker_obj.get_path().transformed(marker_obj.get_transform()))
        FinalCollection.set_paths(paths)
    return FinalCollection
Axes.scatter = new_scatter
############## End hack. ##############


# Example Test 
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")

# To test robustness
tips.loc[(tips['sex']=="Male") & (tips['day']=="Fri"),'sex']='Female'
tips.loc[(tips['sex']=="Female") & (tips['day']=="Sat"),'sex']='Male'

Markers = ["o","P"]

fig, axs = plt.subplots(1,2,figsize=(14,5))
axs[0] = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=axs[0])
axs[0].set_title("Original")
axs[1] = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=axs[1],co2mk=Markers)
axs[1].set_title("Hacked")
fixlegend(axs[1],Markers)

plt.show()
Sausauce answered 24/10, 2018 at 23:46 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.