I actually had thought of the same problem a while ago. I did not come up with the greatest of solutions, but I have a hack that works OK. Unfortunately, it's much easier to implement if you use dodge=True
.
The idea is to collect the PathCollections
objects created by swarmplot
. If dodge=True
then you'll get N_cat*N_hues+N_hues
collections (the N_hues extras are used to create the legend). You can simply iterate through that list. Since we want all hues to be the same, we use a N_hues stride to get all the collections corresponding to each of the hues. After that, you are free to update the paths
of that collection to whatever Path
object you choose. Refer to the documentation for Path
to learn how to create paths.
To simplify things, I created some dummy scatter plots before hands to get some premade Paths
that I can use. Of course, any Path
should be able to work.
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")
fig, ax = plt.subplots(1,1)
# dummy plots, just to get the Path objects
a = ax.scatter([1,2],[3,4], marker='s')
b = ax.scatter([1,2],[3,4], marker='^')
square_mk, = a.get_paths()
triangle_up_mk, = b.get_paths()
a.remove()
b.remove()
ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax, dodge=True)
N_hues = len(pd.unique(tips.sex))
c = ax.collections
for a in c[::N_hues]:
a.set_paths([triangle_up_mk])
for a in c[1::N_hues]:
a.set_paths([square_mk])
#update legend
ax.legend(c[-2:],pd.unique(tips.sex))
plt.show()
UPDATE A solution that "works" with dodge=False
.
If you use dodge=False
, then you'll get N+2 collections, one for each category, +2 for the legend. The problem is that all the different marker colors are jumbled up in these collections.
A possible, but ugly, solution is to loop through each element of the collection, and create an array of Path
objects based one the color of each element.
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
tips = sns.load_dataset("tips")
fig, ax = plt.subplots(1,1)
ax = sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips,size=8,ax=ax, dodge=False)
collections = ax.collections
unique_colors = np.unique(collections[0].get_facecolors(), axis=0)
markers = [triangle_up_mk, square_mk] # this array must be at least as large as the number of unique colors
for collection in collections:
paths = []
for current_color in collection.get_facecolors():
for possible_marker,possible_color in zip(markers, unique_colors):
if np.array_equal(current_color,possible_color):
paths.append(possible_marker)
break
collection.set_paths(paths)
#update legend
ax.legend(collections[-2:],pd.unique(tips.sex))
plt.show()