Multiple Seaborn subplots (e.g., sns.regplot
, sns.boxplot
) in one Matplotlib figure (i.e., fig, axs = plt.subplots
)
Building off of the suggestion of using two sns.regplot
's instead of sns.lmplot
in the accepted answer, here is a fully fleshed-out example closely mirroring the reference figure provided in your question.
The figure above was produced from the following code:
import matplotlib.pyplot as plt
import seaborn as sns
fig, axs = plt.subplots(ncols=3, sharey=True, figsize=(18, 6), dpi=300)
# Plots 1 & 2: Polynomial & Linear Regressions over "Wage" for "Age" and "Year"
for i, (variate, order) in enumerate(zip(["Age", "Year"], [2, 1])):
sns.stripplot(
x=variate,
y="Wage",
#hue="Education",
data=df,
ax=axs[i],
native_scale=True,
color="gray",
zorder=1,
alpha=0.5,
legend=False,
)
sns.regplot(
x=variate,
y="Wage",
data=df,
ax=axs[i],
scatter=False,
truncate=False,
order=order,
color="deepskyblue",
)
axs[i].set_xlabel(f"{variate}", labelpad=25, fontsize=18)
# Plot 3: Boxplot of "Wage" by "Education"
sns.boxplot(
x="Education",
y="Wage",
data=df,
hue="Education",
#palette="Set2",
ax=axs[2],
legend=True,
)
# Adjust axes labels for better readability
axs[0].set_ylabel("Wage", labelpad=25, fontsize=18)
axs[2].set_xlabel("Education", labelpad=25, fontsize=18)
for ax in axs:
ax.tick_params(
axis="both", which="major", labelsize=12, length=5, width=1.0
)
# Customize boxplot axes ticks and legend
axs[2].set_xticks(np.arange(len(education_levels)))
axs[2].set_xticklabels(
[label.split()[0][0] for label in sorted(education_levels)]
)
axs[2].legend(
loc="center left", bbox_to_anchor=(1, 0.5), title="Education Level"
)
plt.subplots_adjust(wspace=5.0)
plt.tight_layout()
plt.show()
using data simulated via:
import pandas as pd
import numpy as np
# Simulate example data with correlation and meaningful education levels
np.random.seed(0)
# Set parameter values
n_samples = 1000
mean_wage = 120
std_dev_wage = 60
age_min = 16
age_max = 80
peak_age = 50
age_data = np.random.uniform(age_min, age_max, n_samples)
# Generate corresponding wage data mimicking trend in given example
wage_data = np.zeros_like(age_data)
for i, age in enumerate(age_data):
base = (age / 100) * 90
if age <= peak_age:
m, s = list(
map(lambda p: p * (age / peak_age), (mean_wage, std_dev_wage))
)
wage_data[i] = np.abs(np.random.normal(m + base, s))
else:
m, s = list(
map(
lambda p: p * ((100 - age) / (100 - peak_age)),
(mean_wage, std_dev_wage),
)
)
wage_data[i] = np.abs(np.random.normal(m + base, s))
education_levels = [
"1. < HS Grad",
"2. HS Grad",
"3. Some College",
"4. College Grad",
"5. Postgraduate",
]
# Assign education levels vs. age by weighted probabilities
def assign_education(age):
education_levels = [
"1. < HS Grad",
"2. HS Grad",
"3. Some College",
"4. College Grad",
"5. Postgraduate",
]
if age >= 60:
weights = [0.05, 0.35, 0.25, 0.3, 0.05]
elif 45 <= age < 60:
weights = [0.05, 0.25, 0.25, 0.35, 0.1]
elif 25 <= age < 45:
weights = [0.1, 0.1, 0.3, 0.3, 0.2]
else:
weights = [0.2, 0.39, 0.3, 0.1, 0.01]
return np.random.choice(education_levels, p=weights)
education_data = np.array([assign_education(age) for age in age_data])
df = pd.DataFrame(
{
"Education": education_data,
"Age": age_data,
"Year": year_data,
"Wage": wage_data,
}
)
# Sort education by categories
df["Education"] = df["Education"].astype("category")
df["Education"] = df["Education"].cat.reorder_categories(
sorted(education_levels), ordered=True
)
print(f"DataFrame:\n{'-'*50}\n{df}\n")
print(f"DataFrame column datatypes:\n{'-'*50}\n{df.dtypes}\n")
print(
f"DataFrame 'Education' category order:\n{'-'*50}\n{df.Education.values}"
)
DataFrame:
--------------------------------------------------
Education Age Year Wage
0 4. College Grad 51.124064 2016 157.349244
1 4. College Grad 61.772119 2022 148.226233
2 4. College Grad 54.576856 2023 258.951815
3 3. Some College 50.872524 2019 151.065454
4 2. HS Grad 43.113907 2022 116.458425
.. ... ... ... ...
995 2. HS Grad 22.251288 2022 73.171386
996 2. HS Grad 48.955021 2022 55.975291
997 4. College Grad 76.058369 2016 102.863747
998 3. Some College 30.633379 2022 108.192692
999 4. College Grad 59.337033 2018 214.298984
[1000 rows x 4 columns]
DataFrame column datatypes:
--------------------------------------------------
Education category
Age float64
Year int64
Wage float64
dtype: object
DataFrame 'Education' category order:
--------------------------------------------------
['4. College Grad', '5. Postgraduate', '4. College Grad', '3. Some College', '2. HS Grad', ..., '2. HS Grad', '3. Some College', '5. Postgraduate', '3. Some College', '4. College Grad']
Length: 1000
Categories (5, object): ['1. < HS Grad' < '2. HS Grad' < '3. Some College' < '4. College Grad' < '5. Postgraduate']