Matplotlib rcParams ignored when plotting
Asked Answered
M

1

7

I am working on an interactive plotting script using matplotlib version 3.2.2 and tkinter.

When the script is run, the first window looks like this: enter image description here Furthermore, the rcParams are updated and plotted once the Plot figure button is clicked:

enter image description here

If I now hit the button Change plot settings and change for instance the markersize parameter -> Plot figure, the plot is updated like so:

enter image description here

if I, however, try changing the label sizes to 20 px and then validate that the rcParams['axes.labelsize'] is changed, they are. But the size of the x and y labels are never updated in the actual plot.

The plot title (text input field all the way to the top in the plot window) fontsize can be changed after it has been plotted.

The minimal code:

"""
This is a script for interactively plotting a scatterplot and changing the plot params.
"""

import numpy as np
import matplotlib as mpl
import matplotlib.style
import random

mpl.use('TkAgg')
import numpy as np
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from tkinter import *
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from pandas.api.types import is_numeric_dtype


def center_tk_window(window, height, width):
    # Helper method for centering windows

    screen_width = window.winfo_screenwidth()
    screen_height = window.winfo_screenheight()
    x_coordinate = int((screen_width / 2) - (width / 2))
    y_coordinate = int((screen_height / 2) - (height / 2))
    window.geometry("{}x{}+{}+{}".format(width, height, x_coordinate, y_coordinate))


def plot_data(data, chosen_columns, ax=None, initial_box=None, fig=None):
    if fig is None:
        fig = Figure()
    if ax is None:
        # Create a new subplot
        ax = fig.add_subplot(111)

    # Selected x-coordinates
    x_data = data[chosen_columns['x_col']]

    # Selected y-coordinates
    if 'y_col' in chosen_columns:
        y_data = data[chosen_columns['y_col']]

    filled_markers = ('o', 'v', '^', '<', '>', '8', 's', 'p', '*', 'h', 'H', 'D', 'd', 'P', 'X')

    # Category column
    if 'category_col' in chosen_columns:
        category_data = data[chosen_columns['category_col']]

        # Plotting it all
        sns.scatterplot(ax=ax, x=x_data, y=y_data, hue=category_data, style=category_data,
                        markers=filled_markers
                        )
        # Shrink current axis's height by 20% on the bottom
        if initial_box is None:
            initial_box = ax.get_position()

        ax.set_position([initial_box.x0, initial_box.y0 + initial_box.height * 0.2,
                         initial_box.width, initial_box.height * 0.80])
        # Put a legend below current axis
        ax.legend(bbox_to_anchor=(0.5, -0.15), ncol=6)
    else:  # Normal scatterplot without any categorical values
        sns.scatterplot(ax=ax, x=x_data, y=y_data)

    ax.set_ylabel(chosen_columns['y_col'])
    ax.set_xlabel(chosen_columns['x_col'])
    return fig, ax, initial_box


class GenericPlot:
    def __init__(self, data):

        # Parameters window selection
        self.canvas = None
        self.fig = None
        self.ax = None
        self.chosen_columns = None
        self.initial_box = None
        self.updated_rc_params = None
        self.set_plot_params(data)

        # Plot window
        self.save_plot_bool = False
        self.plot_window = Tk()
        self.interactive_plot(data)
        self.plot_window.mainloop()

    def set_plot_params(self, data):

        def plot_with_settings():
            def format_input(input):
                if input == '':
                    # Use the default value
                    return mpl.rcParams[input]
                if ',' in input:
                    return float(input.replace(',', '.'))
                else:
                    return float(input)

            # Figure size
            figure_params = {}
            if figsize_width.get() != '' and figsize_height.get() != '':
                figure_params['figsize'] = (format_input(figsize_width.get()), format_input(figsize_height.get()))

            # label sizes
            axes_params = {}
            if label_size.get() != '':
                axes_params['labelsize'] = format_input(label_size.get())
            if title_size.get() != '':
                axes_params['titlesize'] = format_input(title_size.get())

            legend_params = {}
            if legend_title_fontsize.get() != '': legend_params['title_fontsize'] = format_input(
                legend_title_fontsize.get())
            legend_additional = {'loc': 'upper center',
                                 'fancybox': False,
                                 'shadow': False
                                 }
            legend_params.update(legend_additional)

            if marker_size.get() != '': lines_params = {'markersize': format_input(marker_size.get())}
            legend_params['markerscale'] = format_input(legend_markerscale.get())

            mpl.rc('figure', **figure_params)
            mpl.rc('axes', **axes_params)
            mpl.rc('lines', **lines_params)
            mpl.rc('legend', **legend_params)

            self.updated_rc_params = mpl.rcParams

            # Update canvas if the params were changed after it was drawn:
            if self.ax is not None:
                self.ax.clear()
                mpl.rcParams.update(self.updated_rc_params)
                self.fig, self.ax, _ = plot_data(data, self.chosen_columns, self.ax,
                                                 self.initial_box, self.fig)
                self.canvas.draw()
            custom_params_window.destroy()  # Close the tk window

        # Create a new window
        custom_params_window = Tk()
        center_tk_window(custom_params_window, 300, 400)  # window, height, width

        custom_params_window.title('Set plot parameters')

        # Set up GUI
        custom_params_window.columnconfigure(0, weight=1)
        custom_params_window.columnconfigure(1, weight=1)

        n_rows = 8
        for r in range(n_rows):
            custom_params_window.rowconfigure(r, weight=1)

        row_num = 0

        # Figsize
        Label(custom_params_window, text="Figure width (px)").grid(row=row_num, column=0, sticky="e")
        figsize_width = Entry(custom_params_window)
        placeholder_width = self.updated_rc_params['figure.figsize'][
            0] if self.updated_rc_params is not None else 7.0
        figsize_width.insert(0, placeholder_width)
        figsize_width.grid(row=row_num, column=1)

        row_num += 1
        Label(custom_params_window, text="Figure height (px)").grid(row=row_num, column=0, sticky="e")
        figsize_height = Entry(custom_params_window)
        placeholder_height = self.updated_rc_params['figure.figsize'][
            1] if self.updated_rc_params is not None else 6.0
        figsize_height.insert(0, placeholder_height)
        figsize_height.grid(row=row_num, column=1)

        # User input label size
        row_num += 1
        Label(custom_params_window, text="Label sizes (px)").grid(row=row_num, column=0, sticky="e")
        label_size = Entry(custom_params_window)
        placeholder_label_size = self.updated_rc_params[
            'axes.labelsize'] if self.updated_rc_params is not None else 10.0
        label_size.insert(0, placeholder_label_size)
        label_size.grid(row=row_num, column=1)

        # User input title size
        row_num += 1
        Label(custom_params_window, text="Title font size (px)").grid(row=row_num, column=0, sticky="e")
        title_size = Entry(custom_params_window)
        placeholder_axes_titlesize = self.updated_rc_params[
            'axes.titlesize'] if self.updated_rc_params is not None else 14.0
        title_size.insert(0, placeholder_axes_titlesize)
        title_size.grid(row=row_num, column=1)

        print(" self.updated_rc_params STATUS:", self.updated_rc_params)
        # Marker_size
        row_num += 1
        Label(custom_params_window, text="Marker size (px)").grid(row=row_num, column=0, sticky="e")
        marker_size = Entry(custom_params_window)
        placeholder_legend_markersize = self.updated_rc_params[
            'lines.markersize'] if self.updated_rc_params is not None else 6.0
        marker_size.insert(0, placeholder_legend_markersize)
        marker_size.grid(row=row_num, column=1)

        # Legend markerscale
        row_num += 1
        Label(custom_params_window, text="Legend markerscale\n(Relative size to marker size) ").grid(
            row=row_num, column=0,
            sticky="e")
        legend_markerscale = Entry(custom_params_window)
        placeholder_legend_markerscale = self.updated_rc_params[
            'legend.markerscale'] if self.updated_rc_params is not None else 1.0
        legend_markerscale.insert(0, placeholder_legend_markerscale)
        legend_markerscale.grid(row=row_num, column=1)

        # Legend title size
        row_num += 1
        Label(custom_params_window, text="Legend title font size").grid(row=row_num, column=0, sticky="e")
        legend_title_fontsize = Entry(custom_params_window)
        placeholder_legend_title_size = self.updated_rc_params[
            'legend.title_fontsize'] if self.updated_rc_params is not None else 1.0
        legend_title_fontsize.insert(0, placeholder_legend_title_size)
        legend_title_fontsize.grid(row=row_num, column=1)

        row_num += 1
        Button(custom_params_window, text="Plot figure", command=lambda: plot_with_settings(), height=2,
               width=8).grid(row=row_num, column=0)

        custom_params_window.mainloop()

    def interactive_plot(self, data):
        """
        Input :
            window : tkinter window
            data   : DataFrame object
        """

        def close_plot_window():
            self.plot_window.destroy()

        def set_save_plot_bool():
            self.save_plot_bool = True
            self.plot_window.destroy()

        center_tk_window(self.plot_window, 750, 600)

        # Drop-down variables (3 drop-downs)
        dropdown_choice_x = StringVar(self.plot_window)  # Variable holding the dropdown selection for the x column
        dropdown_choice_y = StringVar(self.plot_window)  # Variable holding the dropdown selection for the y column
        dropdown_choice_category = StringVar(
            self.plot_window)  # Variable holding the dropdown selection for the category column

        # Create set of column names in the dataset
        choices = data.columns.values

        # Find numeric and string columns
        string_columns = []
        numeric_columns = []
        [numeric_columns.append(col) if is_numeric_dtype(data[col]) else string_columns.append(col) for col in
         data.columns]

        if len(numeric_columns) < 1:
            raise Exception("Unable to plot, there are too few numerical columns.")

        if len(numeric_columns) == 1:
            raise Exception(
                "Unable to create scatter plot- need more than two numerical columns in the imported dataset.")

        # GUI setup
        self.plot_window.columnconfigure(0, weight=1)
        self.plot_window.columnconfigure(1, weight=1)

        n_rows = 6
        for r in range(n_rows):
            self.plot_window.rowconfigure(r, weight=1)

        def update_ax_title(title):
            self.ax.set_title(title.get())
            self.canvas.draw()

        title = StringVar()
        title.trace("w", lambda name, index, mode, title=title: update_ax_title(title))

        # Set title
        Label(self.plot_window, text="Set plot title:").grid(row=0, column=0, sticky="e")
        e = Entry(self.plot_window, textvariable=title, width=23)
        e.grid(row=0, column=1, sticky="w")

        # Drop-down 1: x-value selection
        if len(numeric_columns) >= 1:
            x_values_column = numeric_columns[0]  # Select the first numeric column as the default x values to plot
            dropdown_choice_x.set(x_values_column)  # Set the default option in the dropdown with the first column
            Label(self.plot_window, text="Select x column:").grid(row=1, column=0, sticky="e")
            choices_numeric = numeric_columns  # Only show numeric columns in the drop-down for x and y
            dropdown_menu_x = OptionMenu(self.plot_window, dropdown_choice_x, *choices_numeric)
            dropdown_menu_x.grid(row=1, column=1, sticky="w")
            dropdown_menu_x.config(width=16)

            self.chosen_columns = {'x_col': x_values_column}

        # Drop-down 2: y-value selection
        if len(numeric_columns) >= 2:
            y_values_column = numeric_columns[1]  # Select the second alternative in the dropdown list for the y values
            dropdown_choice_y.set(y_values_column)  # Set the default option in the dropdown with the first column
            l2 = Label(self.plot_window, text="Select y column:")
            l2.grid(row=2, column=0, sticky='e')
            dropdown_menu_y = OptionMenu(self.plot_window, dropdown_choice_y, *choices_numeric)
            dropdown_menu_y.config(width=16)
            dropdown_menu_y.grid(row=2, column=1, sticky='w')

            self.chosen_columns = {'x_col': x_values_column,
                                   'y_col': y_values_column}

        if len(data.columns) > 2:  # There exist a third columns as well -> include drop-down for category selection
            # Drop-down 3: Category selections
            category_column = string_columns[0] if (len(string_columns) > 0) else numeric_columns[2]
            dropdown_choice_category.set(
                category_column)  # Set the default option in the dropdown with the first column
            l3 = Label(self.plot_window, text="Select category column:")
            l3.grid(row=3, column=0, sticky='e')
            dropdown_menu_category = OptionMenu(self.plot_window, dropdown_choice_category, *choices, 'Set title above')
            dropdown_menu_category.config(width=16)
            dropdown_menu_category.grid(row=3, column=1, sticky='w')

            self.chosen_columns = {'x_col': x_values_column,
                                   'y_col': y_values_column,
                                   'category_col': category_column}

        # Plot the initially selected columns
        self.fig, self.ax, self.initial_box = plot_data(data, self.chosen_columns)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.plot_window)
        self.canvas.get_tk_widget().grid(row=4, columnspan=2, rowspan=True)
        self.canvas.draw()

        def change_dropdown_x(*args):
            # This function is triggered once a dropdown selection is made

            selected_x_col = dropdown_choice_x.get()
            self.chosen_columns['x_col'] = selected_x_col
            # Create a new plot now
            self.ax.clear()  # Clearing the previous plot
            self.fig, self.ax, _ = plot_data(data, self.chosen_columns, self.ax, self.initial_box,
                                             self.fig)
            self.canvas.draw()

        # chosen columns might not be updated...
        def change_dropdown_y(*args):
            # This function is triggered once a dropdown selection is made
            selected_y_col = dropdown_choice_y.get()
            self.chosen_columns['y_col'] = selected_y_col
            # Create a new plot now
            self.ax.clear()  # Clearing the previous plot
            self.fig, self.ax, _ = plot_data(data, self.chosen_columns, self.ax, self.initial_box,
                                             self.fig)
            self.canvas.draw()

        def change_dropdown_category(*args):
            # This function is triggered once a dropdown selection is made
            selected_category = dropdown_choice_category.get()
            self.chosen_columns['category_col'] = selected_category
            # Create a new plot now
            self.ax.clear()  # Clearing the previous plot
            self.fig, self.ax, _ = plot_data(data, self.chosen_columns, self.ax, self.initial_box,
                                             self.fig)
            self.canvas.draw()

        # Link functions to change dropdown
        dropdown_choice_x.trace('w',
                                lambda *args: change_dropdown_x(
                                    *args))
        dropdown_choice_y.trace('w',
                                lambda *args: change_dropdown_y(
                                    *args))
        dropdown_choice_category.trace('w', lambda *args: change_dropdown_category(
            *args))

        def change_settings():
            self.plot_params_type = 'customize'
            self.set_plot_params(data)
            # self.plot_window.destroy()

        # Save and close buttons
        Button(self.plot_window, text="<- Change plot settings", command=change_settings, height=2, width=20).grid(
            row=5, columnspan=2)
        Button(self.plot_window, text="CLOSE", command=close_plot_window, height=2, width=8).grid(row=6, column=0)
        Button(self.plot_window, text="SAVE PLOT ->", command=set_save_plot_bool, height=2, width=8).grid(row=6,
                                                                                                          column=1)


# Create dummy data to plot
df = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD'))

# Add a category column to the DataFrame
labels = ['q', 'r', 's', 't']
df['Category'] = [labels[random.randint(0, len(labels) - 1)] for i in range(100)]

GenericPlot(df)



I have tried to change the x and y label sizes in the function called update_ax_title (just to debug:

And run the function

def update_ax_title(title):
    self.ax.set_title(title.get())  # Correct size (5.0)
    self.ax.set_xlabel('gdsgsdgsdgsdgdsg')  # Incorrect size...
    print(mpl.rcParams['axes.labelsize'])  # prints 5.0
    print(mpl.rcParams['axes.titlesize'])  # prints 5.0
    self.canvas.draw()

Only the title size is updated, even though the rcParams are updated globally. The x and y label sizes change after running specifically self.ax.set_xlabel('gdsgsdgsdgsdgdsg',fontsize=5)

How can this issue be solved? Thanks!

Marchesa answered 17/7, 2020 at 11:25 Comment(2)
I cannot reproduce. Can you create a minimal reproducible example and include the matplotlib version you are using?Gunflint
It seems to be a problem with seaborn. You have to use its functions and search for the seaborn way of doing those things.Assizes
P
3

Although I don't fully understand why it doesn't work, you can solve it by applying the labelsize manually to each axis using set_size:

 ...
            # Update canvas if the params were changed after it was drawn:
            if self.ax is not None:
                self.ax.clear()
                mpl.rcParams.update(self.updated_rc_params)

                # NEW!
                self.ax.xaxis.label.set_size(self.updated_rc_params['axes.labelsize'])
                self.ax.yaxis.label.set_size(self.updated_rc_params['axes.labelsize'])
                # End of New

                self.fig, self.ax, _ = plot_data(data, self.chosen_columns, self.ax,
                                                 self.initial_box, self.fig)
                self.canvas.draw()
            custom_params_window.destroy()  # Close the tk window
 ...
Priapitis answered 22/7, 2020 at 10:47 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.