Данная статья является гайдом по использованию кастомных агрегаторов в Spark SQL API. Она “выросла” из моих заметок, которые я делал себе с начала работы со Spark. Сейчас, по мере накопления опыта, мне все это кажется уж слишком наивным и простым, но в свое время мне это показалось чертовски удобным/изящным/заслуживающим внимания, поэтому и решил опубликовать, тем более на Хабре про это еще вроде не писали.

Статья ориентирована в первую очередь на тех, кто только начинает работать со Spark, поэтому и помечена как “tutorial”. Если же вы опытный и у вас есть какие-либо интересные кейсы по использованию кастомных агрегаторов - делитесь в комментариях!

Ниже мы будем говорить о user-defined aggregations functions (UDAF) org.apache.spark.sql.expressions.Aggregator, которые могут быть использованы для DataSet’ов с целью агрегации группы элементов в одно значение каким-угодно-пользователю образом.

Давайте сначала разберем пример из официальной документации, реализующий простое суммирование.

case class Data(i: Int)

val customSummer =  new Aggregator[Data, Int, Int] {
     def zero: Int = 0
     def reduce(b: Int, a: Data): Int = b + a.i
     def merge(b1: Int, b2: Int): Int = b1 + b2
     def finish(r: Int): Int = r
     def bufferEncoder: Encoder[Int] = Encoders.scalaInt
     def outputEncoder: Encoder[Int] = Encoders.scalaInt
   }.toColumn()

val ds: Dataset[Data] = ...
val aggregated = ds.select(customSummer)

Case class “Data” нам нужен потому, что UDAF используются для DataSet’ов.  Для создания своего агрегатора нам надо определить 6 функций с предопределенным названием:

  • zero - оно же нулевое, или начальное значение; должна удовлетворять требованию: “нечто” + zero = “нечто”

  • reduсe - функция reduсe, выполняющая нашу агрегацию в буфер

  • merge - функция для мёржа буферов

  • finish - финишная обработка для получения целевого значения; в конкретном примере ее по сути нет, но часто бывает полезной (увидим в примерах ниже)

  • два энкодера для буфера и выходного значения

Обратите внимание, что по сути агрегация производится в 3 этапа:

  1. сначала в буферы

  2. потом идет агрегация буферов

  3. потом финальная обработка

и поэтому, даже если вам не нужен какой-либо этап, все равно надо определять данные функции. С другой стороны, вряд ли вы будете использовать UDAF для простого суммирования, а для сложных агрегаций такое разбиение на этапы бывает очень полезным. Также следует обратить внимание, что необходимо при создании агрегатора определить типы входного, буферного, и выходного значения. Входным типом обычно является case class для элементов DataSet’а

Давайте рассмотрим учебный “а-ля word count”-пример, реализованный на агрегаторах. Пусть у нас есть простой CSV-файл с твитами пользователей:

userId,tweet
f6e8252f,cat dog frog cat dog frog frog dog dog bird
f6e8252f,cat dog frog dog dog bird
f6e8252f,cat dog
29d89ee4,frog frog dog bird
29d89ee4,frog cat dog frog frog dog dog bird
29d89ee4,frog bird

Определим для каждого пользователя самое употребялемое им слово. Естественно задачу можно решить многими способами, но мы в учебных целях покажем, как ее решить с помощью UDAF.

Обратите внимание на указание типов <IN,BUF,OUT>:

  • входной тип у нашего агрегатора - case class Tweet, описываюший наши данные

  • для буфера мы используем тип Map[String, Int]

  • ну и выходной тип - просто String, так как возвращается самое популярное слово.

// Даем алиас типу для удобства и создаем case class
type myMap = Map[String, Int]
case class Tweet(userID: String, tweet: String)

// Определяем сам агрегатор
val FavoriteWordAggregator = new Aggregator[Tweet, myMap, String] {
// Вспомогательная функция “сложения” двух Map, которые будем ниже использовать для подсчета слов
  def addMap(map1: myMap, map2: myMap): myMap = {
    map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0)) }
  }
  def zero: myMap = Map.empty[String, Int]
  def reduce(accum: myMap, a: Tweet): myMap = {
    val aMap = a.tweet.split(" ").groupBy(identity).mapValues(_.size)
    addMap(accum, aMap)
  }
  def merge(a: myMap, b: myMap): myMap = addMap(a, b)
// Обратите внимание: здесь у нас в отличии от предыдущего примера есть финишная обработка - выбираем самое популярное слово
  def finish(map: myMap): String = map.toList.sortBy(-_._2).head._1
  def bufferEncoder: Encoder[myMap] = ExpressionEncoder()
  def outputEncoder: Encoder[String]= Encoders.STRING
}.toColumn

// Ну и теперь можем “в одну строчку” реализовать требуемую логику
ds.groupByKey(_.userID)
  .agg(FavoriteWordAggregator.name("favoriteWord"))
  .withColumnRenamed("value", "userId")
  .show()

В итоге получаем такой результат:

+--------+---------------+
|userId  |favoriteWord   |
+--------+---------------+
|29d89ee4|frog           |
|f6e8252f|dog            |
+--------+---------------+

Рассмотрим другой, более сложный пример, который на мой взгляд лучше может показать удобство UDAF. Пусть у нас есть кликстримы пользователей примерно в таком виде:

userId,eventId,eventType,eventTime,attributes
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,3fce7a72-6aa5-4a81-90a2-2060bd003bf4,app_open,2020-12-13 15:37:31,"{'campaign_id': 332, 'channel_id': 'Facebook Ads'}"
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,b49802f8-62a6-4248-ad48-3b5a0ac36ac3,search_product,2020-12-13 15:45:34,
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,08b51fee-d62d-4b21-9352-0df7909e4560,view_product_details,2020-12-13 15:48:07,
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,bdde7a28-c32d-4770-bb6a-0760339cf83d,purchase,2020-12-13 16:00:20,{'purchase_id': 'e3e70682-c209-4cac-a29f-6fbed82c07cd'}
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,be7741d8-56c9-47cb-896f-621d8d68027c,app_close,2020-12-13 16:04:20,
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,ee2c31f1-d02c-4bac-80d4-2131e7f391c5,app_open,2020-12-26 08:09:50,"{'campaign_id': 859, 'channel_id': 'VK Ads'}"
f6e8252f-b5cc-48a4-b348-29d89ee4fa9e,2a2fa632-8f20-42c7-b5e6-07afc38b2b9f,search_product,2020-12-26 08:16:23
...

То есть для каждого пользователя мы отслеживаем его действия - заход на сайт, поиск продукта, просмотр деталей продукта, оформление заказа, закрытие сайта и т.п. Причем для определенных действий у нас могут быть дополнительные данные, помещенные в поле attributes: 

  • для первичного входа это, например рекламная кампания (campaign_id) и канал привлечения (channel_id)

  • при оформлении заказа генерируется purchase_id, благодаря которому можем связать данный лид с данными из других систем. 

Для оценки эффективности компаний и каналов нас могут попросить сделать выгрузку в примерно таком виде, на основании которого простыми группировками аналитики отдела маркетинга смогут ответить на различные вопросы по оценке рекламных компаний, каналов продвижения и т.п.

root
   |-- sessionId
   |-- campaign_id
   |-- channel_id
   |-- purchase_id

 Здесь у нас стоит 2 задачи:

  • нам надо проанализировать поток логов и “наиграть” sessionId

  • “вытащить” из JSON-подобных текстовых полей нужные нам атрибуты в колонки, причем первые два содержатся в атрибутах с eventType = “app_open”, а последний - в eventType = “purchase”

попутно преобразовать все к простой структуре, где одной sessionId соответствует 1 строка данных

Первая задача хоть и не относится к теме статьи, но затронем ее немного. Один из вариантов ее решения - применить оконную функцию. Идея в следующем:

  • Группируем по userId и сортируем по eventTime

  • Помечаем все eventType = “app_open” (то есть запись, которую )

  • Суммируем нарастающим итогом с применением оконной функции в сортировке по eventTime, получаем номер сессии пользователя

  • Уникальный идентификатор всей сессии можно сгенерировать, например конкатенируя userId и номер сессии

В коде это будет выглядеть примерно так

row_df = spark.read
  .options(Map("header" -> "true", "inferSchema" -> "true"))
  .csv(path) 

val startSessionFlg: Column = {
  when ($"eventType" === "app_open", 1)
    .otherwise(0)
    .as("startSessionFlg")
}

val window = Window.partitionBy("userId").orderBy("eventTime")

val df = row_df
  .withColumn("startSessionFlg", startSessionFlg) 
  .withColumn("userSessionNum",sum($"startSessionFlg").over(window))
  .withColumn("sessionId", concat($"userId", lit("-"), $"userSessionNum"))
  // попутно преобразуем JSON-строку в колонку MapType
  .withColumn("mapAttributes", from_json($"attributes",MapType(StringType, StringType)))
  // оставим только нужные для последующей обработке колонки
  .select($"sessionId", $"mapAttributes")

Итого, мы имеем на текущий момент такой DataFrame

root
 |-- sessionId: string (nullable = true)
 |-- mapAttributes: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

+--------------------------------------+-----------------------------------------------------+
|sessionId                             |mapAttributes                                        |
+--------------------------------------+-----------------------------------------------------+
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|[campaign_id -> 478, channel_id -> Twitter Ads]      |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|[purchase_id -> d4713d60-c8a7-4639-ab11-67b367a9c378]|
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|[campaign_id -> 332, channel_id -> Facebook Ads]     |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|[purchase_id -> e3e70682-c209-4cac-a29f-6fbed82c07cd]|
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-3|[campaign_id -> 859, channel_id -> VK Ads]           |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-3|null                                                 |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-3|null                                                 |
…

А напомним, что наша задача - получить структуру со следующей схемой, где одной сессии соответствует одна строка. Налицо необходимость агрегации по sessionId и какой-то хитрой обработке колонки mapAttributes.

Как и хотели выше, сделаем это с помощью UDAF, который будет “собирать” данные колонки mapAttributes в одну Map

// Дадим алиас типу для удобства
type StringMap = Map[String, String] 
// Простенький case class для DataSet
case class SessionAttrs(sessionID: String, mapAttributes: StringMap)

// сам агрегатор
val MapAggregator = new Aggregator[SessionAttrs, StringMap, StringMap] {
  // начальное (“нулевое”) значение
  def zero: StringMap = Map.empty[String, String]
  // функция, описывающая агрегацию (в “буферы”)
  // в нашем случае мы “складываем” Map’ы
  def reduce(accum: StringMap, a: SessionAttrs): StringMap = accum ++ a.mapAttributes
  // функция для мержа значений “буферов”
  def merge(map1: StringMap, map2: StringMap): StringMap = map1 ++ map2
  // финишная предобработка, если нужна (в нашем случае - не нужна)
  def finish(result: StringMap): StringMap = result
  // используемые энкодеры для промежуточного буфера и выхода
  def bufferEncoder: Encoder[StringMap] = ExpressionEncoder()
  def outputEncoder: Encoder[StringMap] = ExpressionEncoder()
}.toColumn

// трансформируем DataFrame в DataSet
val ds = df
  .na.drop() // не забываем перед агрегацией, чтобы не было NullPointerException
  .as[SessionAttrs]

// Агрегация с использованием агрегатора
val ds1 = ds.groupByKey(_.sessionID)
  .agg(MapAggregator.name("attrs"))

ds1.show(false)

+--------------------------------------+-----------------------------------------------------------------------------------------------------+
|value                                 |attrs                                                                                                |
+--------------------------------------+-----------------------------------------------------------------------------------------------------+
|ba192cc2-f3e8-4871-9024-426da37bfafc-1|[campaign_id -> 559, channel_id -> Twitter Ads, purchase_id -> 9558867f-5ba9-4faf-ba02-4204f7c1bd87] |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-3|[campaign_id -> 859, channel_id -> VK Ads, purchase_id -> 82e2e662-f728-44fa-8248-5e3a0a5d2f34]      |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|[campaign_id -> 478, channel_id -> Twitter Ads, purchase_id -> d4713d60-c8a7-4639-ab11-67b367a9c378] |
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|[campaign_id -> 332, channel_id -> Facebook Ads, purchase_id -> e3e70682-c209-4cac-a29f-6fbed82c07cd]|
+--------------------------------------+-----------------------------------------------------------------------------------------------------+

Данные успешно агрегированы! Теперь уже дело техники - вытащить данные в разные колонки:

// Ну и наконец, вытащим атрибуты в отдельные колонки
ds1
 .withColumn("campaignId", element_at($"attrs", "campaign_id"))
 .withColumn("channelId", element_at($"attrs", "channel_id"))
 .withColumn("purchaseId", element_at($"attrs", "purchase_id"))
 .withColumnRenamed("value", "sessionId")
 .select($"sessionId", $"campaignId", $"channelId", $"purchaseId")
 .orderBy("sessionId")
 .show(false)

+--------------------------------------+----------+------------+------------------------------------+
|sessionId                             |campaignId|channelId   |purchaseId                          |
+--------------------------------------+----------+------------+------------------------------------+
|ba192cc2-f3e8-4871-9024-426da37bfafc-1|559       |Twitter Ads |9558867f-5ba9-4faf-ba02-4204f7c1bd87|
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-1|478       |Twitter Ads |d4713d60-c8a7-4639-ab11-67b367a9c378|
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-2|332       |Facebook Ads|e3e70682-c209-4cac-a29f-6fbed82c07cd|
|f6e8252f-b5cc-48a4-b348-29d89ee4fa9e-3|859       |VK Ads      |82e2e662-f728-44fa-8248-5e3a0a5d2f34|
+--------------------------------------+----------+------------+------------------------------------+

Просто? Вроде просто! Но несмотря на это, я все же сторонник подхода: “Если задача может быть решена стандартными средствами, то лучше решать это стандартными средствами”, поэтому не переусердствуйте)

Надеюсь, данный туториал будет кому-нибудь полезен :-)

А вы используете UDAF? Поделитесь примерами в комментариях.

Комментарии (3)


  1. sshikov
    01.11.2021 19:07

    На мой взгляд, тут можно было бы добавить упоминание об очень похожих операциях над RDD. treeReduce/treeAggregate по сути, очень похожи на UDAF, и основаны на тех же самых фундаментальных вещах — свертке (которая выполняется на партиции), и merge (слияние результатов на двух партициях в один). Но и так очень даже неплохо.


    1. Ninil Автор
      01.11.2021 20:59

      Спасибо за комментарий! Я хотел изначально, но потом отказался по двум причинам:

      • Объем для туториала и так получился не очень маленький

      • Все же сейчас RDD используется все реже и реже. В своей практике я, пожалую уже года 3-4 его не использовал.


      1. sshikov
        01.11.2021 21:56

        Да, я тоже стараюсь оставаться в рамках Dataset. Пожалуй, у нас в коде осталось одно место, где есть treeAggregate — это построение блум фильтра. Просто потому, что когда его писали, пример на RDD попался на глаза )