How to use list comprehension on a column with array in pyspark?

I have a pyspark dataframe that looks like this.

+--------------------+-------+--------------------+
|              ID    |country|               attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2,3,4...]        |
|3de27656            |     US|[1,7,2,4...]        |
|75ce4e58            |     US|[1,2,1,4...]        |
|908df65c            |     US|[1,8,3,0...]        |
|f0503257            |     US|[1,2,3,2...]        |
|2tBxD6j             |     US|[1,2,3,4...]        |
|33811685            |     US|[1,5,3,5...]        |
|aad21639            |     US|[7,8,9,4...]        |
|e3d9e3bb            |     US|[1,10,9,4...]       |
|463f6f69            |     US|[12,2,13,4...]      |
+--------------------+-------+--------------------+

I also have a set that looks like this

reference_set = (1,2,100,500,821)

what I want to do is create a new list as a column in the dataframe using maybe a list comprehension like this [attr for attr in attrs if attr in reference_set]

so my final dataframe should be something like this

+--------------------+-------+--------------------+
|              ID    |country|      filtered_attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2]               |
|3de27656            |     US|[1,2]               |
|75ce4e58            |     US|[1,2]               |
|908df65c            |     US|[1]                 |
|f0503257            |     US|[1,2]               |
|2tBxD6j             |     US|[1,2]               |
|33811685            |     US|[1]                 |
|aad21639            |     US|[]                  |
|e3d9e3bb            |     US|[1]                 |
|463f6f69            |     US|[2]                 |
+--------------------+-------+--------------------+

How can I do this? as I'm new to pyspark I can't think of a logic.

Edit : posted a logic below, if there's a more efficient way of doing this please let me know.

2 answers

  • answered 2022-01-13 05:30 Shashank Setty

    I managed to use the filter function paired with a UDF to make this work.

    def filter_items(item):
        if item in reference_set:
            return True
        else:
            return False
    
    custom_udf = udf(lambda attributes : list(filter(filter_items, attributes)))
    processed_df = df.withColumn('filtered_attrs',custom_udf(col('attrs')))
    

    This gives me the required output

  • answered 2022-01-13 05:55 Mohana B C

    You can use built-in function - array_intersect.

    # Sample dataframe
    
    df = spark.createDataFrame([('ffae10af', 'US', [1,2,3,4])], ["ID", "Country", "attrs"])
    
    reference_set = {1,2,100,500,821}
    
    # This step is to add set as column in dataframe
    set_to_string = ",".join([str(x) for x in reference_set])
    
    df.withColumn('reference_set', split(lit(set_to_string), ',').cast('array<bigint>')). \
    withColumn('filtered_attrs', array_intersect('attrs','reference_set'))\ 
    .show(truncate = False)
    
    +--------+-------+------------+---------------------+--------------+
    |ID      |Country|attrs       |reference_set        |filtered_attrs|
    +--------+-------+------------+---------------------+--------------+
    |ffae10af|US     |[1, 2, 3, 4]|[1, 2, 100, 500, 821]|[1, 2]        |
    +--------+-------+------------+---------------------+--------------+
    

How many English words
do you know?
Test your English vocabulary size, and measure
how many words do you know
Online Test
Powered by Examplum