Computing number of business days between start/end columns

I have two Dataframes

• facts:
• columns: data, start_date and end_date
• holidays:
• column: holiday_date

What I want is a way to produce another Dataframe that has columns: data, start_date, end_date and num_holidays

Where num_holidays is computed as: Number of days between start and end that are not weekends or holidays (as in the holidays table).

The solution is here if we wanted to do this in PL/SQL. Crux is this part of code:

--Calculate and return the number of workdays using the input parameters.
--This is the meat of the function.
--This is really just one formula with a couple of parts that are listed on separate lines for documentation purposes.
RETURN (
SELECT
(DATEDIFF(dd,@StartDate, @EndDate)+1)
--Subtact 2 days for each full weekend
-(DATEDIFF(wk,@StartDate, @EndDate)*2)
--If StartDate is a Sunday, Subtract 1
-(CASE WHEN DATENAME(dw, @StartDate) = 'Sunday'
THEN 1
ELSE 0
END)
--If EndDate is a Saturday, Subtract 1
-(CASE WHEN DATENAME(dw, @EndDate) = 'Saturday'
THEN 1
ELSE 0
END)
--Subtract all holidays
-(Select Count(*) from [dbo].[tblHolidays]
where  [HolDate] between @StartDate and @EndDate )
)
END

I'm new to pyspark and was wondering what's the efficient way to do this? I can post the udf I'm writing if it helps though I'm going slow because I feel it's the wrong thing to do:

• Is there a better way than creating a UDF that reads the holidays table in a Dataframe and joins with it to count the holidays? Can I even join inside a udf?
• Is there a way to write a pandas_udf instead? Would it be faster enough?
• Are there some optimizations I can apply like cache the holidays table somehow on every worker?

Something like this may work:

from pyspark.sql import functions as F

df_facts = spark.createDataFrame(
[('data1', '2022-05-08', '2022-05-14'),
('data1', '2022-05-08', '2022-05-21')],
['data', 'start_date', 'end_date']
)
df_holidays = spark.createDataFrame([('2022-05-10',)], ['holiday_date'])

df = df_facts.withColumn('exploded', F.explode(F.sequence(F.to_date('start_date'), F.to_date('end_date'))))
df = df.filter(~F.dayofweek('exploded').isin([1, 7]))
df = df.join(F.broadcast(df_holidays), df.exploded == df_holidays.holiday_date, 'anti')

df.show()
# +-----+----------+----------+-------------+
# +-----+----------+----------+-------------+
# |data1|2022-05-08|2022-05-14|            4|
# |data1|2022-05-08|2022-05-21|            9|
# +-----+----------+----------+-------------+