Это третья статья по теме реализации масштабируемой системы для выполнения задач распределенного машинного обучения на GPU с использованием Java, Kotlin, Spring и Spark. Список всех статей:

  1. Варианты использования Java ML библиотек совместно со Spring, Docker, Spark, Rapids, CUDA

  2. Масштабируемая Big Data система в Kubernetes с использованием Spark и Cassandra

  3. Использование Kotlin и WebFlux для выполнения задач ML в Apache Spark на GPU

О чем данная статья

В предыдущей статье для создания Spark Driver приложения использовался сервлетный стек Spring (Boot 2.7.11) и JDK 8.

На дворе вторая половина 2023 года, у многих в проде уже используется Boot 3+ (а то и 3.1+), совсем скоро должна выйти новая LTS версия Java, и, мягко говоря, Boot 2+ и JDK8 устарели. Использовались они намеренно, так как для задач тренировки моделей машинного обучения на GPU в среде Spark частью системы является ускоритель вычислений на GPU NVidia Rapids. Поддержка JDK 17 появилась только в релизе v23.06.0 от 27.06.23, с ее выходом появилась возможность перейти на актуальную LTS версию Java, а с ней - на Spring Boot 3+.

В данной статье описывается миграция с Boot 2 и JDK 8 До Boot 3 и JDK 17, со Spring Web на Spring WebFlux, в конце сравниваются Web и WebFlux версии по потреблению аппаратных ресурсов и скорости выполнения.

JDK8, Spring boot 2.7.11 → JDK17, Spring Boot 3.1.1

Для миграции достаточно поднять версии Rapids до 23.06.0, JDK до 17, Spring Boot до 3.1.1. Нюансов не так уж и много:

  1. Конфликт логеров Slf4j и Log4j при использовании Spark: из зависимости spring boot starter web исключаем spring boot starter logging:

pom.xml
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
    <version>${spring.boot.version}</version>
    <exclusions>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-logging</artifactId>
        </exclusion>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-tomcat</artifactId>
        </exclusion>
    </exclusions>
</dependency>

  1. Запускать Spark Driver на JDK 17 необходимо со следующими параметрами (приведено для Dockerfile):

Application Dockerfile
ENV JAVA_OPTS='--add-opens=java.base/java.lang=ALL-UNNAMED \
               --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \
               --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \
               --add-opens=java.base/java.io=ALL-UNNAMED \
               --add-opens=java.base/java.net=ALL-UNNAMED \
               --add-opens=java.base/java.nio=ALL-UNNAMED \
               --add-opens=java.base/java.util=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \
               --add-opens=java.base/sun.security.action=ALL-UNNAMED \
               --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \
               --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED'

  1. В связи с переходом на Hibernate 6, при использовании JSOB и BYTEA полей в сущностях Postgres придется немного отрефакторить Entity:

При этом, использовавшийся ранее CustomPostgresDialect оказывается не нужным и его можно удалить, заменив на org.hibernate.dialect.PostgreSQLDialect:

application.yml
spring:
  ...
  jpa:
    database-platform: com.mlwebservice.config.CustomPostgresDialect  # <== delete
    database-platform: org.hibernate.dialect.PostgreSQLDialect        # <== add

Ранее использовавшийся CustomPostgresDialect
package com.mlwebservice.config

import com.vladmihalcea.hibernate.type.array.IntArrayType
import com.vladmihalcea.hibernate.type.array.StringArrayType
import com.vladmihalcea.hibernate.type.json.JsonBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeStringType
import com.vladmihalcea.hibernate.type.json.JsonStringType
import org.hibernate.dialect.PostgreSQL10Dialect
import java.sql.Types

class CustomPostgresDialect : PostgreSQL10Dialect() {
    init {
        registerHibernateType(Types.OTHER, StringArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, IntArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonStringType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeStringType::class.qualifiedName)
    }
}

Не считая докерфайлов и действий по добавлению новой версии jar'ника Rapids в директорию с jar-файлами для отправки в Spark executors и в образ executor’а, это все, что необходимо выполнить. Актуальную версию можно взять в соответствующей ветке репозитория.

На этом можно было бы и закончить, но любопытство ведь берет свое, и появился вопрос - а заработает ли на реактивном стеке, и будет ли эффект?

Сделаем ML реактивным: Spring Web → Spring WebFlux

Зависимости

Изменений при таком переходе изначально должно быть больше, но так же есть нюансы в виде управления зависимостями. Так, Netty, необходимый для Project Reactor (WebFlux) используется самим Spark и драйвером Cassandra, поэтому изначально конфликтовали. Решается путем задания трех зависимостей в самом начале списка зависимостей:

pom.xml: Зависимости Netty
<dependencies>
    <!-- Netty -->
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-all</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-codec-http</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-resolver-dns</artifactId>
        <version>4.1.74.Final</version>
    </dependency>

    <!-- Spring -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-webflux</artifactId>
        <version>${spring.boot.version}</version>
        <exclusions>
            <exclusion>
                <artifactId>log4j-to-slf4j</artifactId>
                <groupId>org.apache.logging.log4j</groupId>
            </exclusion>
        </exclusions>
    </dependency>
    ...
</dependencies>

Spring Data тоже заменяется на реактивную версию:

pom.xml: R2DBC и Spring Data Cassandra Reactive
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-r2dbc</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>io.r2dbc</groupId>
    <artifactId>r2dbc-postgresql</artifactId>
    <version>0.8.13.RELEASE</version>
</dependency>

И добавляются несколько библиотек для работы Kotlin в среде WebFlux:

pom.xml: Kotlin dependencies
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-stdlib</artifactId>
    <version>${kotlin.version}</version>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-reflect</artifactId>
    <version>${kotlin.version}</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlinx</groupId>
    <artifactId>kotlinx-coroutines-reactor</artifactId>
    <version>1.7.2</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>io.projectreactor.kotlin</groupId>
    <artifactId>reactor-kotlin-extensions</artifactId>
    <version>1.2.2</version>
    <scope>runtime</scope>
</dependency>

Кстати, сам Kotlin тоже поднял с версии 1.8.21 до 1.9.0.

Для логирования HTTP запросов-ответов добавляем Zalando Logbook:

pom.xml: Zalando Logbook
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-spring-boot-autoconfigure</artifactId>
    <version>3.2.0</version>
</dependency>
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-netty</artifactId>
    <version>3.2.0</version>
</dependency>

pom.xml (полная версия для WebFlux)
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.mlwebservice</groupId>
    <artifactId>MLWebService</artifactId>
    <version>1.0.0-SNAPSHOT</version>

    <properties>
        <java.version>17</java.version>
        <spring.boot.version>3.1.1</spring.boot.version>
        <scala.version>2.12</scala.version>
        <spark.version>3.3.2</spark.version>
        <lombok.version>1.18.24</lombok.version>
        <org.mapstruct.version>1.4.2.Final</org.mapstruct.version>
        <kotlin.version>1.9.0</kotlin.version>
        <jackson.version>2.13.5</jackson.version>
    </properties>

    <distributionManagement>
        <repository>
            <id>XGBoost4J Snapshot Repo</id>
            <name>XGBoost4J Snapshot Repo</name>
            <url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>
        </repository>
    </distributionManagement>

    <dependencies>
        <!-- Netty -->
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-codec-http</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-resolver-dns</artifactId>
            <version>4.1.74.Final</version>
        </dependency>

        <!-- Spring -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-webflux</artifactId>
            <version>${spring.boot.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-starter-logging</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-core</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.module</groupId>
            <artifactId>jackson-module-kotlin</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-annotations</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>

        <!-- Spring Data -->
        <dependency>
            <groupId>org.springframework.data</groupId>
            <artifactId>spring-data-commons</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-r2dbc</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-jpa</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <scope>runtime</scope>
            <version>42.6.0</version>
        </dependency>
        <dependency>
            <groupId>io.r2dbc</groupId>
            <artifactId>r2dbc-postgresql</artifactId>
            <version>0.8.13.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>com.vladmihalcea</groupId>
            <artifactId>hibernate-types-60</artifactId>
            <version>2.21.1</version>
        </dependency>

        <!-- Cassandra -->
        <dependency>
            <groupId>com.datastax.oss</groupId>
            <artifactId>java-driver-core</artifactId>
            <version>4.13.0</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.12.15</version>
        </dependency>
        <dependency>
            <groupId>com.datastax.spark</groupId>
            <artifactId>spark-cassandra-connector_2.12</artifactId>
            <version>3.3.0</version>
        </dependency>

        <dependency>
            <groupId>com.typesafe</groupId>
            <artifactId>config</artifactId>
            <version>1.4.2</version>
        </dependency>

        <!-- Spark -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.antlr</groupId>
            <artifactId>antlr4-runtime</artifactId>
            <version>4.8</version>
            <scope>runtime</scope>
        </dependency>

        <!-- GXBoost -->
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-spark-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>

        <!-- Kubernetes -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-kubernetes_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>commons-compiler</artifactId>
            <version>3.0.16</version>
        </dependency>
        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>janino</artifactId>
            <version>3.0.16</version>
        </dependency>

        <!-- Rapids -->
        <dependency>
            <groupId>com.nvidia</groupId>
            <artifactId>rapids-4-spark_${scala.version}</artifactId>
            <version>23.06.0</version>
        </dependency>

        <!-- Lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>${lombok.version}</version>
        </dependency>

        <!-- Logging -->
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-webflux</artifactId>
            <version>3.1.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-boot-autoconfigure</artifactId>
            <version>3.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-netty</artifactId>
            <version>3.2.0</version>
        </dependency>

        <!-- Utils -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>

        <!-- Kotlin -->
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-stdlib</artifactId>
            <version>${kotlin.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-reflect</artifactId>
            <version>${kotlin.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx</groupId>
            <artifactId>kotlinx-coroutines-reactor</artifactId>
            <version>1.7.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>io.projectreactor.kotlin</groupId>
            <artifactId>reactor-kotlin-extensions</artifactId>
            <version>1.2.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx.spark</groupId>
            <artifactId>kotlin-spark-api_3.3.1_${scala.version}</artifactId>
            <version>1.2.3</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-test</artifactId>
            <version>${kotlin.version}</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <finalName>service</finalName>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>3.0.6</version>
                <configuration>
                    <mainClass>com.mlwebservice.MLWebServiceApplication</mainClass>
                </configuration>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.11.0</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>compile</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                    </execution>
                    <execution>
                        <id>testCompile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                    <annotationProcessorPaths>
                        <path>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                            <version>${lombok.version}</version>
                        </path>
                    </annotationProcessorPaths>
                </configuration>
            </plugin>

            <plugin>
                <groupId>org.jetbrains.kotlin</groupId>
                <artifactId>kotlin-maven-plugin</artifactId>
                <version>${kotlin.version}</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>process-sources</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                    <execution>
                        <id>test-compile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>test-compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                </executions>
                <configuration>
                    <jvmTarget>${java.version}</jvmTarget>
                    <sourceDirs>
                        <source>src/main/java</source>
                        <source>src/main/kotlin</source>
                        <source>target/generated-sources/annotations</source>
                    </sourceDirs>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

Main класс

Модифицируем Main класс приложения, необходимо добавить аннотации @EnableWebFlux и @EnableR2dbcRepositories, указать тип приложения REACTIVE

Main class
package com.mlwebservice;

import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
import org.springframework.boot.autoconfigure.gson.GsonAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories;
import org.springframework.web.reactive.config.EnableWebFlux;

import java.net.InetAddress;
import java.net.UnknownHostException;

@EnableWebFlux
@EnableR2dbcRepositories
@SpringBootApplication(exclude = {
        GsonAutoConfiguration.class,
        CassandraAutoConfiguration.class
})
public class MLWebServiceApplication {
    public static void main(String[] args) {
        new SpringApplicationBuilder(MLWebServiceApplication.class)
                .web(WebApplicationType.REACTIVE)
                .run(args);
        );
    }
}

Spring Data → R2DBC

Так как в сущности БД используется JSONB поле (с его отображением в приложении в виде JsonNode), необходима конфигурация R2DBC с кастомными конвертерами:

Jsonb converters
package com.mlwebservice.config

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import io.r2dbc.postgresql.codec.Json
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.convert.converter.Converter
import org.springframework.data.convert.ReadingConverter
import org.springframework.data.convert.WritingConverter
import org.springframework.data.r2dbc.convert.R2dbcCustomConversions
import org.springframework.data.r2dbc.dialect.PostgresDialect

@Configuration
open class R2dbcConfiguration(private val objectMapper: ObjectMapper) {

    @Bean
    open fun customConversions() : R2dbcCustomConversions {
        val converters = listOf<Converter<*, *>>(
            JsonNodeWritingConverter(objectMapper),
            JsonNodeReadingConverter(objectMapper)
        )
        return R2dbcCustomConversions.of(PostgresDialect.INSTANCE, converters);
    }
}

@WritingConverter
class JsonNodeWritingConverter(private val objectMapper: ObjectMapper) : Converter<JsonNode, Json> {
    override fun convert(source: JsonNode): Json {
        return Json.of(objectMapper.writeValueAsString(source));
    }
}

@ReadingConverter
class JsonNodeReadingConverter(private val objectMapper: ObjectMapper) : Converter<Json, JsonNode> {
    override fun convert(source: Json): JsonNode? {
        return objectMapper.readTree(source.asString());
    }
}

Далее следует удалить из упомянутой выше сущности ModelEntity лишние аннотации, в итоге должно получиться:

ModelEntity
package com.mlwebservice.persist.entity

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ObjectNode
import org.springframework.data.annotation.CreatedDate
import org.springframework.data.annotation.Id
import org.springframework.data.annotation.LastModifiedDate
import org.springframework.data.relational.core.mapping.Column
import org.springframework.data.relational.core.mapping.Table
import java.time.LocalDateTime
import java.util.*

@Table(name = "models", schema = "instrument_data")
data class ModelEntity constructor(
    @Id
    val id: Long? = null,
  
    @Column("model")
    val model: ByteArray,

    @Column("created_at")
    val createdAt: LocalDateTime,

    @Column("last_trained_at")
    val lastTrainedAt: LocalDateTime,

    @Column("task_id")
    val taskId: UUID,

    @Column("parameters")
    val parameters: JsonNode
)
// конструкторы и прочее необходимое

Сам репозиторий сущностей теперь наследуется от R2dbcRepository:

@Repository
interface ModelRepository : R2dbcRepository<ModelEntity, Long>

Методы сохранения и загрузки модели трансформируются для работы в WebFlux:

методы для работы с моделями данных

метод загрузки модели из БД

internal inline fun <reified T> loadModel(modelId: Long): T {
    val optional = modelRepository.findById(modelId)

    val entity = optional.get()
    val modelByteArray = entity.model

    val byteArrayInputStream = ByteArrayInputStream(modelByteArray)
    val modelObject = ObjectInputStream(byteArrayInputStream).use { it.readObject() }

    if (modelObject is T) {
        return modelObject
    } else {
        throw ServiceException.withMessage("Model id $modelId has incorrect format")
    }
}

модифицируется до:

internal inline fun <reified T> loadModel(modelId: Long): Mono<T> =
        modelRepository.findById(modelId)
            .map { modelEntity: ModelEntity ->
                ByteArrayInputStream(modelEntity.model)
            }
            .publishOn(Schedulers.boundedElastic())
            .map { byteArrayInputStream: ByteArrayInputStream ->
                ObjectInputStream(byteArrayInputStream).use { it.readObject() }
            }
            .flatMap { modelObject ->
                if (modelObject is T) {
                    Mono.just(modelObject)
                } else {
                    Mono.error(ServiceException.withMessage("Model id $modelId has incorrect format"))
                }
            }

а метод сохранения:

fun saveModel(
    model : PredictionModel<Vector, XGBoostRegressionModel>,
    taskId : UUID,
    modelParameters : AnalyticsRequest.ModelParameters
) {
    val byteArrayOutputStream = ByteArrayOutputStream()
    ObjectOutputStream(byteArrayOutputStream).use { it.writeObject(model) }
    val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()
    val jsonParams : JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

    val entity = ModelEntity(modelByteArray, taskId, jsonParams)
    modelRepository.save(entity)
    log.info("Model for task id {} saved. Parameters map: {}, jsonNode: {}",
        taskId, modelParameters, jsonParams)
}

модифицируется до:

fun saveModel(
        model: PredictionModel<Vector, XGBoostRegressionModel>,
        taskId: UUID,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<Void> =
        Mono.fromCallable {
            val jsonParams: JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

            val byteArrayOutputStream = ByteArrayOutputStream()
            ObjectOutputStream(byteArrayOutputStream).use { objectOutputStream ->
                objectOutputStream.writeObject(model)
            }
            val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()

            ModelEntity(modelByteArray, taskId, jsonParams)
        }
            .subscribeOn(Schedulers.boundedElastic())
            .flatMap { entity ->
                modelRepository.save(entity)
                    .doOnSuccess {
                        log.info(
                            "Model for task id {} saved. Parameters map: {}, jsonNode: {}",
                            taskId, modelParameters, entity.parameters.toString()
                        )
                    }
                    .then()
            }

Cassandra

Репозитории Cassandra строились на основе взаимодействия со спарковой сессией. Переработать методы довольно просто. Так, метод получения датасета в базовом абстрактном репозитории:

cassandraDataset web
fun cassandraDataset(keyspace: String, table: String): Dataset<Row> {
    val cassandraDataset: Dataset<Row> = sparkSession.read()
        .format("org.apache.spark.sql.cassandra")
        .option("keyspace", keyspace)
        .option("table", table)
        .load()

    cassandraDataset.createOrReplaceTempView(table)
    return cassandraDataset
}

модифицируется до:

cassandraDataset webflux
fun cassandraDataset(keyspace: String, table: String): Mono<Dataset<Row>> =
    Mono.fromCallable {
        val cassandraDataset: Dataset<Row> = sparkSession.read()
            .format("org.apache.spark.sql.cassandra")
            .option("keyspace", keyspace)
            .option("table", table)
            .load()

        cassandraDataset.createOrReplaceTempView(table)
        cassandraDataset
    }

метод сохранения датасета:

saveDataSet web
open fun saveDataSet(dataset: Dataset<Row>) {
    dataset.write()
        .format("org.apache.spark.sql.cassandra")
        .mode("append")
        .option("confirm.truncate", "false")
        .option("keyspace", keyspace)
        .option("table", table)
        .save();
}

модифицируется до:

saveDataSet webflux
open fun saveDataSet(dataset: Dataset<Row>): Mono<Void> =
    Mono.fromRunnable {
        dataset.write()
            .format("org.apache.spark.sql.cassandra")
            .mode("append")
            .option("confirm.truncate", "false")
            .option("keyspace", keyspace)
            .option("table", table)
            .save()
    }

метод получения базового датасета с определенными оффсетами:

getBaseDataSet web
fun getBaseDataSet(
    ticker: String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    currentOffset : Int,
    batchSize : Int
): Dataset<Row> {
    val filteredDataset = cassandraDataset(table)
        .filter(
            functions.col("ticker").equalTo(ticker)
                .and(functions.col("task_number").equalTo(taskNumber.toString()))
                .and(functions.col("datetime").between(dateStart, dateEnd))
        )

    val offsetDataset = filteredDataset.withColumn(
        "row_number",
        functions.row_number().over(orderBy("datetime"))
    )

    return offsetDataset
        .filter(functions.col("row_number")
            .between(currentOffset + 1, currentOffset + batchSize))
        .drop("row_number")
}

модифицируется до:

getBaseDataSet webflux
fun getBaseDataSet(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    currentOffset: Int,
    batchSize: Int
): Mono<Dataset<Row>> =
    cassandraDataset(table)
        .map { dataset ->
            dataset
                .filter(
                    functions.col("ticker").equalTo(ticker)
                        .and(functions.col("task_number").equalTo(taskNumber.toString()))
                        .and(functions.col("datetime").between(dateStart, dateEnd))
                ).withColumn(
                    "row_number",
                    functions.row_number().over(orderBy("datetime"))
                )
                .filter(
                    functions.col("row_number")
                        .between(currentOffset + 1, currentOffset + batchSize)
                )
                .drop("row_number")
        }

Остальные репозитории конкретных таблиц переписываются по такому же принципу.

В сервисе работы с данными следует упомянуть метод объединения датасетов (теперь же репозитории возвращают реактивные Mono<Dataset<Row>>):

getMainDataset web
fun getMainDataset(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate
) : Dataset<Row> {
    val timeSeries = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd).`as`("ts")
    val emaDataSet = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd).`as`("ema")
    val stochasticDataset = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd).`as`("stoch")
    val bBandsDataset = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd).`as`("bb")
    val macdDataset = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd).`as`("macd")
    val rsiDataset = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd).`as`("rsi")
    val smaDataset = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd).`as`("sma")
    val willrDataset = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd).`as`("willr")

    return combineDatasets(
        timeSeries, emaDataSet, stochasticDataset, bBandsDataset, macdDataset, rsiDataset, smaDataset, willrDataset
    )
}

модифицируется до:

getMainDataset webflux
fun getMainDataset(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate
): Mono<Dataset<Row>> {
    val timeSeriesMono = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ts") }
    val emaDataSetMono = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ema") }
    val stochasticDatasetMono = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("stoch") }
    val bBandsDatasetMono = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("bb") }
    val macdDatasetMono = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("macd") }
    val rsiDatasetMono = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("rsi") }
    val smaDatasetMono = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("sma") }
    val willrDatasetMono = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("willr") }

    return Mono.zip(
        timeSeriesMono, emaDataSetMono, stochasticDatasetMono, bBandsDatasetMono,
        macdDatasetMono, rsiDatasetMono, smaDatasetMono, willrDatasetMono
    ).map { tuple ->
        combineDatasets(tuple.t1, tuple.t2, tuple.t3, tuple.t4, tuple.t5, tuple.t6, tuple.t7, tuple.t8)
    }
}

здесь получаются 8 датасетов в Mono-обертках, обертки объединяются в один Mono посредством .zip() и передаются на исполнение в метод комбинации датасетов, который не менялся.

Сервис StockAnalyticsService

Метод выполнения предикта с помощью сохраненной модели:

predictWithExistingModel web
fun predictWithExistingModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    modelId : Long
): StockPredictDto {
    val model: PredictionModel<Vector, XGBoostRegressionModel> = modelService.loadModel(modelId)
    val data = dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)

    var predictions = model.transform(data)
    predictions = predictions.select("dateTime", "prediction")
    return StockPredictDto.fromDataset(predictions)
}

модифицируется до:

predictWithExistingModel webflux
fun predictWithExistingModel(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    modelId: Long
): Mono<StockPredictDto> =
    modelService.loadModel<PredictionModel<Vector, XGBoostRegressionModel>>(modelId)
        .flatMap { model ->
            dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)
                .map { data ->
                    val predictions = model.transform(data)
                        .select("dateTime", "prediction")
                    StockPredictDto.fromDataset(predictions)
                }
        }

Метод обучения модели:

trainModel web
fun trainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)

    val tdf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
    val edf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset)
        .selectExpr(*allColumns)

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    val model: PredictionModel<Vector, XGBoostRegressionModel> = regressor.fit(tdf)
    val predictions = model.transform(edf)

    combinedDataRepository.saveData(tdf.selectExpr(*allColumns).unionAll(edf), ticker, taskNumber)
    modelService.saveModel(model, taskNumber, modelParameters)

    val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

модифицируется до:

trainModel webflux
fun trainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> =
        Mono.just(dateEnd.minusDays(evalPivotPoint))
            .flatMap { pivot: LocalDate ->
                dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
                    .zipWith(dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset))
            }
            .flatMap { tuple: Tuple2<Dataset<Row>, Dataset<Row>> ->
                val tdf = tuple.t1
                val edf = tuple.t2

                val modelParams = createModelParams(modelParameters)
                val regressor = xgBoostRegressor(modelParams)

                Mono.fromCallable { regressor.fit(tdf) }
                    .flatMap { model: XGBoostRegressionModel ->
                        val predictions = model.transform(edf)

                        val saveDataMono = combinedDataRepository.saveData(
                            tdf.selectExpr(*allColumns).unionAll(edf),
                            ticker,
                            taskNumber
                        )

                        modelService.saveModel(model, taskNumber, modelParameters)
                            .then(saveDataMono)
                            .thenReturn(predictions)
                    }
            }
            .map { predictions: Dataset<Row> ->
                val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
                ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
            }

здесь tdf и edf обернуты в Mono, поэтому объединяются в кортеж из двух элементов Mono<Tuple2>, далее в оборачиваем в Callable функцию regressor.fit(tdf), которая будет выполнена асинхронно и вернет результат в виде model: XGBoostRegressionModel. В функции flatMap она используется с эвалюирующим датасетом для получения предиктов, затем сохраняется в БД с помощью описанного выше метода saveModel. Остальная логика очевидна.

Наибольшую сложность вызывает метод инкрементального обучения (да, на данной модели инкремент не работает и требуется замена XGBoost на другую модель, но цель была трансформировать логику под реактивную среду и получить работающий пример, который далее можно использовать для инкрементального обучения модели).

Исходный метод:

incrementTrainModel web
fun incrementTrainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    batchSize : Int,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)
    var currentBatchOffset = 0
    var i = 0

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
    var predictions: Dataset<Row>? = null

    var tdf: Dataset<Row>?
    do {
        log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
        tdf = dataReaderService.getDatasetWithLabel(
            ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
        )
        if (tdf.isEmpty) break

        model = regressor.fit(tdf)
        combinedDataRepository.saveData(tdf.selectExpr(*allColumns), ticker, taskNumber)

        currentBatchOffset += batchSize
        i++
    } while (tdf?.isEmpty == false)

    val edf = dataReaderService.getDatasetWithLabel(
        ticker, taskNumber, pivot, dateEnd, offset, 0, 100).selectExpr(*allColumns)
    if (model != null) {
        predictions = model.transform(edf)
    }
    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
    modelService.saveModel(model!!, taskNumber, modelParameters)

    val result = predictions!!.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

модифицируется до:

incrementTrainModel webflux
fun incrementTrainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        batchSize: Int,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> {
        val pivot = dateEnd.minusDays(evalPivotPoint)
        var currentBatchOffset = 0
        var i = 0

        val modelParams = createModelParams(modelParameters)
        val regressor = xgBoostRegressor(modelParams)

        var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
        var tdf: Dataset<Row>? = null

        return Mono.defer {
            dataReaderService.getDatasetWithLabel(
                ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
            )
        }
            .map { dataset ->
                tdf = dataset
                log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
                if (tdf?.isEmpty == true) {
                    log.warn(
                        "tdf is empty, no more data for learning, Iteration {}: currentOffset {}",
                        i, currentBatchOffset
                    )
                    Mono.empty()
                } else {
                    model = regressor.fit(tdf)
                    log.info("model trained, Iteration {}: currentOffset {}", i, currentBatchOffset)
                    currentBatchOffset += batchSize
                    i++
                    combinedDataRepository.saveData(tdf!!.selectExpr(*allColumns), ticker, taskNumber)
                        .thenReturn(currentBatchOffset + batchSize)
                }
            }
            .repeat { tdf?.isEmpty == false }
            .then(dataReaderService
                .getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset, 0, 100)
                .flatMap { edf ->
                    log.info("Got edf")
                    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
                        .then(modelService.saveModel(model!!, taskNumber, modelParameters))
                        .thenReturn(model!!.transform(edf))
                }.map { predictions ->
                    log.info("Predictions stage")
                    val result = predictions?.withColumn(
                        "error", col("prediction")
                            .minus(col(labelName))
                    )
                    ModelTrainResultResponse(ModelTrainResult.listFromDataset(result!!.selectExpr(*resultExp)))
                })
            .doOnError { exception ->
                log.error("Error while increment learning; taskNumber = {}", taskNumber, exception)
                ModelTrainResultResponse()
            }
    }

В отличие от Java, лямбда-выражения в Kotlin не требуют от переменных, чтобы они были effectively final, поэтому переменные currentBatchOffset, i, model и tdf могут изменяться в ходе выполнения основного стрима.

Здесь функция получения датасета обертывается в Mono.defer(). Особенность данного подхода в том, что выполнение функции откладывается до момента подписки на данный Mono. А подписка будет повторяться методом .repeat() до тех пор, пока не выполнится условие tdf?.isEmpty == false.

Когда очередной tdf будет пустым, выполнится логика в then: из кассандры будет получен датасет edf, который сохранится в таблице скомбинированных данных, так же будут получены предикты модели и сохранена сама модель. Затем из предиктов подготовится результат метода. В случае ошибки вернется пустой результат метода.

Не сказать, что это идеальное исполнение метода, но как пример сойдет.

Подробно реализацию можно посмотреть в отдельной ветке репозитория.

Сравнение двух реализаций

Как известно, реактивный стек отличается от сервлетного тем, что для выполнения одной и той же логики зачастую требуется меньше ресурсов. В некоторых случаях может возрасти скорость выполнения алгоритма.

Тестирование происходило по следующей методике:

  1. Сервис поднимается в Docker-контейнере с 4 CPU и 4 Gb памяти, использует Spark Executor (v. 3.3.2, JDK 17), так же в Docker контейнере, который подключается к Standalone-мастеру Spark в виртуальной машине. Все работает на одной машине под управлением Windows 10 Pro, задачи тренировки моделей выполняются на GPU NVidia 4090.

  2. В течении 10 минут производятся запросы методов: обучения новой модели (POST /analytics - для сокращения “1 запрос”), получения предиктов с помощью сохраненной модели (GET /analytics - для сокращения “2 запрос”) и инкрементального обучения (POST /analytics/increment - для сокращения “3 запрос”) с batch_size = 50 записей, во время которого делается 12 итераций над 6 сотнями записей в таблицах Cassandra. Первый цикл на “не прогретом” драйвере (первые запросы всегда выполняются дольше), далее два одинаковых цикла по одному запросу каждого метода на “прогретом драйвере” и в четвертом цикле запускаются одновременно 1, 2, 3 методы.

  3. Driver работает в режиме Spark Cluster, используется одна Spark Session на все время работы приложения;

  4. Изначальные параметры запуска JVM одинаковые: первоначальный размер кучи 512 Мб, максимальный размер не указан, GC по умолчанию (G1).

Результаты потребления ресурсов:

Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world за 10 минут

Spring Web

3,4

1,5

4

2

4

Spring Webflux

3,4

1,1

1

0,5

0

С указанными выше параметрами для сервлетного стека наблюдалось 4 stop the world от G1 GC, при этом один раз результатом выполнения предиктов из сохраненной модели стала ошибка сервера.

На графике видно, что потребление памяти растет линейно до момента, когда свободного места для кучи уже нет и ее необходимо чистить.

График потребления ресурсов Web приложения с параметрами JVM -Xms512m
График потребления ресурсов Web приложения с параметрами JVM -Xms512m

У реактивного стека другая картина: после первых запросов стабильные ~0.5 Гб памяти. По потреблению CPU разница не настолько большая.

График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m
График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m

Скорость выполнения запросов:

Сравнительная таблица Web и WebFlux версия приложения с параметрами JVM -Xms512m
Сравнительная таблица Web и WebFlux версия приложения с параметрами JVM -Xms512m

Топ-5 классов по потреблению памяти:

Учитывая, что весь результирующий датасет занимает около 55 Мб, такой объем аллоцированной памяти вызывает вопросы. Анализ стектрейсов показал, что в большинстве случаев источником и причиной является Spark и Rapids, которые строят план запросов, обмениваются данными между БД, экзекуторами и драйвером, подготавливают массивы данных для загрузки в GPU и вычитывают результат из него. Потратив некоторое время на изучение вопроса оптимизации использования памяти, могу сделать вывод, что это штатное поведение системы в такой конфигурации, и надо научиться с этим жить при использовании сервлетного стека.

Первые попытки жить с этим в привели к изменению параметров запуска JVM для сервлетного стека на следующие: -Xms512m -Xmx3g -XX:GCTimeRatio=19 (жесткое указание того, что система может потратить до 5% времени на сборку мусора - (1 / (1+19))) -XX:+UseZGC. Учитывая, что реактивному стеку достаточно в среднем 512 Мб памяти, и что Z GC потребляет несколько больше памяти, чем G1 GC, планка максимального размера кучи снизилась до 3 Гб.

График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC
График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Потребление CPU незначительно снизилось, наблюдается схожее потребление памяти, но stop the world уже не зафиксировано. Судя по графику, после завершения работы методов POST /analytics и GET /analytics куча очищается, но при работе POST /analytics/increment куча очищается только к моменту приближения к своему максимальному размеру. Логика, которая могла бы вести к утечке памяти, отсутствует, причина такого высокого потребления памяти остается не выясненной.

Результаты переключения GC в таблице потребления ресурсов:


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Web

3,4

1,5

3

1,5

0

Spring Webflux

3,4

1,1

1

0,5

0

и скорости выполнения запросов:

Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Стало интересно, что будет, если для G1 GC установить максимальный размер кучи как для Z GC и установить жесткий предел времени выполнения на сборку мусора. В этом случае оказалось, что память заполняется как и раньше, но stop the world стало больше, так как доступной памяти меньше, и, соответственно, заполняется она быстрее. Потребление ресурсов осталось примерно на том же уровне:

График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC
График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Web

3,4

1,5

3

1,5

6

Spring Webflux

3,4

1,1

1

0,5

0

Скорость выполнения запросов возросла, но не существенно.

Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Так же попробовал в сервлетном стеке использовать Parallel GC с параметрами -Xms512m -Xmx4g -XX:GCTimeRatio=19 -XX:+UseParallelGC. Результаты самые худшие, за 10 минут удалось прогнать только 2 цикла. Если первые два метода выполнялись примерно за то же время без отклонений, то метод инкрементального обучения выполнялся в первый раз 3мин 32с, что хуже примерно на 1,5 минуты среднего результата сервлетного стека, а второй запрос подвис и выполнялся 8мин 10с. Результаты в таблицах не фиксировал.

График потребления ресурсов Web версии приложения с Pasrallel GC
График потребления ресурсов Web версии приложения с Pasrallel GC

Напоследок применил настройки JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC к версии WebFlux, которая оказалась на данный момент самой оптимальной по потреблению ресурсов и скорости обработки запросов. Сравнительные таблицы версии с дефолтными параметрами и G1 GC и версии с кастомными параметрами JVM с Z GC ниже.

График потребления ресурсов:

График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC
График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Таблица потребления ресурсов:


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Webflux G1 GC

3,4

1,1

1

0,5

0

Spring Webflux Z GC

3,4

1,4

2,84

1,5

0

Таблица скорости выполнения запросов:

Сравнительная таблица скорости выполнения запросов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Итоговые таблицы со всеми версиями и различными значениями параметров конфигурации JVM представлены ниже. За основу взяты результаты для WebFlux на G1 GC с одним параметром JVM минимального размера хипа 512m.

Сводные таблицы по потреблению ресурсов
Сводные таблицы по потреблению ресурсов
Сводные таблицы времени выполнения запросов
Сводные таблицы времени выполнения запросов

Вывод

Подводя черту после написания третьей статьи на тему построения системы распределенного машинного обучения на Java и Kotlin, самый большой вывод, который напрашивается - построить подобную систему сложно, много неизвестных, необходимо выполнить много исследований, но добиться работающего решения вполне реально, было бы желание.

Если так случилось, что нужно выполнять задачи ML на JVM стеке технологий, учите Python и не занимайтесь фигней, а руководству продайте альтернативную систему отличным выбором в качестве основы будет Kotlin и Spring Webflux (как альтернатива - Web с Z GC), и, естественно, Apache Spark. По окончанию работ над любым приложением стоит провести проверку профилировщиком, так как с очень высокой вероятностью при дефолтных параметрах JVM работа приложения не будет оптимальной.

Другой вопрос - является ли данная система эффективной с точки зрения производительности и потребления ресурсов? Без тестов на альтернативной системе (например, Python + Dask) объективно ответить на данный вопрос я затрудняюсь. Возможно, в будущем попробую поднять такую систему и написать альтернативную логику на питоне, тогда будет с чем сравнить и о чем написать очередную статью.

Комментарии (4)


  1. sherbinko
    06.08.2023 09:14

    Как известно, реактивный стек отличается от сервлетного тем, что для выполнения одной и той же логики зачастую требуется меньше ресурсов.

    Это где вы такое взяли? Обычно подразумевается что код в потоках быстрее с точки зрения throughput, но зато реактивный может быть выгоднее с точки зрения latency.
    Кроме того, непонятно какое вообще влияние может оказывать веб-API для обучения на основе GPU ??? Зачем там вообще нужно веб-API?
    Вам надо было зафиксировать какой-нибудь один веб-API и под капотом уже тестировать либо реактивный API, либо "обычный".
    Создаётся впечатление, что пытаетесь протестировать всё сразу, хаотично меняя какие-то параметры без нормального профайлинга и изоляции модулей.
    Хотелось бы видеть, скажем, "реактивный API кассандры быстрее нереактивного на таких-то запросах на 15% потому что это, это и это". Ваше же измерение абстрактного Machine Learning-а в вакууме вообще ни о чём не говорит


    1. Dartya Автор
      06.08.2023 09:14

      Это где вы такое взяли? Обычно подразумевается что код в потоках быстрее с точки зрения throughput, но зато реактивный может быть выгоднее с точки зрения latency.

      Вы правы, говоря о пропускной способности и задержке - действительно, в данном предложении я имел в виду, что зачастую (но не всегда - на всякий случай стоит сделать уточнение) при большом числе параллельных запросов WebFlux может потреблять меньше ресурсов, чем аналогичное приложение на web-стеке, как минимум за счет асинхронности и порождения мЕньшего числа потоков, не говоря уже об уменьшении процессорного времени на переключение и обработку данных потоков. Опять же, многое зависит от настроек JVM, Netty, Undertow и прочих. Наверное, если возник вопрос, стоило это уточнить и дополнительно осветить - спасибо, уточнили.

      Другой вопрос, зачем это в данном приложении, которое что-то абстрактно обучает, что никому не нужно? Видите ли, я, когда начинал этот проект, одним из требований ставил горизонтальную масштабируемость системы, и во многом система держится на Spark и Kubernetes (и даже ускоритель Rapids для Spark Jobs на GPU имеется). Так как частью системы является приложение на Java, и так уж случилось, что для обеспечения многих вспомогательных механизмов использован Spring, использовать WebFlux и сравнить его с сервлетным стеком было естественным желанием.

      Кроме того, непонятно какое вообще влияние может оказывать веб-API для обучения на основе GPU ???

      Я где-то в статье сказал, что оказывает, или что Spring в целом и WebFlux в частности - это только про WEB API? Опять же, результаты перед Вами: перешел на реактивный стек - управление ресурсами стало более эффективным. Почему? Извините, я не нашел 100% ответа на этот вопрос с доказательствами из профайлера и хипдампов, а потратить еще несколько месяцев личного времени на поиск ответов в данный момент не готов. Есть результат, результат выражен в графиках, таблицах, описана методика тестирования - результатом поделился.

      Теперь, пользуясь случаем, позвольте задать Вам вопрос - Вы, как я понял, статью читали внимательно, и сделали для себя определенные выводы. Так же подозреваю, что Вы компетентный и довольно опытный специалист с хорошей инженерной подготовкой. Не могли бы озвучить свою гипотезу по факту более эффективного управления ресурсами в WebFlux версии приложения по сравнению с Web, и в чем может быть причина такого поведения последнего?

      Зачем там вообще нужно веб-API?

      Как интерфейс взаимодействия с внешним миром, предоставляющим возможность запуска Spark Jobs. С таким же успехом можно было сделать и чтение из очереди или запуск задач по расписанию, по получению SMS или e-mail, и др. Так как у меня нет заказчика с его персональными хотелками, я выбрал такой способ взаимодействия приложения с внешним миром.

      Вам надо было зафиксировать какой-нибудь один веб-API и под капотом уже тестировать либо реактивный API, либо "обычный". Создаётся впечатление, что пытаетесь протестировать всё сразу, хаотично меняя какие-то параметры без нормального профайлинга и изоляции модулей.

      Хотелось бы видеть, скажем, "реактивный API кассандры быстрее нереактивного на таких-то запросах на 15% потому что это, это и это". Ваше же измерение абстрактного Machine Learning-а в вакууме вообще ни о чём не говорит

      В этом частично соглашусь, частично нет. Признаюсь, частью с профайлингом я менее всего доволен в данной статье, так как не удалось привести сервлетную версию приложения к подобным результатам в части потребления ресурсов, как у реактивной версии, и не удалось найти ответы на вопросы о причинах такого поведения. Ограничение памяти хипа, времени работы GC и замена G1 на Z - наверное, сравни прикладыванию подорожника, но этого оказалось достаточно, чтобы улучшить работу с ресурсами и предотвратить stop the world (кстати, я не сказал, что стопы оказывают существенное влияние на результаты работы, и заметил их только при подключении профайлером).

      Стоило ли "изолировать" логические модули и отдельно их тестировать и профилировать - вопрос постановки задачи, я такую перед собой не ставил. Почему не ставил - потому что материала и без того наработано на 30+ страниц А4, и по итогу работ имеется результат достаточный для того, чтобы его показать, что я и сделал в выводе.

      Что касается "абстрактного Machine Learning-а в вакууме" - во всех трех статьях по этой теме я не претендовал на то, чтобы разработать новый высокоэффективный алгоритм, новую технологию, библиотеку или фреймворк для Java, и все иное в этом духе - я не обладаю такими компетенциями. Изначально мне было интересно, а есть ли альтернатива Python и Dask в мире ML такая, чтобы за основу можно было взять Java с ее существующими библиотеками и фреймворками ML, NN, и прочих; взять оркестратор контейнеров и запустить параллельное обучение на нескольких GPU. Вот с этим, я считаю, что справился в полной мере - альтернатива есть, и она работает, что бы Вы ни говорили. Насколько она эффективна - большой вопрос, который я не готов раскрывать ни в рамках опубликованной статьи, ни в рамках данного комментария. Но, исходя из имеющегося опыта уже могу сказать, что инкрементальное обучение XGBoost в библиотеке Java не поддерживается, а вот при ресерче для Python я находил пример с работающими параметрами.

      Касательно Ваших пожеланий - действительно, простор для исследований есть, и, возможно, когда-нибудь результаты появятся в отдельной статье, затем найдутся такие же не довольные, я им пообещаю написать еще, напишу, и найдутся другие недовольные, и так по кругу - жизнь во всей ее красе. Но, если Вам действительно что-то интересно и хотелось бы увидеть - не вижу смысла кого-то или чего-то ждать, я предлагаю Вам провести исследование самостоятельно, выложить результаты в статье. Со своей стороны готов даже на Вас подписаться, и, когда вы выложите результаты, обязательно поставлю лайк, приду к Вам в комментарии, поблагодарю за статью и задам уточняющие вопросы, если таковые возникнут.


      1. sherbinko
        06.08.2023 09:14

         при большом числе параллельных запросов WebFlux может потреблять меньше ресурсов

        Откуда у вас "большое число параллельных потоков" в машин лёрнинге? WebFlux используется, когда речь идёт о десятках тысяч одновременных соединений

        Я где-то в статье сказал, что оказывает, или что Spring в целом и WebFlux в частности - это только про WEB API

        WebFlux - это реактивный аналог Spring MVC. Под капотом в обоих случаях вы можете использовать как реактивный так и нереактивный API. Как веб API может повлиять на скорость мне не понятно. Вы же не запускаете обучение миллион раз в секунду.

        Почему? Извините, я не нашел 100% ответа на этот вопрос с доказательствами из профайлера и хипдампов, а потратить еще несколько месяцев личного времени на поиск ответов в данный момент не готов

        Если у вас появляются подозрительные результаты, то вы должны исследовать этот вопрос, ибо есть вероятность что вы что-то делаете не так. Иначе будет как в анекдоте: "стекло протирал? по колёсам бил?"
        WebFlux - более сложный API на его поддежку требуется больше рессурсов. И не смотря на это, вы нашли время чтобы на него переписать. А теперь пишете, что не нашли времени на ответ на вопрос: почему всё таки проседает производительность.

        Не могли бы озвучить свою гипотезу по факту более эффективного управления ресурсами в WebFlux версии приложения по сравнению с Web

        Моя гипотеза - что Spring Flux будет иметь такую же производительность как и Spring MVC. Чтобы точно ответить причину тормозов - надо профайлить. Bottleneck - он вcегда в неожиданном месте.

        но этого оказалось достаточно, чтобы улучшить работу с ресурсами и предотвратить stop the world

        зачем его предотвращать? У вас точно своп отключен?

        Все работает на одной машине под управлением Windows 10 Pro

        Если у вас kubernetes, почему бы не протестировать там


        1. Dartya Автор
          06.08.2023 09:14

          Откуда у вас "большое число параллельных потоков" в машин лёрнинге? WebFlux используется, когда речь идёт о десятках тысяч одновременных соединений

          Отвечая по существу, начиная проект, ставил перед собой цель ответить на вопрос - "Если поставлена задача принимать от пользователя данные "использовать такую-то модель, указать для нее такие-то параметры, взять данные такие-то с такой-то даты, делать предикт на такое-то время вперед" для выполнения его задач на обучение моделей и их дальнейшего использования, как я могу реализовать на JVM-стэке и в контейнеризированной среде масштабируемую систему для обеспечения параллельной работы большого числа клиентов?". В двух предыдущих статьях я отвечал на данный вопрос, и, я полагаю, достиг определенных успехов, чем и поделился с сообществом. Кстати, если найдете наиболее информативные примеры подобной системы, чем я могу предложить, в публичном доступе, скиньте в комментарии, буду признателен.

          Вы с тем же успехом могли бы задать вопрос о необходимости распределенной среды, Kubernetes и, собственно, Java, если все можно делать проще на питоне, и для работы используется только локальная среда. Надеюсь, ответил на этот вопрос и подкинул идей для новых.

          Как веб API может повлиять на скорость мне не понятно

          А мне не понятно, почему Вы делаете привязку к WEB API, если WebFlux - это в первую очередь фреймворк на основе Project Reactor, по умолчанию использующий Netty. Условный пример, который в жизни не применим, но как эксперимент сойдет: сделайте Flux, пусть генерирует в лог сообщения о положении луны и солнца в зависимости от системного времени, делает на основе положения солнца еще какие-нибудь расчеты (придумайте сами), и выводит в лог. Flux есть, реактивная цепочка есть, вывод в лог какой-никакой есть - где здесь WEB API? Затем можете тем же Flux'ом генерировать задачи для спарка.

          Возможно, стоит освежить в памяти принцип работы и архитектуру Spark.

          Чтобы сократить комментарий, предлагаю ознакомиться со следующими материалами:

          https://spark.apache.org/docs/latest/cluster-overview.html, https://blog.knoldus.com/understanding-the-working-of-spark-driver-and-executor/

          Таким образом получается, что у Spark Driver довольно много сетевых взаимодействий с исполнителями. Учитывая, что как минимум логика приложения выполняется асинхронно и в реактивных цепочках, осмелюсь предположить, что именно по этой причине WebFlux наиболее эффективен в части потребления аппаратных ресурсов и скорости выполнения задач.

          WebFlux используется, когда речь идёт о десятках тысяч одновременных соединений

          Соединений или RPS? Можем долго вести дискуссию на предмет преждевременной оптимизации, но, во-первых, речь идет об эксперименте, и эксперимент оказался удачным. Во-вторых, я правильно понимаю, что Вы пытаетесь через комментарии на Хабре запретить людям, не имеющим "десятки тысяч одновременных соединений", использовать реактивный стек, или донести им мысль о том, что им ни в коем случае не нужно этого делать?

          А если число "одновременных соединений", как Вы говорите, (или RPS) будет исчисляться не десяткамми тысяч, а тысячами, скажем, условными 9000, - все, WebFlux не используется? Можно ссылку на авторитетные источники, исследования, возможно Ваши публикации, или, выражаясь Вашим языком, откуда вы взяли, что WebFlux используется только когда речь идет о десятках тысяч одновременных соединений?

          WebFlux - более сложный API на его поддежку требуется больше рессурсов.

          Да, сложнее, больше ли ресурсов (насколько я понял, речь идет о человеческих) - не соглашусь. Если в технологическом стеке компании имеется WebFlux, принято решение о единообразии его использования в ряде задач, имеется система онбординга и в целом персонал обучен и квалифицирован, имеет обширную кодовую базу, как личную, так и компании, большего числа ресурсов относительно сервлетного стека не требуется. И говорю сейчас не об абстрактных компаниях - я в такой на данный момент работаю.

          Если у вас появляются подозрительные результаты, то вы должны исследовать этот вопрос...

          Последовательность действий и трудозатраты таковы, что переписал я за один световой день, а вот на тесты с профайлером и анализом проблемы потратил гораздо больше времени. По моей оценке, и по тому, что мне удалось найти, такими темпами более глубокое изучение, изоляция модулей и их тестирование, прочие манипуляции могут занять в худшем варианте месяц и более моего личного времени, поэтому я поделился тем результатом, который посчитал достаточным для публикации. Достаточность определяется тем, что работают оба варианта, реактивный эффективнее, а сервлетный нужно докручивать, вектор направления так же задал. Не достаточно, нужно еще - либо ждите, либо заводите и профилируйте, репозиторий в открытом доступе.

          И не смотря на это, вы нашли время чтобы на него переписать. А теперь пишете, что не нашли времени на ответ на вопрос: почему всё таки проседает производительность.

          Касательно тона Ваших комментариев - я прошу Вас его сменить, так как Вы мне не заказчик, не руководитель, не мать, не сестра и не любовница, чтобы критиковать меня за потраченное мое личное время, а так же задавать вопросы в духе "Это где вы такое взяли?", "Откуда у вас "большое число параллельных потоков" в машин лёрнинге?", и пр. Я Вам ничего не должен, и Ваши хотелки можете оставить при себе, как я Вам выше писал. В противном случае можете продолжать пытаться самоутверждаться в комментариях сколько угодно, я не буду на Вас реагировать. Судя по Вашему профилю, Вы именно за этим и приходите на Хабр. Призываю использовать корректные выражения в своей речи в дискуссии со мной.

          WebFlux используется, когда речь идёт о десятках тысяч одновременных соединений

          ...

          Моя гипотеза - что Spring Flux будет иметь такую же производительность как и Spring MVC. Чтобы точно ответить причину тормозов - надо профайлить. Bottleneck - он вcегда в неожиданном месте.

          Я Вас услышал, блестящая гипотеза, спасибо. Если еще что-то надумаете, пишите, с удовольствием почитаем.

          зачем его предотвращать? [stop the world - прим.]

          Слишком толсто. Погуглите этот вопрос, уверен, найдете много интересных материалов.

          У вас точно своп отключен?

          Отключен. И, если уж вы вспомнили про Win10Pro, основываясь на Ваших рассуждениях выше и предвидя Ваши вопросы - да, перезагружал, а вот переустановить не пробовал.

          Удобно задавать вопросы, а когда нужно самому ответить, получается как в том анекдоте - "моя не читатель, моя писатель", или как у Марка Твена - "... можно заговорить и развеять все сомнения". Возвращаясь к вопросу о тоне комментариев - как видите, в эту игру можно играть вдвоем.

          Если у вас kubernetes, почему бы не протестировать там

          Потому что кубер на удаленных машинах, у них меньшее число аппаратных ресурсов, и к ним я имею доступ только по RDP, и мне не так комфортно постоянно вести разработку, профайлить и деплоить новые версии в кубер, по сравнению с работой на локальной машине при тех же условиях в том же контейнер-рантайме.

          ---

          Что касается профайлинга и поиска причин - если в будущем вернусь к этому вопросу и будет что полезного рассказать, вернусь в комментарии с подробным указанием причин и решений, как бороться с подобным поведением сервлетного стека в данной среде.