次の方法で共有


ユーザー定義集計関数 - Scala

この記事には、UDAF の例と、Apache Spark SQL で使用できるようにこれを登録する方法が含まれています。 詳細については、「ユーザー定義集計関数 (UDAF)」を参照してください。

UserDefinedAggregateFunction を実装する

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class GeometricMean extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", DoubleType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("count", LongType) ::
    StructField("product", DoubleType) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 1.0
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Long](0) + 1
    buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
  }
}

UDAF を Spark SQL に登録する

spark.udf.register("gm", new GeometricMean)

UDAF を使用する

// Create a DataFrame and Spark SQL table
import org.apache.spark.sql.functions._

val ids = spark.range(1, 20)
ids.createOrReplaceTempView("ids")
val df = spark.sql("select id, id % 3 as group_id from ids")
df.createOrReplaceTempView("simple")
-- Use a group_by statement and call the UDAF.
select group_id, gm(id) from simple group by group_id
// Or use DataFrame syntax to call the aggregate function.

// Create an instance of UDAF GeometricMean.
val gm = new GeometricMean

// Show the geometric mean of values of column "id".
df.groupBy("group_id").agg(gm(col("id")).as("GeometricMean")).show()

// Invoke the UDAF by its assigned name.
df.groupBy("group_id").agg(expr("gm(id) as GeometricMean")).show()