Привет! Меня зовут Дарина, и я занимаюсь фундаментальными исследованиями в MTS AI. Основной фокус нашей работы сейчас — обучение больших языковых моделей, их тестирование и оптимизация.
Сегодня хочу сделать обзор на недавно вышедшую статью LLEMMA: an open language model for mathematics. Расскажу про обучение модели, новый датасет Proof-Pile-2 и в конце сравню ее с ChatGPT и GPT-4 на ЕГЭ заданиях по профильной математике.
Введение
За последнее время было выпущено много больших языковых моделей, которые умеют поддержать диалог, решить математическую задачку, помочь составить презентацию и т.д. Однако, если обучать или дообучать модель на определенную сферу знаний, то это принесет больше пользы и к тому же на это будет потрачено меньше ресурсов. Например, модель Galactica, обученная на научных данных, превосходит более модели GPT3 и BLOOM. Также недавно выпущенная CodeLlama показывает результаты лучше, чем ее базовая модель Llama 2.
Авторы статьи решили создать открытую языковую модель Llemma, умеющую решать математические задачи. Также они собрали датасет Proof-Pile-2, на котором обучались. Модели, код для обучения и датасет были выложены в открытый доступ на GitHub и HuggingFace.
Датасет
В качестве основного датасета был сформирован Proof-Pile-2, который состоит из датасетов поменьше:
AlgebraicStack. Авторы создали датасет из 11 миллиардов токенов исходного кода на 17 языках, связанных с математикой. Сюда входят языки программирования для доказательств теорем: Lean, Isabelle, Coq и другие. Из популярных языков программирования включены Python, Matlab, C и C++. Данные были взяты и отфильтрованы из the Stack, публичных гитхаб-репозиториев с помощью GitHub API. Отдельно данные для Lean и Isabelle были получены, соответственно, из Mathlib, архива формальных доказательств и стандартной библиотеки Isabelle.
OpenWebMath. Датасет, состоящий из 15 миллиардов токенов включает в себя веб-страницы с математическим контентом.
Научные статьи из arXiv. Часть датасета RedPajama, состоящая из 29 миллиардов токенов. RedPajama - это датасет, воспроизводящий датасет для обучения Llama.
Proof-Pile-2 актуален на апрель 2023 года.
Также помимо Proof-Pile-2 модель была обучена на датасете the Pile и подмножестве GitHub из датасета RedPajama.
Модель и обучение
Каждая модель Llemma была инициализирована с CodeLlama - decoder-only трансформер, обученный на 500B токенов кода. Авторы продолжили обучение две CodeLLama модели 7B и 34B.
Llemma 7B обучалась на 200B токенах 23 тысячи A100-часов. Llemma 34B обучалась на 50B токенах 47 тысяч A100-часов. Обе модели имеют контекст 4096 токенов.
Во время обучения использовались tensor и data parallelism, а также ZeRO Stage 1. Не буду вдаваться в подробности что это такое, но советую почитать эту статью для понимания. Не обошлось также без FlashAttention2 и RoPE.
Оценка
Chain-of-thought промптинг
Сначала авторы статьи решили оценить способность модели решать математические задачи используя chain of thought reasoning. CoT - это промпт, который включает в себя обоснование данного ответа. Например, как показано на картинке ниже.
Для оценки использовались следующие датасеты:
MATH - датасет, включающий в себя 12.5 тысяч задач из соревнований по математике среди старших школ. Модели подается проблема, а ее ответ генерируется в виде Latex решения. В статье использовали 4-shot промптинг.
GSM8k - датасет из 8.5 тысячи математических задач уровня средней школы, написанными людьми. Оценка проводилась с помощью 8-shot промптинга.
OCWCourses - коллекция задач уровня бакалавриата полученных из OpenCourseWare от MIT.
MMLU-STEM - подмножества 18 предметных областей из 57 бенчмарка MMLU. Использовался 4-shot промптинг.
SAT - созданный авторами статьи, датасет, состоящий из 32 математических вопросов.
Целью было оценить Llemma как базовую модель, сравнивая ее с подобными, которые не дообучались (fine-tuning) на математических данных. В качестве главного конкурента была выбрана Minerva от Google Research, которая продолжила обучение PaLM. Самым главным преимуществом Llemma является ее открытость. Minerva же закрытая модель с закрытым датасетом.
Дополнительно модель сравнивалась с CodeLlama и Llama 2. В качестве метрики была выбрана точность совпадения строк или его SymPy эквивалента. Ниже можно увидеть результаты.
Также оценка проводилась с помощью majority voting или maj@k. Это способ выбора самого популярного ответа среди k сгенерированных ответов, вместо greedy decoding, который просто выбирает самый вероятный. На рисунке ниже показан пример.
Как можно заметить, это довольно долгий процесс оценки, поэтому авторы решили сделать оценку только для Minerva и Llemma. Для бенчмарка MATH k равнялся 256, а для GSM8k и OCW k = 100. В случае SAT и MATH генерировалось всего 16 сэмплов. После выбора самого популярного ответа, использовался nucleus sampling с p = 0.95. Результаты можно увидеть ниже.
Proof assistant
Proof assistant - это интерактивная программа для доказательств теорем. Обычно такая программа имеет свой собственный язык. Как уже было выше сказано, авторы создали датасет AlgebraicStack, который включает в себя 1.5 миллиарда токенов таких языков и решили проверить свою модель на двух задачах:
Informal-to-formal. Перевод из неформально описанной задачи и ее решения на языке Latex в формальный язык Isabelle. Чтобы оценить правильность, был использован бенчмарк miniF2F.
Formal-to-formal. Генерация продолжения доказательства на основе предыдущих шагов для программы Lean 4. Иными словами, Copilot для этого языка. Также использовался бенчмарк miniF2F.
Результаты получились такими:
Llemma превосходит все сравниваемые модели для бенчмарка miniF2F.
Влияние пропорций данных
Авторы статьи решили поэкспериментировать с пропорциями между AlgebraicStack, OpenWebMath и статьями из arXiv. Пропорции данных подразумевают, что мы каким-то датасетам отдаем больше предпочтения и обучаем c большим количеством шагов. Такие пропорции перебирались руками, и проводилось короткое обучение модели. Тестировался этот эксперимент на датасете MATH по темам математики, а в качестве метрики взяли перплексию. Ниже можно увидеть результаты.
Лучшим соотношение оказалось 2:4:1.
Тестирование
Я решила проверить Llemma 7B на 21 задании ЕГЭ профильной математики. С помощью Яндекс Переводчика я перевела вариант на английский и глазами проверила ответы. Вариант на английском лежит тут. Я использовала 1-shot промптинг, взяв только первую задачу из промпта для бенчмарка MATH, используемого в статье. Из-за вычислительных ограничений я не тестировала Llemma 34B.
Для сравнения я также протестировала на этих заданиях GPT-4 и ChatGPT (gpt-3.5-turbo). Получились такие результаты:
Модель |
Кол-во верных ответов |
Llemma 7B |
6 |
ChatGPT |
11 |
GPT-4 |
12 |
Перебор параметров генерации Llemma для достижения лучшего результата занял довольно долгое время. Она повторялась, пыталась генерировать сама вопросы и ответы на них, ошибалась в знаках. В итоге у меня так и не получилось найти золотую середину, где недостатков совсем не было. Я остановилась на параметрах repetition_penalty=1.1 и temperature=0.2.
GPT-3 и ChatGPT справились без дополнительного промпта, не повторялись и подробно объясняли свой ответ.
Заключение
В этом материале я рассмотрела модели Llemma 7B и 34B, продолжающие обучение CodeLlama на математических данных. Рассказала, как обучались модели и какие данные для этого использовались. Были рассмотрены методы оценки модели на основных математических бенчмарках, а также как модель справляется с генерацией кода для математических доказательств.
Дополнительно я провела свою оценку на 21 задаче из ЕГЭ по профильной математике, сравнивая модель с ChatGPT и GPT-4. В результате, Llemma показала худшие результаты. В защиту хотелось бы сказать, что тестирование проводилось на маленькой модели. Также цель авторов статьи заключалась в создании открытой базовой модели, которую лучше дообучить на определенной сфере математики.
rPman
Когда проводится оценка модели, считают процент успешных ответов, но делают ли что-нибудь с теми ответами, которые не верные. Проводится ли анализ, закономерности и т.п. Можно ли сделать какие то предположения о том что вот такая то модель плохо работает вот с такими разделами математики, или когда вопрос формулируют в таком стиле и т.п.?
А если тот же вопрос задать другой модели, той же мощности? А если задать тот же вопрос этой же модели несколько раз (с учетом рандомизации генерации ответа, т.е. temp>0)? Играясь с llama, меняя немного запросы (не принципиально по смыслу но меняя формулировки) очень часто можно добиться верного ответа. Может нужно изучить условия, при которых ответ становится лучше?
p.s. даже в лучшем случае по бенчмаркам выходит что получить верный ответ чуть чуть больше чем в половине случаев... грустно
darinka666 Автор
Авторы статьи сделали такую оценку только для экспериментов по изменению пропорций данных. Каким-то образом упустила этот абзац, сейчас добавила в статью.
Было бы неплохо, конечно, узнать все результаты по группам, чтобы сделать выводы и предположения, вы правы.
Если под мощностью имеется в виду кол-во параметров, то Llemma 7B сравнивалась с CodeLlama 7b и Minerva 8B.
В главе "Оценка" как раз рассказан метод majority voting, когда генерируется несколько ответов, а далее выбирается самый популярный.