PySpark: DataFrame - Convert Struct to Array
Asked Answered
C

2

5

I have a dataframe in the following structure:

root
 |-- index: long (nullable = true)
 |-- text: string (nullable = true)
 |-- topicDistribution: struct (nullable = true)
 |    |-- type: long (nullable = true)
 |    |-- values: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |-- wiki_index: string (nullable = true)

I need to change it to:

root
 |-- index: long (nullable = true)
 |-- text: string (nullable = true)
 |-- topicDistribution: array (nullable = true)
 |    |--  element: double (containsNull = true)
 |-- wiki_index: string (nullable = true)

May I ask how can I do that?

Thanks a lot.

Candlestand answered 3/12, 2017 at 8:24 Comment(0)
W
10

I think you're looking for

df.withColumn("topicDistribution", col("topicDistribution").getField("values"))
Wrote answered 3/12, 2017 at 8:34 Comment(3)
This is an interesting use case and solution. However, the topicDistribution column remains of type struct and not array and I have not yet figured out how to convert between these two types.Ingulf
How can this be done dynamically . my withColumn should dynamically create all the columsn based on column name of keys ?Chamkis
I don't have code on hand, but you can do something like: 1. struct_keys = ...## go through schema to figure out the column keys 2. new_cols = [col("yourStruct").getField(kk) for kk in struct_keys] 3. df.select(*(new_cols + orig_cols))Wrote
L
0

I am just adding full working code for some peoples who wants to understand whole working code examples.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType, LongType, ArrayType, DoubleType

if __name__ == '__main__':
    # Create a Spark session
    spark = SparkSession.builder.appName("example").getOrCreate()

    # Define the schema
    schema = StructType([
        StructField("index", LongType(), True),
        StructField("text", StringType(), True),
        StructField("topicDistribution", StructType([
            StructField("type", LongType(), True),
            StructField("values", ArrayType(DoubleType()), True)
        ]), True),
        StructField("wiki_index", StringType(), True)
    ])

    # Sample data
    data = [
        (1, "Sample Text 1", (100, [0.3, 0.5, 0.2]), "wiki_1"),
        (2, "Sample Text 2", (101, [0.1, 0.8, 0.1]), "wiki_2"),
        (3, "Sample Text 3", (102, [0.6, 0.2, 0.2]), "wiki_3")
    ]
    df = spark.createDataFrame(data, schema=schema)

    # Accessing the 'values' array within 'topicDistribution'
    df_values = df.withColumn("topicDistribution",col("topicDistribution.values"))
    df_values.printSchema()
    df_values.show(truncate=False)
Lacerta answered 22/12, 2023 at 18:8 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.