Статья является продолжением «Пишем агента на Kotlin: KOSMOS», но может читаться независимо. Мотивация к написанию — сохранить читателю время на возню с фреймворками для решения относительно простой задачи.
Автор подразумевает у читателя теоретическое понимание того, что такое агент. Иначе лучше прочесть хотя бы начало предыдущей части.
Как и везде, в программирование важен маркетинг, поэтому обертку над HTTP-запросами в цикле называют революцией:
From Python to Kotlin: How JetBrains Revolutionized AI Agent Development. — reddit, medium.
Но в этом нет ничего революционного. Ниже хочу показать, как самостоятельно написать аналог Koog или Langchain4j. У вас не будет всех их фичей, зато будет очень простая и расширяемая система.
Содержание
Введение
Проблемы использования фреймворков
- Мета проблемы
- Сложная ментальная модель
- Запутанный синтаксис
Реализация агента на основе графов
- Упрощенная реализация Агента на основе графов
- Детальная реализация Агента на основе графов
- Добавление RAG
Когда использовать фреймворк, а не самописное решение?
Предисловие для тех, кто читал первую статью
Единственная часть, которую было сложно расширять и поддерживать в прошлой статье — сам агент. Тут мы напишем такое решение, чтобы агент собирался, как конструктор, и чтобы любую часть можно было легко вынести и переиспользовать в других агентах.
Если нет времени читать, можете глянуть PR с реализацией агента на графах и PR с добавлением RAG.
Предисловие для мобильных разработчиков
С 2015, когда я начал карьеру, постоянно появляются библиотеки-решения для организации UI-архитектуры на основе паттернов: MVC, MVP, MVVM, TEA, VIPER, Flux/Redux. Я пробовал все паттерны из перечисленных и смело могу сказать, что особой разницы нет, пока вся команда придерживается одного подхода. Но каждый раз использование чьей-то библиотеки приводило к страданиям. Потому что находился кейс, который библиотека упускала и не давала решить легко. Были баги, которые не починить без форка. Код библиотек запутан и сложен. Всегда проще было самому написать основу и жить с ней.
То же самое и с фреймворками по написанию ИИ-агентов. Лучше разберитесь с основами и напишите под себя легковесное решение, которое вы будете понимать и контролировать.
Предисловие для бэкендеров
Я встречал java-бэкенд разработчиков, которые несколько лет работали с Postgres, но не знают, как взять лок. Разгадка — фреймворк Hibernate. Такой разработчик может ходить в базу в цикле или внутри транзакций к базе выполнять HTTP-запросы. Конечно, есть исключения, но при прочих равных, фреймворк лишает понимания, задерживает развитие, способствует написанию неоптимального кода и даже негативно влияет на экологию (сколько энергии было потрачено на разработку и компиляцию этого фреймворка? SW).
То же самое и с фреймворками по написанию ИИ-агентов. Под капотом происходит всего лишь вызов нескольких ручек HTTP, а ваш фреймворк падает с out of memory.
Проблемы использования фреймворков
Приведу несколько соображений на примере Koog. Если вы и так понимаете, что фреймворк — это усложнение на ровном месте, пропускайте.
Frameworks are one of the hugest anti-patterns in software development. — Peter Krumins
Мета проблемы
Существуют проблемы, не относящиеся к сложности непосредственного использования API фреймворка. Вот несколько примеров:
Мы взяли Koog в KMP-проект, и он сломался о колено по конфликту версий kotlinx-datetime (issue, PR висит с начал сентября). Это решить сложнее, чем кажется, так как апи kotlinx-datetime еще в альфе и меняется, не заботясь об обратной совместимости.
Посмотрите, сколько всего вы затащите с Koog — libs.versions.toml. Всё что начинается с «0.», а это 7 библиотек на момент написания статьи, может пойти по пути из пункта 1.
Чтобы использовать Гигачат, придется писать клиента самому. Гигачат умеет работать с примерами для функций (тулов) — few_shot_examples. У Koog ни Annotation-based tools, ни Class-based tools это не умеют.
Из-за того, что фреймворк пытается поддерживать все популярные LLM, он допускает случайные ошибки в конкретных реализациях. В примере до этого фреймворк не учел возможности некоторых API, вроде Гигачат. В других случаях он просто падает в рантайме (Issue).
Фреймворки могут скрывать грязные приемы внутри. Вот тут Koog подхачивает промпты пользователя, чтобы стратегия работала, как полагается: «Don't chat with plain text! Call one of the available tools, instead: ${tools.joinToString(", ")». Я не хочу, чтобы фреймворк менял промпты, потому что это и так делает API LLM. Видимо, тут следствие проблемы из пункта 4.
Сложная ментальная модель
Агент — это про цикл взаимодействия между LLM, пользователем системы и вызовом функций (тулов). Где-то рядом можно прикрутить трейсинг, кеши и RAG. И в общем-то всё.
Но Koog настолько сложен, что падает с OOM на компиляции (Issue).
Если вы захотите быстренько разобраться с Koog, вам придется погрузиться в их концепцию и терминологию:
Agent, LLM, Message, Prompt, Attachment, System prompt, Context, Session,
Event, EventContext, EventHandler,
OpenTelemetry, LoggingSpanExporter, Sampler,
AgentMemory, Concept, Fact, FactType, Subject,
ToolArgs, ToolSet, Class-based tool, Function-based tool
Strategy, Graph, Subgraph, Node, Edge, Conditions
LLMEmbedder, JVMTextDocumentEmbedder, EmbeddingBasedDocumentStorage
McpTool, McpToolDescriptorParser, McpToolRegistryProvider, ProcessBuilder
Модель можно значительно упростить:
Всё, что касается эвентов, можно реализовать на стороне пользователя. Нужен лишь способ передать callback на события перехода между нодами графа.
Зачем нам и граф, и сабграф. Уже выглядит грязно, ведь хочется, чтобы любой граф мог выступить в качестве сабграфа.
Зачем и граф, и стратегия? Можно было бы оставить только граф.
Всё что касается памяти, пользователь может решить сам — дополнить контекст в своем туле или в ноде графа.
Запутанный синтаксис
Вот пример создания ребра графа между sourceNode и targetNode в Koog.
edge(sourceNode forwardTo targetNode onCondition {input -> input.length > 10})
Что не так с этим кодом?
Мы не можем читать код слева направо. До того, как выполнится
forwardTo, запускаетсяonCondition.Мы не можем доверять тому, что infix функции выполнятся в том порядке, в котором мы ожидаем. Смысла в infix-функциях тут нет никакого.
Излишний синтаксис. Зачем нам и
edge, иforwardTo— это бесцельное дублирование.
На мой взгляд, код ниже читается лучше:
sourceNode.edgeTo {input -> if (input.length > 10) transformerNode else null}
Реализация агента на основе графов
Цикл работы агента ровно такой же, как и в прошлой статье (рисунок из нее), но реализация будет на основе графов.
По горячим следам прошлого параграфа, давайте сделаем набросок того, как должен выглядеть агент:
val agent = buildGraph {
nodeInput.edgeTo(nodeLLM)
nodeLLM.edgeTo { context ->
if (context.isToolUse()) nodeToolUse else nodeFinish
}
nodeToolUse.edgeTo(nodeLLM)
}
Каждый Node должен уметь принять Input и вернуть Output. Например, nodeLLM может выглядеть как-то так:
val nodeLLM = suspend fun (input: String, context: AgentContext): Pair {
val request = buildRequestFrom(input, context)
// Запрос к АПИ
val response = GigaHttpClient.chat(request)
// Добавление истории в контекст
val newContext = context.copyWith(appendToHistory = response.messages)
return response.messages.last to newContext
}
А nodeInput — это просто ожидание System.in от юзера:
val nodeInput = suspend fun (input: String, context: AgentContext): Pair {
println("> ")
val userMessage = readlnOrNull()
val newContext = context.copyWith(appendToHistory = userMessage)
return userMessage to newContext
}
В AgentContext мы можем положить всё, что нужно между нодами. Например, историю общения с агентом и предыдущий output.
Цикл общения с агентом можно построить двумя способами. Первый — сохранять контекст вовне:
val agent = ...
var seed = AgentContext("Агент готов") // тут будет сохраняться история
while (true) {
val result = graph.start(seed)
println("Agent said: $result")
seed = agent.currentContext
}
Второй способ — можно сделать граф зацикленным (заменить nodeFinish на nodeInput в buildGraph выше).
Пока всё должно выглядеть очень просто, давайте реализуем Node.
Упрощенная реализация Агента на основе графов
Давайте напишем первую реализацию, чтобы нащупать нужные абстракции.
В контексте агента важны input и history — ввод для вершины графа (Node) и история. Удобно иметь историю в контексте. Например, если агент упадет, мы можем запустить нового с уже имеющейся историей. Всю мутабельность можно спрятать на этом уровне в будущем.
data class AgentContext(
val input: String,
val history: List
)
Контекст меняется, переходя от одной вершины графа к другой. Давайте опишем вершины и переходы:
interface Node {
val name: String // для логов
suspend fun execute(ctx: AgentContext): AgentContext
}
/** Create new [Node] implementation based on [op] */
fun Node(
name: String,
op: suspend (AgentContext) -> AgentContext,
): Node = object : Node {
override val name: String = "Node $name; ${Integer.toHexString(hashCode())}"
override suspend fun execute(ctx: AgentContext) = op(ctx)
}
/** Ребра графа */
sealed interface Transition {
class Static(val target: Node) : Transition
class Dynamic(val router: suspend (AgentContext) -> Node) : Transition
}
Ребра (Transition) могут быть статическими и динамическими. Пример динамического перехода был в предыдущем разделе:
nodeLLM.edgeTo { context ->
if (context.isToolUse()) nodeToolUse else nodeFinish
}
Теперь надо решить, где хранить ребра графа (переходы). Не хочется, чтобы Node был мутабельным — бывшие Clojure-коллеги не поймут. Да и сами посудите, вдруг понадобится переиспользовать один Node в разных графах. Мутабельность всё испортит.
Пусть пока мутабельным будет Graph, чуть позже мы проведем рефакторинг:
class Graph {
val transitions = HashMap<Node, ArrayList<Transition>>()
val nodeEnter: Node = Node("enter") { it }
fun Node.edgeTo(target: Node): Node {
registerTransition(this, Transition.Static(target))
return target
}
fun Node.edgeTo(router: suspend (AgentContext) -> Node) {
registerTransition(this, Transition.Dynamic(router))
}
private fun registerTransition(from: Node, transition: Transition) {
val bucket = transitions.getOrPut(from) { arrayListOf() }
bucket += transition
}
}
Теперь мы можем создать агента:
suspend fun main() {
val nodeInput = Node("NodeInput") { ctx ->
val userMessage = readln()
ctx.copy(
input = userMessage,
history = ArrayList(ctx.history).apply { add(userMessage) }
)
}
val nodeLLM = Node("NodeLLM") { ctx ->
val llmResult = "I can't do much, just a mock"
ctx.copy(
input = llmResult,
history = ArrayList(ctx.history).apply { add(llmResult) }
)
}
val agent = Graph().apply {
nodeEnter.edgeTo(nodeInput)
nodeInput.edgeTo(nodeLLM)
nodeLLM.edgeTo(nodeEnter)
}
/*
agent.run(AgentContext("start")) { node, ctx ->
println(node.name + ": " + ctx.input)
}
*/
}
Чтобы граф можно было «запустить», реализуем функцию Graph.run(я быврал BFS, можно в будущем это контралировать через контекст):
suspend fun Graph.nextNodes(node: Node, ctx: AgentContext): List<Node> {
val registered = transitions[node] as? List<Transition> ?: emptyList()
if (registered.isEmpty()) return emptyList()
val next = ArrayList<Node>(registered.size)
for (transition in registered) {
when (transition) {
is Transition.Static -> next.add(transition.target)
is Transition.Dynamic -> next.add(transition.router(ctx))
}
}
return next
}
suspend fun Graph.run(
seed: AgentContext,
onStep: (Node, AgentContext) -> Unit
): AgentContext {
val queue = ArrayDeque<Pair<Node, AgentContext>>()
.apply { add(nodeEnter to seed) }
var lastCtx: AgentContext = seed
while (queue.isNotEmpty() && currentCoroutineContext().isActive) {
val (node, ctx) = queue.removeFirst()
val outCtx = node.execute(ctx)
onStep(node, outCtx)
lastCtx = outCtx
val nextNodes = nextNodes(node, outCtx)
if (nextNodes.isNotEmpty()) {
for (child in nextNodes) {
queue.add(child to outCtx)
}
}
}
return lastCtx
}
Вот и всё, мы реализовали core Koog. Можем запускаться и смотреть на результат.
Детальная реализация Агента на основе графов
Решение выше — всего лишь набросок, но уже функциональный. Внутри Node можно реализовать что угодно — например, построение другого графа или даже нескольких графов, которые можно запустить параллельно через обычный async. Имея callback в функции Graph.run, мы можем повесить метрики. Sequence abstraction (map, filter, reduce) легко реализуется через Node и Transition.
Чего не хватает:
Где-то нужно хранить тулы, system prompt, текущую модель и т.п..
Нет обработки ошибок и retry.
Нельзя использовать граф как сабграф (Node).
Отсутствие полиморфизма (параметрического, т.е. нет дженериков). Не переводить же String input в Json и обратно в String на каждом шагу.
Реализация графа не иммутабельна, а детали реализации торчат наружу (нет инкапсуляции).
Как будем решать?
Все настройки можно положить в AgentContext.
Вынесем абстракцию GraphRunner, которая будет думать о retry. Контекст запуска будем хранить в отдельной сущности GraphRuntime.
Граф реализует интерфейс Node.
Сделаем input в AgentContext дженериком. А историю можем хранить в DTO моделях Гигачата. Понадобится другая LLM-модель — напишем конвертер из Гигачат-моделей в целевые.
Вынесем создание графа в билдер, а граф оставим иммутабельным. Описание графа будет доступно в GraphRuntime.
Прежде чем начнем, реализации функций (тулов) можно найти в предыдущей статье или на гитхабе. Там же есть DTO и Ktor клиенты для Гигачата. Продублирую здесь:
Клиент и DTO для Гигачат
import com.fasterxml.jackson.annotation.JsonProperty
import java.util.*
object GigaResponse {
data class Token(
@JsonProperty("access_token") val accessToken: String,
@JsonProperty("expires_at") val expiresAt: Date
)
sealed interface Chat {
data class Ok(val choices: List<Choice>, val created: Long, val model: String) : Chat
data class Error(val status: Int, val message: String) : Chat
}
data class Choice(
val message: Message,
val index: Int,
@JsonProperty("finish_reason")
val finishReason: String
)
data class Message(
val content: String,
val role: GigaMessageRole,
@JsonProperty("function_call")
val functionCall: FunctionCall? = null,
@JsonProperty("functions_state_id")
val functionsStateId: String?
)
data class FunctionCall(
val name: String,
val arguments: Map<String, Any>
)
}
object GigaRequest {
data class Chat(
val model: String = "GigaChat-Max",
val messages: List<Message>,
@JsonProperty("function_call")
val functionCall: String = "auto",
val functions: List<Function>? = null,
)
data class Message(
val role: GigaMessageRole,
val content: String, // Could be String or FunctionCall object
@JsonProperty("functions_state_id")
val functionsStateId: String? = null
)
data class Function(
val name: String,
val description: String,
val parameters: Parameters
)
data class Parameters(
val type: String,
val properties: Map<String, Property>
)
data class Property(
val type: String,
val description: String? = null
)
}
@Suppress("EnumEntryName")
enum class GigaMessageRole { system, user, assistant, function }
const val MAX_TOKENS = 8192
enum class GigaModel(val alias: String, val maxTokens: Int) {
Lite("GigaChat-2", MAX_TOKENS),
Pro("GigaChat-Pro", MAX_TOKENS),
Max("GigaChat-Max", MAX_TOKENS),
}
fun String.toSystemPromptMessage() = GigaRequest.Message(
role = GigaMessageRole.system,
content = this
)
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.engine.cio.*
import io.ktor.client.request.*
import io.ktor.client.request.forms.*
import io.ktor.http.*
object GigaAuth {
suspend fun requestToken(apiKey: String): String {
val client = HttpClient(CIO) {
gigaDefaults()
}
val response = client.submitForm(
url = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
formParameters = Parameters.build {
append("scope", "GIGACHAT_API_PERS")
}
) {
header("Content-Type", "application/x-www-form-urlencoded")
header("Authorization", "Basic $apiKey")
}.body<GigaResponse.Token>()
client.close()
return response.accessToken
}
}
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.auth.*
import io.ktor.client.plugins.auth.providers.*
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.request.*
import io.ktor.http.*
class GigaChatAPI(private val auth: GigaAuth) {
private val client = HttpClient(CIO) {
var token = "" // get form env, or cache, or db
val gigaKey = System.getenv("GIGA_KEY")
gigaDefaults()
install(Auth) {
bearer {
loadTokens {
BearerTokens(token, "")
}
refreshTokens {
token = auth.requestToken(gigaKey)
BearerTokens(token, "")
}
}
}
install(Logging) {
val envLevel = LogLevel.INFO
level = envLevel
sanitizeHeader { it.equals(HttpHeaders.Authorization, true) }
}
}
suspend fun message(body: GigaRequest.Chat): GigaResponse.Chat {
val response = client.post("https://gigachat.devices.sberbank.ru/api/v1/chat/completions") {
setBody(body)
}
return when {
response.status.isSuccess() -> response.body<GigaResponse.Chat.Ok>()
else -> response.body<GigaResponse.Chat.Error>()
}
}
fun clear() = client.close()
}
import com.fasterxml.jackson.databind.DeserializationFeature
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.jackson.*
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.X509TrustManager
fun HttpClientConfig<CIOEngineConfig>.gigaDefaults() {
this.defaultRequest {
header(HttpHeaders.ContentType, "application/json")
header(HttpHeaders.Accept, "application/json")
header("RqUID", UUID.randomUUID().toString())
}
install(HttpTimeout) {
requestTimeoutMillis = 40000
}
install(ContentNegotiation) {
jackson { this.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) }
}
engine {
https {
trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
}
}
}
Описание функций (тулов)
@Target(AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.RUNTIME)
annotation class InputParamDescription(val value: String)
interface ToolSetup<Input> {
val name: String
val description: String
operator fun invoke(input: Input): String
}
class BadInputException(msg: String) : Exception(msg)
Пример реализации тула:
object ToolRunBashCommand : ToolSetup<ToolRunBashCommand.Input> {
override val name = "RunBashCommand"
override val description = "Executes a bash command and returns its output"
override fun invoke(input: Input): String {
val process = ProcessBuilder("bash", "-c", input.command)
.redirectErrorStream(true)
.start()
val output = process.inputStream.bufferedReader().use(BufferedReader::readText)
val exitCode = process.waitFor()
if (exitCode != 0) throw RuntimeException("Command failed with exit code $exitCode")
return output.trim()
}
data class Input(
@InputParamDescription("The bash command to run, e.g., 'ls', 'echo Hello', './gradlew tasks'")
val command: String
)
}
Маппинг на модели гигачата:
import com.dumch.tool.InputParamDescription
import com.dumch.tool.ToolSetup
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import kotlin.reflect.KCallable
import kotlin.reflect.full.declaredMembers
import kotlin.reflect.full.findAnnotation
interface GigaToolSetup {
val fn: GigaRequest.Function
operator fun invoke(functionCall: GigaResponse.FunctionCall): GigaRequest.Message
}
val gigaJsonMapper = jacksonObjectMapper()
inline fun <reified Input> ToolSetup<Input>.toGiga(): GigaToolSetup {
val toolSetup = this
return object : GigaToolSetup {
override val fn: GigaRequest.Function = GigaRequest.Function(
name = toolSetup.name,
description = toolSetup.description,
parameters = GigaRequest.Parameters(
"object",
properties = HashMap<String, GigaRequest.Property>().apply {
val clazz = Input::class
for (kProperty: KCallable<*> in clazz.declaredMembers) {
val annotation = kProperty.findAnnotation<InputParamDescription>() ?: continue
val description = annotation.value
val type = kProperty.returnType.toString().substringAfterLast(".").lowercase()
val gigaProperty = GigaRequest.Property(type, description)
put(kProperty.name, gigaProperty)
}
}
)
)
override fun invoke(
functionCall: GigaResponse.FunctionCall,
): GigaRequest.Message {
return try {
val input: Input = gigaJsonMapper.convertValue(functionCall.arguments, Input::class.java)
val toolResult = toolSetup.invoke(input)
val gigaResult = gigaJsonMapper.writeValueAsString(
mapOf("result" to toolResult)
)
GigaRequest.Message(
role = GigaMessageRole.function,
content = gigaResult,
)
} catch (e: Exception) {
e.toGigaToolMessage()
}
}
}
}
fun Exception.toGigaToolMessage(): GigaRequest.Message {
return GigaRequest.Message(
role = GigaMessageRole.function,
content = """{"result": "${message ?: toString()}"}""",
)
}
Приступаем к агенту.
Node с дженериками
interface Node<IN, OUT> {
val name: String
suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime): AgentContext<OUT>
}
/**
* Create new [Node] implementation based on [op]
*/
fun <IN, OUT> Node(
name: String,
op: suspend (AgentContext<IN>) -> AgentContext<OUT>,
): Node<IN, OUT> = object : Node<IN, OUT> {
override val name: String = "Node $name; ${Integer.toHexString(hashCode())}"
override suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime) = op(ctx)
}Реализация AgentContext
data class AgentContext<I>(
val input: I,
val settings: AgentSettings,
val history: List<GigaRequest.Message>,
val tools: List<GigaRequest.Function>,
val systemPrompt: String,
) {
inline fun <reified O> map(
settings: AgentSettings = this.settings,
history: List<GigaRequest.Message> = this.history,
activeTools: List<GigaRequest.Function> = this.tools,
systemPrompt: String = this.systemPrompt,
transform: (I) -> O = { it as O },
): AgentContext<O> = AgentContext(input = transform(input), settings, history, activeTools, systemPrompt)
}
data class AgentSettings(
val model: String,
val temperature: Float,
val tools: Map<String, GigaToolSetup>
)Реализация Graph с Builder
class Graph<IN, OUT> internal constructor(
label: String,
private val enter: Node<IN, *>,
private val exit: Node<OUT, OUT>,
private val retryPolicy: RetryPolicy,
private val definition: GraphDefinition,
) : Node<IN, OUT> {
private val runner = GraphRunner()
override val name: String = "$label::graph"
@Suppress("UNCHECKED_CAST")
override suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime): AgentContext<OUT> {
val result = runner.run(
start = enter as Node<Any?, Any?>,
seed = ctx as AgentContext<Any?>,
runtime = runtime,
definition = definition, // ребра передадим в Runner
stopPredicate = { node, _ -> node === exit }
)
return result as AgentContext<OUT>
}
suspend fun start(
seed: AgentContext<IN>,
maxSteps: Int = 1000,
onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null,
): AgentContext<OUT> {
val runtime = GraphRuntime(
retryPolicy = retryPolicy,
maxSteps = maxSteps,
onStep = onStep,
)
return execute(seed, runtime)
}
}
class GraphBuilder<IN, OUT> internal constructor(
private val graphName: String,
private val retryPolicy: RetryPolicy,
) {
val nodeInput: Node<IN, IN> = Node("$graphName::enter") { it }
val nodeFinish: Node<OUT, OUT> = Node("$graphName::exit") { it }
private val transitions: MutableMap<Node<*, *>, MutableList<Transition<*>>> = mutableMapOf()
fun <IN, OUT, OUT2> Node<IN, OUT>.edgeTo(target: Node<OUT, OUT2>): Node<OUT, OUT2> {
registerTransition(this, Transition.Static(target))
return target
}
fun <IN, OUT> Node<IN, OUT>.edgeTo(router: suspend (AgentContext<OUT>) -> Node<OUT, *>): Unit {
registerTransition(this, Transition.Dynamic(router))
}
private fun <OUT> registerTransition(from: Node<*, OUT>, transition: Transition<OUT>) {
val bucket = transitions.getOrPut(from) { mutableListOf() }
bucket += transition
}
internal fun build(): Graph<IN, OUT> = Graph(
graphName,
nodeInput,
nodeFinish,
retryPolicy,
GraphDefinition(transitions.mapValues { it.value.toList() }),
)
}
// Вынесем еще и абстракцию для хранения ребер
internal class GraphDefinition(
private val transitions: Map<Node<*, *>, List<Transition<*>>>,
) {
@Suppress("UNCHECKED_CAST")
suspend fun nextNodes(node: Node<Any?, Any?>, ctx: AgentContext<Any?>): List<Node<Any?, *>> {
val registered = transitions[node] as? List<Transition<Any?>> ?: emptyList()
if (registered.isEmpty()) return emptyList()
val next = ArrayList<Node<Any?, *>>(registered.size)
for (transition in registered) {
when (transition) {
is Transition.Static -> next.addOrWarn(transition.target as Node<Any?, *>)
is Transition.Dynamic -> next.addOrWarn(transition.router(ctx) as Node<Any?, *>)
}
}
return next
}
private fun MutableCollection<Node<Any?, *>>.addOrWarn(node: Node<Any?, *>) {
if (contains(node)) {
add(node)
}
}
}
internal sealed interface Transition<OUT> {
class Static<OUT>(val target: Node<OUT, *>) : Transition<OUT>
class Dynamic<OUT>(val router: suspend (AgentContext<OUT>) -> Node<OUT, *>) : Transition<OUT>
}
// Ниже всего лишь бойлерплейт для делегатов (by).
fun <I, O> buildGraph(
name: String = "Graph",
retryPolicy: RetryPolicy = RetryPolicy(),
configure: GraphBuilder<I, O>.() -> Unit
): Graph<I, O> {
val builder = GraphBuilder<I, O>(name, retryPolicy)
builder.configure()
return builder.build()
}
fun <I, O> graph(
name: String? = null,
retryPolicy: RetryPolicy = RetryPolicy(),
configure: GraphBuilder<I, O>.() -> Unit
): ReadOnlyProperty<Any?, Graph<I, O>> = GraphDelegate(name, retryPolicy, configure)
private class GraphDelegate<I, O>(
private val nameHint: String?,
private val retryPolicy: RetryPolicy,
private val configure: GraphBuilder<I, O>.() -> Unit,
) : ReadOnlyProperty<Any?, Graph<I, O>> {
private var cached: Graph<I, O>? = null
override fun getValue(thisRef: Any?, property: KProperty<*>): Graph<I, O> {
return cached ?: build(property.name).also { cached = it }
}
private fun build(propertyName: String): Graph<I, O> {
val name = nameHint ?: propertyName
val builder = GraphBuilder<I, O>(name, retryPolicy)
builder.configure()
return builder.build()
}
}Реализация GraphRunner и GraphRuntime
internal class GraphRunner {
suspend fun run(
start: Node<Any?, Any?>,
seed: AgentContext<Any?>,
runtime: GraphRuntime,
definition: GraphDefinition,
stopPredicate: ((Node<Any?, Any?>, AgentContext<Any?>) -> Boolean)? = null,
): AgentContext<Any?> {
val queue = ArrayDeque<Frame>().apply { add(Frame(start, seed, 0)) }
val leaves = mutableListOf<AgentContext<*>>()
var lastCtx: AgentContext<Any?> = seed
try {
while (queue.isNotEmpty() && currentCoroutineContext().isActive) {
if (runtime.counter.get() >= runtime.maxSteps) {
error("Graph maxSteps (${runtime.maxSteps}) reached — potential loop")
}
val frame = queue.removeFirst()
val outCtx = executeWithRetry(frame.node, frame.ctx, runtime)
val stepInfo = StepInfo(currentGraphIndex = frame.depth, index = runtime.counter.get())
runtime.onStep?.invoke(stepInfo, frame.node, outCtx)
lastCtx = outCtx
if (stopPredicate?.invoke(frame.node, outCtx) == true) return outCtx
val nextNodes = definition.nextNodes(frame.node, outCtx)
if (nextNodes.isEmpty()) {
leaves += outCtx
} else {
for (child in nextNodes) {
@Suppress("UNCHECKED_CAST")
queue.add(Frame(child as Node<Any?, Any?>, outCtx, frame.depth + 1))
}
}
runtime.counter.incrementAndGet()
}
} catch (cancel: CancellationException) {
throw GraphCancellation(lastCtx, cancel)
}
@Suppress("UNCHECKED_CAST")
return leaves.lastOrNull() as? AgentContext<Any?> ?: lastCtx
}
private suspend fun executeWithRetry(
node: Node<Any?, Any?>,
inCtx: AgentContext<Any?>,
runtime: GraphRuntime,
): AgentContext<Any?> {
val policy = runtime.retryPolicy
var attempt = 0
var lastError: Throwable? = null
while (attempt < policy.maxAttempts) {
attempt++
try {
return node.execute(inCtx, runtime)
} catch (t: Throwable) {
if (t is CancellationException) throw t
lastError = t
val shouldRetry = policy.shouldRetry(t, inCtx, node, attempt)
val attemptsLeft = policy.maxAttempts - attempt
if (!shouldRetry || attemptsLeft <= 0) break
}
}
throw lastError ?: IllegalStateException("Unknown failure in node ${node.name}")
}
private data class Frame(
val node: Node<Any?, Any?>,
val ctx: AgentContext<Any?>,
val depth: Int,
)
}
class GraphCancellation(
val lastContext: AgentContext<*>,
cause: CancellationException? = null
) : CancellationException(cause?.message) {
init {
initCause(cause)
}
}
data class StepInfo(
/**
* Sequential index of the executed node within the run (starting from 0).
*/
val index: Int,
/**
* Sequential index of the executed node within the current graph run (starting from 0).
*/
val currentGraphIndex: Int,
)
class GraphRuntime private constructor(
val retryPolicy: RetryPolicy,
val maxSteps: Int,
val onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null,
val counter: AtomicInteger
) {
constructor(
retryPolicy: RetryPolicy,
maxSteps: Int,
onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null,
): this(retryPolicy, maxSteps, onStep, counter = AtomicInteger())
}
data class RetryPolicy(
val maxAttempts: Int = 2,
val shouldRetry: suspend (
error: Throwable,
ctx: AgentContext<*>,
node: Node<*, *>?,
attempt: Int
) -> Boolean = { _, _, _, _ -> true }
)Дефолтные реализации Node
object NodesCommon {
val stringToReq: Node<String, GigaRequest.Chat> = Node("String->Request") { ctx ->
val usrMsg = GigaRequest.Message(GigaMessageRole.user, ctx.input)
val history = ArrayList(ctx.history).apply {
if (isEmpty()) add(ctx.systemPrompt.toSystemPromptMessage())
add(usrMsg)
}
ctx.map(history = history) { ctx.toGigaRequest(history) }
}
val respToString: Node<GigaResponse.Chat, String> = Node("Response->String") { ctx ->
when (val input = ctx.input) {
is GigaResponse.Chat.Error -> ctx.map { input.message }
is GigaResponse.Chat.Ok -> ctx.map { input.choices.last().message.content }
}
}
val toolUse: Node<GigaResponse.Chat, GigaRequest.Chat> = Node("toolUse") { ctx ->
val fnCallMessages = fnCallMessages(ctx)
val history = ArrayList(ctx.history).apply { addAll(fnCallMessages) }
ctx.map(history = history) { ctx.toGigaRequest(history) }
}
private suspend fun fnCallMessages(ctx: AgentContext<GigaResponse.Chat>): List<GigaRequest.Message> {
val fnCallMessages = (ctx.input as GigaResponse.Chat.Ok).choices.mapNotNull { choice ->
val msg = choice.message
if (msg.functionCall != null && msg.functionsStateId != null) {
executeTool(ctx.settings, msg.functionCall)
} else null
}
return fnCallMessages
}
private suspend fun executeTool(
settings: AgentSettings,
functionCall: GigaResponse.FunctionCall,
): GigaRequest.Message {
val tools = settings.tools
val fn: GigaToolSetup = tools[functionCall.name] ?: return GigaRequest.Message(
GigaMessageRole.function, """{"result":"no such function ${functionCall.name}"}"""
)
return fn.invoke(functionCall)
}
}
fun <T> AgentContext<T>.toGigaRequest(history: List<GigaRequest.Message>): GigaRequest.Chat {
val ctx = this
return GigaRequest.Chat(
model = ctx.settings.model,
messages = history,
functions = ctx.tools,
)
}
class NodesLLM(llmApi: GigaChatAPI) {
val chat: Node<GigaRequest.Chat, GigaResponse.Chat> = Node("llmCall") { ctx ->
val response = withContext(Dispatchers.IO) {
llmApi.message(ctx.input)
}
val history = ArrayList(ctx.history).apply {
if (response is GigaResponse.Chat.Ok) {
addAll(response.choices.mapNotNull { it.toMessage() })
}
}
ctx.map(history = history) { response }
}
/**
* Restores the last message, and a system prompt. Other messages are transformed into TLDR
*/
val summarize: Node<GigaResponse.Chat, GigaResponse.Chat> = Node("llmSummarize") { ctx ->
val conversation = ArrayList(ctx.history)
val summaryResponse: GigaResponse.Chat = withContext(Dispatchers.IO) {
conversation.add(GigaRequest.Message(
role = GigaMessageRole.user,
content = "Резюмируй разговор",
))
val request = ctx.toGigaRequest(conversation)
.copy(functions = emptyList())
llmApi.message(request)
}
val msg: GigaRequest.Message = when (summaryResponse) {
is GigaResponse.Chat.Error -> {
throw IOException(summaryResponse.message)
}
is GigaResponse.Chat.Ok -> summaryResponse.choices.mapNotNull { it.toMessage() }.last()
}
val newHistory = listOf(ctx.systemPrompt.toSystemPromptMessage(), ctx.history.last(), msg)
ctx.map(history = newHistory) { summaryResponse }
}
private fun GigaResponse.Choice.toMessage(): GigaRequest.Message? {
val msg = this.message
val content: String = when {
msg.content.isNotBlank() -> msg.content
msg.functionCall != null -> gigaJsonMapper.writeValueAsString(
mapOf("name" to msg.functionCall.name, "arguments" to msg.functionCall.arguments)
)
else -> return null
}
return GigaRequest.Message(
role = msg.role,
content = content,
functionsStateId = msg.functionsStateId
)
}
}
И, наконец, агент с вызовом тулов, суммаризацией истории и хранением контекста (истории) будет выглядеть так:
import com.dumch.agent.engine.*
import com.dumch.agent.node.NodesCommon
import com.dumch.agent.node.NodesLLM
import com.dumch.giga.*
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.cancellation.CancellationException
import kotlin.math.ceil
class GraphBasedAgent(
private val model: String,
private val llmApi: GigaChatAPI,
private val tools: Map<String, GigaToolSetup> = GigaAgent.tools
) {
private val nodesLLM = NodesLLM(llmApi)
// Make sure summarization only happens after all tool requests from LLM are answered
private val nodeSummarize: Node<GigaResponse.Chat, String> by graph(name = "Go to user") {
nodeInput.edgeTo { ctx -> if (ctx.historyIsTooBig()) nodesLLM.summarize else NodesCommon.respToString }
nodesLLM.summarize.edgeTo(NodesCommon.respToString)
NodesCommon.respToString.edgeTo(nodeFinish)
}
private val settings = AgentSettings(
model = model,
temperature = 0.7f,
tools = tools
)
private val allFunctions: List<GigaRequest.Function> = settings.tools.values.map { it.fn }
private val initialCtx = AgentContext(
input = "",
settings = settings,
history = emptyList(),
tools = allFunctions,
systemPrompt = SYSTEM_PROMPT
)
private val _ctx: MutableStateFlow<AgentContext<String>> = MutableStateFlow(initialCtx)
val currentContext: StateFlow<AgentContext<String>> = _ctx
private val runningJob = AtomicReference<Deferred<*>>()
fun cancelActiveJob() {
runningJob.get()?.cancel(CancellationException("Cleared by force"))
}
/** Execute one job at a time */
suspend fun execute(input: String): String {
cancelActiveJob()
val ctx = currentContext.value.copy(input = input)
val result: Deferred<AgentContext<String>> = coroutineScope {
async { buildGraph().start(ctx) { _, _, _ -> } }
}
runningJob.set(result)
val newContext = result.await()
_ctx.emit(newContext)
return newContext.input
}
private fun buildGraph(): Graph<String, String> = buildGraph(name = "Agent") {
nodeInput.edgeTo(NodesCommon.stringToReq)
NodesCommon.stringToReq.edgeTo(nodesLLM.chat)
nodesLLM.chat.edgeTo { ctx ->
when (val output = ctx.input) {
is GigaResponse.Chat.Error -> nodeSummarize
is GigaResponse.Chat.Ok -> if (isToolUse(output)) NodesCommon.toolUse else nodeSummarize
}
}
NodesCommon.toolUse.edgeTo(nodesLLM.chat)
nodeSummarize.edgeTo(nodeFinish)
}
private fun isToolUse(input: GigaResponse.Chat.Ok): Boolean = input.choices.any { it.message.functionCall != null }
private fun AgentContext<GigaResponse.Chat>.historyIsTooBig(
threshold: Double = HISTORY_SUMMARIZE_THRESHOLD,
): Boolean {
val model = GigaModel.entries.firstOrNull { it.alias == settings.model }
val contextWindow = model?.maxTokens ?: MAX_TOKENS
val estimatedTokens = systemPrompt.estimateTokenCount() +
history.sumOf { it.content.estimateTokenCount() }
return estimatedTokens >= contextWindow * threshold
}
private fun String.estimateTokenCount(): Int = ceil(length / APPROX_CHARS_PER_TOKEN).toInt()
}
private const val HISTORY_SUMMARIZE_THRESHOLD = 0.8
private const val APPROX_CHARS_PER_TOKEN = 4.0
private val SYSTEM_PROMPT = """
Ты программист-помощник, 10 лет пишешь код на Kotlin, Android и Backend. Стараешься писать простой и поддерживаемый код.
""".trimIndent()
Использование:
private const val AGENT_ALIAS = "?"
suspend fun main() {
val agent = GraphBasedAgent(
model = GigaModel.Max.alias,
llmApi = GigaChatAPI(GigaAuth),
)
userInputFlow().collect { text ->
val result = agent.execute(text)
println(AGENT_ALIAS + result)
}
}
private fun userInputFlow(): Flow = flow {
println("Type `exit` to quit")
while (true) {
print("> ")
val input = readlnOrNull() ?: break
if (input.lowercase() == "exit") break
emit(input)
println("\n")
}
}
Изменения по сравнению с версией агента из предыдущей статьи — то есть реализация на основе графа и базовые Nodes — собраны в PR.
Если нужен openai вместо gigacode, можно взять openai-kotlin. В предыдущей статье писал, как легко адаптировать anthropic через их sdk. Библиотеки «композируются», в отличие от фреймворков.
Frameworks do not compose. — Tomas Petricek, article.
Добавление RAG
Обычно под RAG имеется в виду следующий алгоритм:
Запрос к API, чтобы перевести текст в вектор (например, запрос пользователя) .
Поиск по векторной базе данных похожих текстов (например, по документам компании).
Прикрепление похожих текстов к промпту.
На хабре есть статья с формальными определениями и примерами.
RAG можно найти в документации Koog в подкатегории Advanced Usage. И я уверен, что с Koog задача действительно потребует advanced-усилий, ведь не понятно, что они используют под капотом, есть ли там кеши, retry, какая база данных будет использоваться, можно ли не тащить ненужные зависимости в проект, будут ли они добавлять промпты, чтобы захачить свои проблемы. На всё есть ответы, но с этим надо разбираться (об этих проблемах с примерами и ссылками писал в этой же статье выше).
Давайте реализуем RAG в рамках имеющегося решения. Абстракции, относящиеся к агенту, трогать не будем — всё решим на уровне Node.
Нам понадобится ручка с модельками для перевода текстов в вектора. Реализуем на доступном всем Гигачат.
Ручка и модельки
object GigaResponse {
// ... предыдущий код
data class Embeddings(
val data: List<Embedding>,
val model: String,
@JsonProperty("object") val objectType: String,
)
data class Embedding(
val embedding: List<Double>,
val index: Int,
@JsonProperty("object") val objectType: String? = null,
)
}
object GigaRequest {
// ... предыдущий код
data class Embeddings(
val model: String = "Embeddings",
val input: List<String>,
)
}
class GigaChatAPI(private val auth: GigaAuth) {
// ... предыдущий код
suspend fun embeddings(body: GigaRequest.Embeddings): GigaResponse.Embeddings {
val response = client.post("https://gigachat.devices.sberbank.ru/api/v1/embeddings") {
setBody(body)
}
return when {
response.status.isSuccess() -> response.body<GigaResponse.Embeddings>()
response.status == HttpStatusCode.Unauthorized || response.status == HttpStatusCode.Forbidden -> TODO("Auth exception")
else -> TODO("unexpected error")
}
}
}Добавляем векторную базу и наивную реализацию:
implementation("org.apache.lucene:lucene-core:9.9.2")
Можно было бы решить и плагином для SQL, но для целей статьи так быстрее:
Обертка над векторной базой
object VectorDB {
private const val INDEX_PATH = "build/rag_index"
init {
val isInitialized = File(INDEX_PATH).exists() // naive way to check initialization
if (!isInitialized) {
val dir = FSDirectory.open(Paths.get(INDEX_PATH))
IndexWriter(dir, IndexWriterConfig()).use { }
}
}
fun insert(data: List<String>, embeddings: List<List<Double>>) {
val dir = FSDirectory.open(Paths.get(INDEX_PATH))
IndexWriter(dir, IndexWriterConfig()).use { writer ->
data.indices.forEach { idx ->
val doc = Document()
doc.add(StoredField("text", data[idx]))
doc.add(KnnFloatVectorField("embedding", toFloatArray(embeddings[idx])))
writer.addDocument(doc)
}
}
}
fun getAllTexts(): List<String> {
val dir = FSDirectory.open(Paths.get(INDEX_PATH))
DirectoryReader.open(dir).use { reader ->
val list = mutableListOf<String>()
for (i in 0 until reader.maxDoc()) {
val doc = reader.document(i)
doc.get("text")?.let { list.add(it) }
}
return list
}
}
fun searchSimilar(embedding: List<Double>, limit: Int = 5): List<String> {
val dir = FSDirectory.open(Paths.get(INDEX_PATH))
DirectoryReader.open(dir).use { reader ->
val searcher = IndexSearcher(reader)
val query = KnnFloatVectorQuery("embedding", toFloatArray(embedding), limit)
val topDocs = searcher.search(query, limit)
val texts = mutableListOf<String>()
topDocs.scoreDocs.forEach { sd ->
searcher.doc(sd.doc).get("text")?.let { texts.add(it) }
}
return texts
}
}
fun clearAllData() {
val dir = FSDirectory.open(Paths.get(INDEX_PATH))
IndexWriter(dir, IndexWriterConfig()).use { writer ->
writer.deleteAll()
}
}
private fun toFloatArray(list: List<Double>): FloatArray {
val size = min(list.size, MAX_DIM)
val arr = FloatArray(size)
for (i in 0 until size) {
arr[i] = list[i].toFloat()
}
return arr
}
private const val MAX_DIM = 1024
}Имеющихся реализаций хватит, чтобы пощупать RAG руками:
suspend fun main() {
val vectorDb = VectorDB
// Настройка базы
vectorDb.clearAllData() // осторожнее с последующими запусками, тут — чистка
val api = GigaChatAPI(GigaAuth)
val knownFacts = listOf(
"RAG is an AI technique that combines a search engine with a large language model (LLM) — Google AI overview",
"Perhaps the biggest and the most obvious problem with frameworks is that they cannot be composed. — Tomas Petricek",
"Use frameworks only for applications with a short development lifespan, and avoid frameworks for systems you intend to keep for multiple years. — Mathias Verraes",
"Inversion of control is a common feature of frameworks, but it's something that comes at a price. " +
"It tends to be hard to understand and leads to problems when you are trying to debug. " +
"So on the whole I prefer to avoid it unless I need it. — Martin Fowler"
)
val factsEmbeddings = api.embeddings(GigaRequest.Embeddings(input = knownFacts))
vectorDb.insert(knownFacts, factsEmbeddings.data.map { it.embedding })
// Использование базы для поиска схожих строк
val input = "Фреймворк — хорошо или плохо? Есть ли причины не использовать фреймворки?"
val embedding = api.embeddings(GigaRequest.Embeddings(input = listOf(input)))
val result = vectorDb.searchSimilar(embedding.data.first().embedding, limit = 3)
// Ожидаю, что напечатаются 3 цитаты про фреймворки.
println(result.joinToString(prefix = "Found:\n", separator = "\n"))
}
Теперь в классе, где мы описываем граф, можно добавить Node:
class GraphBasedAgent(...) {
private val nodeAppendAdditionalData: Node = Node("appendActualInformation") { ctx ->
val additionalMessage = appendActualInformation(ctx.input)
if (additionalMessage == null) {
ctx
} else {
val history = ArrayList(ctx.history).apply { add(additionalMessage) }
ctx.map(history = history)
}
}
private fun buildGraph(): Graph = buildGraph(name = "Agent") {
nodeInput.edgeTo(nodeAppendAdditionalData)
nodeAppendAdditionalData.edgeTo(NodesCommon.stringToReq)
NodesCommon.stringToReq.edgeTo(nodesLLM.chat)
...
}
private suspend fun appendActualInformation(userText: String): GigaRequest.Message? {
if (userText.isBlank()) return null
val embedding = llmApi.embeddings(GigaRequest.Embeddings(input = listOf(userText)))
val result = VectorDB.searchSimilar(embedding.data.first().embedding, limit = 2)
.joinToString(
prefix = "Найденные в локальном хранилище данные:\n",
separator = "\n",
)
return GigaRequest.Message(
role = GigaMessageRole.user,
content = result,
)
}
}
Вот и весь RAG. Для удобства вынес в PR на гитхабе. В production-ready приложении добавятся обработка ошибок, ретраи, и, может быть, слой с репозиторием.
Когда использовать фреймворк, а не самописное решение?
Когда иначе невозможно. Примеры: разработка мобильных приложений.
Когда фреймворк становится де-факто стандартом. Пример: Spring для бэкенда на Java.
Когда приложение короткоживущее. Пример: MVP, POC, ~разовый аутсорс.
Use a framework for applications with a short expected development lifespan. — Mathias Verraes, X.
В заключение
Весь код, необходимый для того, чтобы переписать агента с циклов (первая статья) на графы — в PR на гитхабе. Примерно такой же код я использую в двух других проектах. Отличия — в деталях, которые опустил в статье для упрощения материала. Читателю не составит труда адаптировать код или даже написать свою реализацию.
Надеюсь, кому-то статья окажется полезной. Обратная связь и критика приветствуются.