How to cache a plot in streamlit?
Asked Answered
M

2

6

I have built a dashboard in streamlit where you can select a client_ID and have SHAP plots displayed (Waterfall and Force plot) to interpret the prediction of credit default for this client.

I also want to display a SHAP summary plot with the whole train dataset. The later does not change every time you make a new prediction, and takes a lot of time to plot, so I want to cache it. I guess the best approach would be to use st.cache but I have not been able to make it.

Here below is the code I have unsuccessfully tried in main.py: I first define the function of which I want to cache the output (fig), then I execute the output in st.pyplot. It works without the st.cache decorator, but as soon as I add it and rerun the app, the function summary_plot_all runs indefinitely

IN:

@st.cache    
def summary_plot_all():
    fig, axes = plt.subplots(nrows=1, ncols=1)
    shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
    prep_train.columns, max_display=50)
    return fig
    
st.pyplot(summary_plot_all())

OUT (displayed in streamlit app)

Running summary_plot_all().

Does anyone know what's wrong or a better way of caching a plot in streamlit ?

version of packages:
streamlit==0.84.1, 
matplotlib==3.4.2, 
shap==0.39.0
Maymaya answered 27/8, 2021 at 9:2 Comment(0)
W
1

Try

import matplotlib

@st.cache(hash_funcs={matplotlib.figure.Figure: lambda _: None})
def summary_plot_all():
    fig, axes = plt.subplots(nrows=1, ncols=1)
    shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
    prep_train.columns, max_display=50)
    return fig

Check this streamlit github issue

Wasteful answered 15/10, 2021 at 8:45 Comment(0)
C
0

You can also serialize your plot to PNG so st.cache_data works, here's a demo.

import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
    

@st.cache_data
def generate_plot():
    print("generating plot...")

    x = np.linspace(0, 10, 100)
    y = np.sin(x)
    plt.figure(figsize=(10, 5))
    plt.plot(x, y)
    plt.title('Sine Wave')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.grid()

    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    val = buf.getvalue()
    return val

st.title('Slow Plot Generator')
# you'll only see "generating plot..." once since the second call is cached
st.image(generate_plot(), caption='Sine Wave Plot', use_column_width=True)
st.image(generate_plot(), caption='Sine Wave Plot', use_column_width=True)
Citizen answered 31/10, 2024 at 14:56 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.