Plotting multiple different plots in one figure using Seaborn
Asked Answered
U

3

80

I am attempting to recreate the following plot from the book Introduction to Statistical learning using seaborn enter image description here

I specifically want to recreate this using seaborn's lmplot to create the first two plots and boxplot to create the second. The main problem is that lmplot creates a FacetGrid according to this answer which forces me to hackily add another matplotlib Axes for the boxplot. I was wondering if there was an easier way to achieve this. Below, I have to do quite a bit of manual manipulation to get the desired plot.

seaborn_grid = sns.lmplot('value', 'wage', col='variable', hue='education', data=df_melt, sharex=False)
seaborn_grid.fig.set_figwidth(8)

left, bottom, width, height = seaborn_grid.fig.axes[0]._position.bounds
left2, bottom2, width2, height2 = seaborn_grid.fig.axes[1]._position.bounds
left_diff = left2 - left
seaborn_grid.fig.add_axes((left2 + left_diff, bottom, width, height))

sns.boxplot('education', 'wage', data=df_wage, ax = seaborn_grid.fig.axes[2])
ax2 = seaborn_grid.fig.axes[2]
ax2.set_yticklabels([])
ax2.set_xticklabels(ax2.get_xmajorticklabels(), rotation=30)
ax2.set_ylabel('')
ax2.set_xlabel('');

leg = seaborn_grid.fig.legends[0]
leg.set_bbox_to_anchor([0, .1, 1.5,1])

Which yields enter image description here

Sample data for DataFrames:

df_melt = {
    'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
    'value': [18, 24, 45, 43, 50],
    'variable': ['age', 'age', 'age', 'age', 'age'],
    'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}

df_wage = {
    'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
    'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}
Uncommercial answered 28/6, 2016 at 17:24 Comment(0)
H
170

One possibility would be to NOT use lmplot(), but directly use regplot() instead. regplot() plots on the axes you pass as an argument with ax=.

You lose the ability to automatically split your dataset according to a certain variable, but if you know beforehand the plots you want to generate, it shouldn't be a problem.

Something like this:

import matplotlib.pyplot as plt
import seaborn as sns

fig, axs = plt.subplots(ncols=3)
sns.regplot(x='value', y='wage', data=df_melt, ax=axs[0])
sns.regplot(x='value', y='wage', data=df_melt, ax=axs[1])
sns.boxplot(x='education',y='wage', data=df_melt, ax=axs[2])
Honolulu answered 29/6, 2016 at 6:14 Comment(0)
A
2

Multiple Seaborn subplots (e.g., sns.regplot, sns.boxplot) in one Matplotlib figure (i.e., fig, axs = plt.subplots)

Building off of the suggestion of using two sns.regplot's instead of sns.lmplot in the accepted answer, here is a fully fleshed-out example closely mirroring the reference figure provided in your question.

Seaborn subplots output

The figure above was produced from the following code:

import matplotlib.pyplot as plt
import seaborn as sns

fig, axs = plt.subplots(ncols=3, sharey=True, figsize=(18, 6), dpi=300)

# Plots 1 & 2: Polynomial & Linear Regressions over "Wage" for "Age" and "Year"
for i, (variate, order) in enumerate(zip(["Age", "Year"], [2, 1])):
    sns.stripplot(
        x=variate,
        y="Wage",
        #hue="Education",
        data=df,
        ax=axs[i],
        native_scale=True,
        color="gray",
        zorder=1,
        alpha=0.5,
        legend=False,
    )
    sns.regplot(
        x=variate,
        y="Wage",
        data=df,
        ax=axs[i],
        scatter=False,
        truncate=False,
        order=order,
        color="deepskyblue",
    )
    axs[i].set_xlabel(f"{variate}", labelpad=25, fontsize=18)

# Plot 3: Boxplot of "Wage" by "Education"
sns.boxplot(
    x="Education",
    y="Wage",
    data=df,
    hue="Education",
    #palette="Set2",
    ax=axs[2],
    legend=True,
)

# Adjust axes labels for better readability
axs[0].set_ylabel("Wage", labelpad=25, fontsize=18)
axs[2].set_xlabel("Education", labelpad=25, fontsize=18)
for ax in axs:
    ax.tick_params(
        axis="both", which="major", labelsize=12, length=5, width=1.0
    )
    
# Customize boxplot axes ticks and legend
axs[2].set_xticks(np.arange(len(education_levels)))
axs[2].set_xticklabels(
    [label.split()[0][0] for label in sorted(education_levels)]
)
axs[2].legend(
    loc="center left", bbox_to_anchor=(1, 0.5), title="Education Level"
)

plt.subplots_adjust(wspace=5.0)
plt.tight_layout()
plt.show()

using data simulated via:

import pandas as pd
import numpy as np

# Simulate example data with correlation and meaningful education levels
np.random.seed(0)

# Set parameter values
n_samples = 1000
mean_wage = 120
std_dev_wage = 60
age_min = 16
age_max = 80
peak_age = 50

age_data = np.random.uniform(age_min, age_max, n_samples)

# Generate corresponding wage data mimicking trend in given example
wage_data = np.zeros_like(age_data)
for i, age in enumerate(age_data):
    base = (age / 100) * 90
    if age <= peak_age:
        m, s = list(
            map(lambda p: p * (age / peak_age), (mean_wage, std_dev_wage))
        )
        wage_data[i] = np.abs(np.random.normal(m + base, s))
    else:
        m, s = list(
            map(
                lambda p: p * ((100 - age) / (100 - peak_age)),
                (mean_wage, std_dev_wage),
            )
        )
        wage_data[i] = np.abs(np.random.normal(m + base, s))

education_levels = [
    "1. < HS Grad",
    "2. HS Grad",
    "3. Some College",
    "4. College Grad",
    "5. Postgraduate",
]

# Assign education levels vs. age by weighted probabilities
def assign_education(age):
    education_levels = [
        "1. < HS Grad",
        "2. HS Grad",
        "3. Some College",
        "4. College Grad",
        "5. Postgraduate",
    ]
    if age >= 60:
        weights = [0.05, 0.35, 0.25, 0.3, 0.05]
    elif 45 <= age < 60:
        weights = [0.05, 0.25, 0.25, 0.35, 0.1]
    elif 25 <= age < 45:
        weights = [0.1, 0.1, 0.3, 0.3, 0.2]
    else:
        weights = [0.2, 0.39, 0.3, 0.1, 0.01]
    return np.random.choice(education_levels, p=weights)


education_data = np.array([assign_education(age) for age in age_data])

df = pd.DataFrame(
    {
        "Education": education_data,
        "Age": age_data,
        "Year": year_data,
        "Wage": wage_data,
    }
)

# Sort education by categories
df["Education"] = df["Education"].astype("category")
df["Education"] = df["Education"].cat.reorder_categories(
    sorted(education_levels), ordered=True
)

print(f"DataFrame:\n{'-'*50}\n{df}\n")
print(f"DataFrame column datatypes:\n{'-'*50}\n{df.dtypes}\n")
print(
    f"DataFrame 'Education' category order:\n{'-'*50}\n{df.Education.values}"
)
DataFrame:
--------------------------------------------------
           Education        Age  Year        Wage
0    4. College Grad  51.124064  2016  157.349244
1    4. College Grad  61.772119  2022  148.226233
2    4. College Grad  54.576856  2023  258.951815
3    3. Some College  50.872524  2019  151.065454
4         2. HS Grad  43.113907  2022  116.458425
..               ...        ...   ...         ...
995       2. HS Grad  22.251288  2022   73.171386
996       2. HS Grad  48.955021  2022   55.975291
997  4. College Grad  76.058369  2016  102.863747
998  3. Some College  30.633379  2022  108.192692
999  4. College Grad  59.337033  2018  214.298984

[1000 rows x 4 columns]

DataFrame column datatypes:
--------------------------------------------------
Education    category
Age           float64
Year            int64
Wage          float64
dtype: object

DataFrame 'Education' category order:
--------------------------------------------------
['4. College Grad', '5. Postgraduate', '4. College Grad', '3. Some College', '2. HS Grad', ..., '2. HS Grad', '3. Some College', '5. Postgraduate', '3. Some College', '4. College Grad']
Length: 1000
Categories (5, object): ['1. < HS Grad' < '2. HS Grad' < '3. Some College' < '4. College Grad' < '5. Postgraduate']
Adelric answered 21/3 at 1:55 Comment(0)
W
0

As of seaborn 0.13.0 (over 7 years after this question was posted), it's still really difficult to add subplots to a seaborn figure-level objects without messing with the underlying figure positions. In fact, the method shown in the OP is probably the most readable way to do it.

With that being said, as suggested by Diziet Asahi, if you want to forego seaborn FacetGrids (e.g. lmplot, catplot etc.) altogether and use seaborn Axes-level methods to create an equivalent figure (e.g. regplot instead of lmplot, scatterplot+lineplot instead of relplot etc.) and add more subplots such as boxplot to the figure, you could group your data by the columns you were going to use as cols kwarg in lmplot (and groupby the sub-dataframe by the columns you were going to use as hue kwarg) and draw the plots using data from the sub-dataframes.

As an example, using the data in the OP, we could the following, which creates a somewhat equivalent figure to lmplot but adds boxplot on the right:

# groupby data since `cols='variable'`
groupby_object = df_melt.groupby('variable')
# count number of groups to determine the required number of subplots
number_of_columns = groupby_object.ngroups

fig, axs = plt.subplots(1, number_of_columns+1, sharey=True)
for i, (_, g) in enumerate(groupby_object):
    # feed data from each sub-dataframe `g` to regplot
    sns.regplot(data=g, x='value', y='wage', ax=axs[i])
# plot the boxplot in the end
sns.boxplot(data=df_wage, x='education', y='wage', hue='education', ax=axs[-1])

The example in the OP uses hue= kwarg to draw different lines of fit by 'education'. To do that, we could groupby the sub-dataframe by the 'education' column again and plot multiple regplots by education on the same Axes. A working example is as follows:

groupby_object = df_melt.groupby('variable')
number_of_columns = groupby_object.ngroups
fig, axs = plt.subplots(1, number_of_columns+1, figsize=(12, 5), sharey=True)
for i, (_, g) in enumerate(groupby_object):
    for label, g1 in g.groupby('education'):
        label = label if i == 0 else None
        sns.regplot(data=g1, x='value', y='wage', label=label, scatter_kws={'alpha': 0.7}, ax=axs[i])
sns.boxplot(data=df_wage, x='education', y='wage', hue='education', ax=axs[-1])
axs[-1].set(ylabel='', xlabel='')
axs[-1].tick_params(axis='x', labelrotation=30)
for ax, title in zip(axs, ['Age', 'Year', 'Education']):
    ax.set_title(title)
_ = fig.legend(bbox_to_anchor=(0.92, 0.5), loc="center left")

Using the following sample dataset (I had to create a new dataset since OP's sample is not rich enough to make a proper graph):

import numpy as np
import pandas as pd
rng = np.random.default_rng(0)
edu = rng.choice(['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad','2. HS Grad'], size=100)
wage = rng.normal(75, 25, 100)
df_melt = pd.DataFrame({'education': edu, 'value': rng.normal(30, 20, 100), 'variable': rng.choice(['age', 'year'], 100), 'wage': wage})
df_wage = pd.DataFrame({'education': edu, 'wage': wage})

the above code plots the following figure:

result

Wild answered 17/11, 2023 at 22:9 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.