Solution to produce flattened hierachy columns for a parent-child relation data.
Lets assume, we have the following dataframe
+---+---------+
|id |parent_id|
+---+---------+
|A |null |
|B |A |
|C |A |
|D |B |
|E |D |
+---+---------+
And, we need to produce a flattened hierarchy as like below
+---------+---------+---------+
|level1_id|level2_id|level3_id|
+---------+---------+---------+
|A |null |null |
|B |A |null |
|C |A |null |
|D |B |A |
|E |D |B |
+---------+---------+---------+
PySpark script to flatten hierarchy
from pyspark.sql.session import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql import DataFrame
spark = (
SparkSession.builder.master('local[*]')
.appName('test_app')
.getOrCreate()
)
def flatten_hierarchy(
df: DataFrame, node_id: str, node_parent_id: str,
max_hier_level: int, hierarchy_label: str=None, other_fields: list=None
) -> DataFrame:
hierarchy_label = f'_{hierarchy_label}' if hierarchy_label else ''
other_fields = other_fields if isinstance(other_fields, list) else []
df_hierarchy = (
df.withColumnRenamed(node_id, f'level1{hierarchy_label}_id')
.withColumnRenamed(node_parent_id, f'level2{hierarchy_label}_id')
)
for fld in other_fields:
df_hierarchy = df_hierarchy.withColumnRenamed(fld, f'level1{hierarchy_label}_{fld}')
i = 2
while i <= max_hier_level:
cur_level = f'level{i}{hierarchy_label}_id'
next_level = f'level{(i+1)}{hierarchy_label}_id'
next_level_tmp = f'level_{(i+1)}_tmp'
df_hlevel = (
df.withColumnRenamed(node_id, cur_level)
.withColumnRenamed(node_parent_id, next_level)
)
for fld in other_fields:
df_hlevel = df_hlevel.withColumnRenamed(fld, f'level{i}{hierarchy_label}_{fld}')
df_hierarchy = df_hierarchy.join(df_hlevel, cur_level, 'left')
df_hierarchy = df_hierarchy.select('*', df_hierarchy[next_level].alias(next_level_tmp))
df_hierarchy = df_hierarchy.drop(next_level)
df_hierarchy = df_hierarchy.withColumnRenamed(next_level_tmp, next_level)
i += 1
if i == max_hier_level+1:
df_hierarchy = df_hierarchy.drop(next_level)
return df_hierarchy.select(sorted(df_hierarchy.columns))
Usage options
Default options
Flattened level columns without label and additional columns
df = spark.createDataFrame(
[
('A', None),
('B', 'A'),
('C', 'A'),
('D', 'B'),
('E', 'D')
],
['id', 'parent_id'],
)
df.show(truncate=False)
df_test = flatten_hierarchy(
df=df,
node_id='id',
node_parent_id='parent_id',
max_hier_level=3,
)
df_test.show(truncate=False)
+---------+---------+---------+
|level1_id|level2_id|level3_id|
+---------+---------+---------+
|A |null |null |
|B |A |null |
|C |A |null |
|E |D |B |
|D |B |A |
+---------+---------+---------+
With additional options
With label for all flattened level columns, along with additional columns from the source data.
df = spark.createDataFrame(
[
('A', None, 'Node-A'),
('B', 'A', 'Node-B'),
('C', 'A', 'Node-C'),
('D', 'B', 'Node-D'),
('E', 'D', 'Node-E')
],
['id', 'parent_id', 'name'],
)
df.show(truncate=False)
+---+---------+------+
|id |parent_id|name |
+---+---------+------+
|A |null |Node-A|
|B |A |Node-B|
|C |A |Node-C|
|D |B |Node-D|
|E |D |Node-E|
+---+---------+------+
df_test = flatten_hierarchy(
df=df,
node_id='id',
node_parent_id='parent_id',
max_hier_level=3,
hierarchy_label='test',
other_fields=['name']
)
df_test.show(truncate=False)
+--------------+----------------+--------------+----------------+--------------+----------------+
|level1_test_id|level1_test_name|level2_test_id|level2_test_name|level3_test_id|level3_test_name|
+--------------+----------------+--------------+----------------+--------------+----------------+
|A |Node-A |null |null |null |null |
|B |Node-B |A |Node-A |null |null |
|D |Node-D |B |Node-B |A |Node-A |
|E |Node-E |D |Node-D |B |Node-B |
|C |Node-C |A |Node-A |null |null |
+--------------+----------------+--------------+----------------+--------------+----------------+