at 23.11-pre 1.3 kB view raw
1from pyspark.sql import Row, SparkSession 2from pyspark.sql import functions as F 3from pyspark.sql.functions import udf 4from pyspark.sql.types import * 5from pyspark.sql.functions import explode 6 7def explode_col(weight): 8 return int(weight//10) * [10.0] + ([] if weight%10==0 else [weight%10]) 9 10spark = SparkSession.builder.getOrCreate() 11 12dataSchema = [ 13 StructField("feature_1", FloatType()), 14 StructField("feature_2", FloatType()), 15 StructField("bias_weight", FloatType()) 16] 17 18data = [ 19 Row(0.1, 0.2, 10.32), 20 Row(0.32, 1.43, 12.8), 21 Row(1.28, 1.12, 0.23) 22] 23 24df = spark.createDataFrame(spark.sparkContext.parallelize(data), StructType(dataSchema)) 25 26normalizing_constant = 100 27sum_bias_weight = df.select(F.sum('bias_weight')).collect()[0][0] 28normalizing_factor = normalizing_constant / sum_bias_weight 29df = df.withColumn('normalized_bias_weight', df.bias_weight * normalizing_factor) 30df = df.drop('bias_weight') 31df = df.withColumnRenamed('normalized_bias_weight', 'bias_weight') 32 33my_udf = udf(lambda x: explode_col(x), ArrayType(FloatType())) 34df1 = df.withColumn('explode_val', my_udf(df.bias_weight)) 35df1 = df1.withColumn("explode_val_1", explode(df1.explode_val)).drop("explode_val") 36df1 = df1.drop('bias_weight').withColumnRenamed('explode_val_1', 'bias_weight') 37 38df1.show() 39 40assert(df1.count() == 12)