I am currently developing my first whole system using PySpark and I am running into some strange, memory-related issues. In one of the stages, I would like to resemble a Split-Apply-Combine strategy in order to modify a DataFrame. That is, I would like to apply a function to each of the groups defined by a given column and finally combine them all. Problem is, the function I want to apply is a prediction method for a fitted model that "speaks" the Pandas idiom, i.e., it is vectorized and takes a Pandas Series as an input.
I have then designed an iterative strategy, traversing the groups and manually applying a pandas_udf.Scalar in order to solve the problem. The combination part is done using incremental calls to DataFrame.unionByName(). I have decided not to use the GroupedMap type of pandas_udf because the docs state that the memory should be managed by the user, and you should have special care whenever one of the groups might be too large to keep it in memory or be represented by a Pandas DataFrame.
The main problem is that all the processing seems to run fine, but in the end I want to serialize the final DataFrame to a Parquet file. And it is at this point where I receive a lot of Java-like errors about DataFrameWriter, or out-of-memory exceptions.
I have tried the code in both Windows and Linux machines. The only way I have managed to avoid the errors has been to increase the --driver-memory value in the machines. The minimum value is different in every platform, and is dependent on the size of the problem, which somehow makes me suspect on memory leaks.
The problem did not happen until I started using pandas_udf. I think that there is probably a memory leak somewhere in the whole process of pyarrow serialization taking place under the hood when using a pandas_udf.
I have created a minimal reproducible example. If I run this script directly using Python, it produces the error. Using spark-submit and increasing a lot the driver memory, it is possible to make it work.
import pyspark
import pyspark.sql.functions as F
import pyspark.sql.types as spktyp
# Dummy pandas_udf -------------------------------------------------------------
@F.pandas_udf(spktyp.DoubleType())
def predict(x):
return x + 100.0
# Initialization ---------------------------------------------------------------
spark = pyspark.sql.SparkSession.builder.appName(
"mre").master("local[3]").getOrCreate()
sc = spark.sparkContext
# Generate a dataframe ---------------------------------------------------------
out_path = "out.parquet"
z = 105
m = 750000
schema = spktyp.StructType(
[spktyp.StructField("ID", spktyp.DoubleType(), True)]
)
df = spark.createDataFrame(
[(float(i),) for i in range(m)],
schema
)
for j in range(z):
df = df.withColumn(
f"N{j}",
F.col("ID") + float(j)
)
df = df.withColumn(
"X",
F.array(
F.lit("A"),
F.lit("B"),
F.lit("C"),
F.lit("D"),
F.lit("E")
).getItem(
(F.rand()*3).cast("int")
)
)
# Set the column names for grouping, input and output --------------------------
group_col = "X"
in_col = "N0"
out_col = "EP"
# Extract different group ids in grouping variable -----------------------------
rows = df.select(group_col).distinct().collect()
groups = [row[group_col] for row in rows]
print(f"Groups: {groups}")
# Split and treat the first id -------------------------------------------------
first, *others = groups
cur_df = df.filter(F.col(group_col) == first)
result = cur_df.withColumn(
out_col,
predict(in_col)
)
# Traverse the remaining group ids ---------------------------------------------
for i, other in enumerate(others):
cur_df = df.filter(F.col(group_col) == other)
new_df = cur_df.withColumn(
out_col,
predict(in_col)
)
# Incremental union --------------------------------------------------------
result = result.unionByName(new_df)
# Save to disk -----------------------------------------------------------------
result.write.mode("overwrite").parquet(out_path)
Shockingly (at least for me), the problem seems to vanish if I put a call to repartition() just before the serialization statement.
result = result.repartition(result.rdd.getNumPartitions())
result.write.mode("overwrite").parquet(out_path)
Having put this line into place, I can lower a lot the driver memory configuration, and the script runs fine. I can barely understand the relationship among all those factors, although I suspect lazy evaluation of the code and pyarrow serialization might be related.
This is the current environment I am using for development:
arrow-cpp 0.13.0 py36hee3af98_1 conda-forge
asn1crypto 0.24.0 py36_1003 conda-forge
astroid 2.2.5 py36_0
atomicwrites 1.3.0 py_0 conda-forge
attrs 19.1.0 py_0 conda-forge
blas 1.0 mkl
boost-cpp 1.68.0 h6a4c333_1000 conda-forge
brotli 1.0.7 he025d50_1000 conda-forge
ca-certificates 2019.3.9 hecc5488_0 conda-forge
certifi 2019.3.9 py36_0 conda-forge
cffi 1.12.3 py36hb32ad35_0 conda-forge
chardet 3.0.4 py36_1003 conda-forge
colorama 0.4.1 py36_0
cryptography 2.6.1 py36hb32ad35_0 conda-forge
dill 0.2.9 py36_0
docopt 0.6.2 py36_0
entrypoints 0.3 py36_0
falcon 1.4.1.post1 py36hfa6e2cd_1000 conda-forge
fastavro 0.21.21 py36hfa6e2cd_0 conda-forge
flake8 3.7.7 py36_0
future 0.17.1 py36_1000 conda-forge
gflags 2.2.2 ha925a31_0
glog 0.3.5 h6538335_1
hug 2.5.2 py36hfa6e2cd_0 conda-forge
icc_rt 2019.0.0 h0cc432a_1
idna 2.8 py36_1000 conda-forge
intel-openmp 2019.3 203
isort 4.3.17 py36_0
lazy-object-proxy 1.3.1 py36hfa6e2cd_2
libboost 1.67.0 hd9e427e_4
libprotobuf 3.7.1 h1a1b453_0 conda-forge
lz4-c 1.8.1.2 h2fa13f4_0
mccabe 0.6.1 py36_1
mkl 2018.0.3 1
mkl_fft 1.0.6 py36hdbbee80_0
mkl_random 1.0.1 py36h77b88f5_1
more-itertools 4.3.0 py36_1000 conda-forge
ninabrlong 0.1.0 dev_0 <develop>
nose 1.3.7 py36_1002 conda-forge
nose-exclude 0.5.0 py_0 conda-forge
numpy 1.15.0 py36h9fa60d3_0
numpy-base 1.15.0 py36h4a99626_0
openssl 1.1.1b hfa6e2cd_2 conda-forge
pandas 0.23.3 py36h830ac7b_0
parquet-cpp 1.5.1 2 conda-forge
pip 19.0.3 py36_0
pluggy 0.11.0 py_0 conda-forge
progressbar2 3.38.0 py_1 conda-forge
py 1.8.0 py_0 conda-forge
py4j 0.10.7 py36_0
pyarrow 0.13.0 py36h8c67754_0 conda-forge
pycodestyle 2.5.0 py36_0
pycparser 2.19 py36_1 conda-forge
pyflakes 2.1.1 py36_0
pygam 0.8.0 py_0 conda-forge
pylint 2.3.1 py36_0
pyopenssl 19.0.0 py36_0 conda-forge
pyreadline 2.1 py36_1
pysocks 1.6.8 py36_1002 conda-forge
pyspark 2.4.1 py_0
pytest 4.5.0 py36_0 conda-forge
pytest-runner 4.4 py_0 conda-forge
python 3.6.6 hea74fb7_0
python-dateutil 2.8.0 py36_0
python-hdfs 2.3.1 py_0 conda-forge
python-mimeparse 1.6.0 py_1 conda-forge
python-utils 2.3.0 py_1 conda-forge
pytz 2019.1 py_0
re2 2019.04.01 vc14h6538335_0 [vc14] conda-forge
requests 2.21.0 py36_1000 conda-forge
requests-kerberos 0.12.0 py36_0
scikit-learn 0.20.1 py36hb854c30_0
scipy 1.1.0 py36hc28095f_0
setuptools 41.0.0 py36_0
six 1.12.0 py36_0
snappy 1.1.7 h777316e_3
sqlite 3.28.0 he774522_0
thrift-cpp 0.12.0 h59828bf_1002 conda-forge
typed-ast 1.3.1 py36he774522_0
urllib3 1.24.2 py36_0 conda-forge
vc 14.1 h0510ff6_4
vs2015_runtime 14.15.26706 h3a45250_0
wcwidth 0.1.7 py_1 conda-forge
wheel 0.33.1 py36_0
win_inet_pton 1.1.0 py36_0 conda-forge
wincertstore 0.2 py36h7fe50ca_0
winkerberos 0.7.0 py36_1
wrapt 1.11.1 py36he774522_0
xz 5.2.4 h2fa13f4_4
zlib 1.2.11 h62dcd97_3
zstd 1.3.3 hfe6a214_0
Any hint or help would be much appreciated.