I absolutely loved @raphael's answer.
Here is a version with circles. Furthermore, I've refactored and trimmed the code a bit to make it more modular.
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
class MulticolorCircles:
"""
For different shapes, override the ``get_patch`` method, and add the new
class to the handler map, e.g. via
ax_r.legend(ax_r_handles, ax_r_labels, handlelength=CONF.LEGEND_ICON_SIZE,
borderpad=1.2, labelspacing=1.2,
handler_map={MulticolorCircles: MulticolorHandler})
"""
def __init__(self, face_colors, edge_colors=None, face_alpha=1,
radius_factor=1):
"""
"""
assert 0 <= face_alpha <= 1, f"Invalid face_alpha: {face_alpha}"
assert radius_factor > 0, "radius_factor must be positive"
self.rad_factor = radius_factor
self.fc = [mcolors.colorConverter.to_rgba(fc, alpha=face_alpha)
for fc in face_colors]
self.ec = edge_colors
if edge_colors is None:
self.ec = ["none" for _ in self.fc]
self.N = len(self.fc)
def get_patch(self, width, height, idx, fc, ec):
"""
"""
w_chunk = width / self.N
radius = min(w_chunk / 2, height) * self.rad_factor
xy = (w_chunk * idx + radius, radius)
patch = plt.Circle(xy, radius, facecolor=fc, edgecolor=ec)
return patch
def __call__(self, width, height):
"""
"""
patches = []
for i, (fc, ec) in enumerate(zip(self.fc, self.ec)):
patch = self.get_patch(width, height, i, fc, ec)
patches.append(patch)
result = PatchCollection(patches, match_original=True)
#
return result
class MulticolorHandler:
"""
"""
@staticmethod
def legend_artist(legend, orig_handle, fontsize, handlebox):
"""
"""
width, height = handlebox.width, handlebox.height
patch = orig_handle(width, height)
handlebox.add_artist(patch)
return patch
Sample usage and image, note that some of the legend handles have radius_factor=0.5
because the true size would be too small.
ax_handles, ax_labels = ax.get_legend_handles_labels()
ax_labels.append(AUDIOSET_LABEL)
ax_handles.append(MulticolorCircles([AUDIOSET_COLOR],
face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(FRAUNHOFER_LABEL)
ax_handles.append(MulticolorCircles([FRAUNHOFER_COLOR],
face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TRAIN_SOURCE_NORMAL_LABEL)
ax_handles.append(MulticolorCircles(SHADOW_COLORS["source"],
face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TRAIN_TARGET_NORMAL_LABEL)
ax_handles.append(MulticolorCircles(SHADOW_COLORS["target"],
face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TEST_SOURCE_ANOMALY_LABEL)
ax_handles.append(MulticolorCircles(DOT_COLORS["anomaly_source"],
radius_factor=LEGEND_DOT_RATIO))
ax_labels.append(TEST_TARGET_ANOMALY_LABEL)
ax_handles.append(MulticolorCircles(DOT_COLORS["anomaly_target"],
radius_factor=LEGEND_DOT_RATIO))
#
ax.legend(ax_handles, ax_labels, handlelength=LEGEND_ICON_SIZE,
borderpad=1.1, labelspacing=1.1,
handler_map={MulticolorCircles: MulticolorHandler})