I believe the top answer shows how to execute a command/stored procedure but not how to get the result if the stored procedure returns a table. This is the first result on Google for me, so hope this helps someone.
Solution
The reason executing the stored procedure fails in the first place is that spark parenthesizes the query and assigns it to an alias (select * from (query)
). More details on that here https://mcmap.net/q/1325885/-how-do-i-set-quot-for-fetch-only-quot-when-querying-ibm-db2-using-the-jdbc-driver-from-spark
I have two different sets of code one pyspark, one scala. Both do the same thing using the jdbc driver. I like the scala better. Right now it requires the dataset to fit in memory but you can do stuff with batching to work around it.
A lot of the documentation on how to use the driver is here: https://learn.microsoft.com/en-us/sql/connect/jdbc/microsoft-jdbc-driver-for-sql-server?view=sql-server-ver16
Both sets of code:
- create a connection object to the server
- make a sql query statement and execute it
- fetch a metadata and result set
- extract information on the columns/count
- loop through every row of the result set and every field within it to extract the values, populating a row list for the fields, and then appending it to the rows list.
- Convert the resulting list of lists into a dataframe
Python
There's an issue with dates staying a JavaObject type although pandas reads it just fine. If you did spark.createDataFrame(pd.DataFrame(...)) it'd complain about the type, but an intermediary csv works.
I use the getObject because it extracts whatever the result is. There are other methods like getString,getFloat,... that would be better
This is the worst possible scenario though for py4j. You are serializing every single value in the entire table one by one to send from scala over to python.
%python
import pandas as pd
# parameters
username = ""
password = ""
host = ""
port = ""
database = ""
query = """
exec ...
"""
# construct jdbc url
sqlsUrl = f"jdbc:sqlserver://{host}:{port};database={database}"
# get the gateway/connection to py4j
gateway = sc._gateway
jvm = gateway.jvm
# connection to the server
con = jvm.java.sql.DriverManager.getConnection(sqlsUrl, username,password)
# create a statement, execute it, get result set and metadata
statement = con.prepareCall(query)
statement.execute()
metadata = statement.getMetaData()
resultset = statement.getResultSet()
# extract column names from metadata
columns = [metadata.getColumnName(i+1) for i in range(metadata.getColumnCount())]
# loop through the result set and make into a list of lists
rows = []
while resultset.next():
row = []
for i in range(len(columns)):
row.append(resultset.getObject(i+1))
rows.append(row)
# close the connection
con.close()
# make into a pandas dataframe and write temporarily to a csv (bug with date)
pd.DataFrame(rows, columns=columns).to_csv("tmp.csv",index=False)
df = spark.createDataFrame(pd.read_csv("tmp.csv"))
# voila
display(df)
Scala
%scala
import org.apache.spark.sql.types._
import java.sql.DriverManager
import java.sql.ResultSet
// connection parameters
val username = ""
val password = ""
val host = ""
val port = ""
val database = ""
val sqlsUrl = s"jdbc:sqlserver://$host:$port;databaseName=$database"
// query
val query = """
exec ..
"""
// get connection
val connection = DriverManager.getConnection(sqlsUrl, username,password)
// prepare statement and execute it
val statement = connection.prepareCall(query)
statement.executeQuery()
// fetch results and the structure of the results
val metaData = statement.getMetaData()
val resultSet = statement.getResultSet()
val indices = (1 to metaData.getColumnCount).toList
// translation of java types to spark types
val columnTypesSpark = Map(
"java.lang.String"-> StringType,
"java.lang.Short"-> ShortType,
"java.sql.Date"-> DateType,
"java.sql.Timestamp"-> TimestampType,
"java.math.BigDecimal"-> DecimalType(10,1), // whatever precision you want
"java.lang.Float" -> FloatType,
"java.lang.Integer" -> IntegerType,
"java.lang.Boolean" -> BooleanType)
// list out the column types in the returned data
val columnTypes = indices.map(i => columnTypesSpark(metaData.getColumnClassName(i)) )
// list out the column names in the returned data
val columnNames = indices.map(i => metaData.getColumnLabel(i))
// define the schema
val schema = StructType(indices.map(i => StructField(columnNames(i-1),columnTypes(i-1)) ))
// loop through the results dataset
val results: List[Row] = Iterator
.continually {
if (resultSet.next()) Some(Row(indices.map(o => resultSet.getObject(o)).toList:_*))
else None
}
.takeWhile(_.isDefined)
.map(_.get)
.toList
// close connection
con.close()
// convert results rowset into an RDD and then assign results into a dataframe
val df = spark.createDataFrame(sc.parallelize(results),schema)
display(df)
Long answer. Hope that helps someone.