RuGPT3 - коллекция генеративных моделей от Сбер.

Проводим автоматическое тестирование циклическим перебором вариантов.

Работаем в Colab, тестируем Small, Mediub, Large.

Параметры генерации совершенно неоптимизированы - это первый заход, чтобы посмотреть исходную ситуацию и сравнивать по мере улучшения.

Обработка реплик тоже не проводится. Ни абзацы, ни предложения - ничего, только удаляются знаки переноса строки, если их много подряд.

Алгоритм тестирования

полный последовательный перебор
10 вопросов
3 модели: Small, Medium, Large
Несколько наборов параметров внутри каждой модели
Реплики сохраняем в гугл-таблице

Параметры

max_length
Максимальная длина
В тестовых примерах применяется набор max_length = [20, 50, 100, 200]

Greedy search
Аргмаксная генерация. «Жадная генерация»
Каждый раз выбирается токен, у которого максимальная вероятность.

Beam search
num_beams - кол-во путей с наибольшими неочевидными итоговыми вероятностными сочетаниями
В тестовых примерах num_beams = 5
num_return_sequences - кол-во лучших вариантов генерации на вывод
В тестовых примерах num_return_sequences = 3

Temperature sampling
«Температура»
Можно трактовать как случайный выбор с учетом распределения вероятностей.
Чем ближе к нулю, тем больше похоже на Greedy Search или «жадную генерацию».
В тестовых примерах применяется набор temperature = [0.8, 1.3, 2.0]

Nucleus sampling

Для запрета сэмплирования совсем некорректных токенов вводятcя top-k или top-p ограничения.

В этом случае генерация тоже происходит случайным образом, но заранее отсекаются все маловероятные токены.

top_k = n - определяется n слов, которые обладают наибольшей вероятностью из условного распределения вероятностей всех слов, что сужает выбор для модели и отбрасывает максимально неподходящие слова сразу. Это позволяет контролировать разнообразие генерируемого текста и избежать слишком случайного выбора.

top_p = n - определяется n слов, чья вероятностная масса вместе равна n%, то есть ограничения сэмплирования происходят динамически, исходя из начального набора. Другими словами, определяет суммарную вероятность набора наиболее вероятных токенов.

Будут рассмотрены токены, пока их суммарная вероятность не достигнет n. Это также позволяет контролировать разнообразие генерируемого текста и избежать слишком случайного выбора.

В тестовых примерах top_k  =  20, top_p = 0.8

Пишем код и запускаем генерацию

Начало кода python
!pip install transformers==4.24.0
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import requests

Спрашиваем ChatGPT 10 самых распространенных вопросов на школьных экзаменах

10 вопросов
text_massive = [
  'Какова формула для вычисления площади прямоугольника?',
  'Какие основные фазы клеточного деления?',
  'Кто написал роман "Преступление и наказание"?',
  'Какие основные элементы таблицы Mendeleev?',
  'Какая формула используется для вычисления площади треугольника?',
  'Какие основные законы Ньютона?',
  'Какой газ является основным компонентом атмосферы Земли?',
  'Какой орган отвечает за фильтрацию крови в организме человека?',
  'Какое число называется пи?',
  'В каком году родился Александр Сергеевич Пушкин?'
  ]

Устанавливаем данные для гугл-таблицы и пропишем фразы в соответствующих листах

Установки для Google
# идентификатор файла Google
google_id = "1Oa-vu9wBkG7APHzHPidB-5_PANZqSlYt61jUmrIOJwk"

# макрос для записи данных в ячейку Google
url = '_makros_' 

sheets_massive = ['Лист1', 'Лист2','Лист3','Лист4','Лист5','Лист6','Лист7','Лист8','Лист9','Лист10']

range = 'A2'

for index, key in enumerate(text_massive):
  value = key
  google_sheet = sheets_massive[index]
  data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
  response = requests.get(url + data)
  print(google_sheet, range, 'status_code:', response.status_code)

print('Finished')

Указываем модель и столбец

Задание модели и столбца
# model
model_name_or_path = "sberbank-ai/rugpt3small_based_on_gpt2"

# column
columns = 'C'

Запускаем генерацию

Генерация реплик и сохранение данных в таблицу
# automatic circle

# Small

print('Started automatic circle ...')
print()

print(model_name_or_path)
print()

for index0, key in enumerate(text_massive):
  text = key
  google_sheet = sheets_massive[index0]

  print(index0, google_sheet)
  print(key)

  # Токенизация
  input_ids = tokenizer.encode(text, return_tensors="pt").to(DEVICE)
  print(input_ids)
  print()

  # Greedy search

  print()
  print('Greedy search')
  print()

  def out(key):
    out = model.generate(input_ids,
                        do_sample=False,
                        max_length=key)
    return(out)

  generated_text_massive = []
  list_length = [20, 50, 100, 200]

  for key in list(list_length):
    print(key)

    # Декодирование токенов
    generated_text = list(map(tokenizer.decode, out(key)))[0]
    generated_text_massive.append(generated_text)
    print()

  print('Greedy search generation is finished');
  print()

  # Save to Google_sheet

  rows = [6,7,8,9]

  for index, key in enumerate(generated_text_massive):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print(list_length[index], 'status_code:', response.status_code)

  print('Greedy search total is finished')
  print()

  # Beam search

  print()
  print('Beam search')
  print()

  def out(key):
    out = model.generate(input_ids,
                        do_sample=False,
                        num_beams=5,
                        max_length=key)
    return(out)

  generated_text_massive = []
  list_length = [20, 50, 100, 200]

  for key in list(list_length):
    print(key)

    # Декодирование токенов
    generated_text = list(map(tokenizer.decode, out(key)))[0]
    generated_text_massive.append(generated_text)
    print()

  print('Beam search generation is finished');
  print()

  # Save to Google_sheet

  rows = [12, 13, 14, 15]

  for index, key in enumerate(generated_text_massive):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print(list_length[index], 'status_code:', response.status_code)

  print('Beam search total is finished')
  print()

  print()
  print('Beam search with num_return_sequences')
  print()

  generated_text_massive = []
  generated_text_massive_list = []
  list_length = [20, 50, 100, 200]

  for key in list(list_length):
    print(key)
    out = model.generate(input_ids,
                        do_sample=False,
                        max_length=key,
                        num_beams=5,
                        num_return_sequences = 3)

    # Декодирование токенов
    generated_text = list(map(tokenizer.decode, out))[0]
    generated_text_massive.append(generated_text)
    generated_text_massive_list.append(list(map(tokenizer.decode, out)))
    print()

  print('Beam search with num_return_sequences generation is finished')
  print()

  # Save to Google_sheet

  rows = [17, 18, 19]

  for index, key in enumerate(generated_text_massive_list[0]):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print('status_code:', response.status_code)

  rows = [21, 22, 23]

  for index, key in enumerate(generated_text_massive_list[1]):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print('status_code:', response.status_code)

  rows = [25, 26, 27]

  for index, key in enumerate(generated_text_massive_list[2]):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print('status_code:', response.status_code)

  rows = [29, 30, 31]

  for index, key in enumerate(generated_text_massive_list[3]):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print('status_code:', response.status_code)

  print('Beam search total is finished')
  print()

  # Temperature sampling

  print()
  print('Temperature sampling');
  print()

  def out(key, temperature):
    out = model.generate(input_ids,
                        do_sample=True,
                        temperature=temperature,
                        max_length=key)
    return(out)

  generated_text_massive = []
  list_length = [20, 50, 100, 200]
  temperature_massive = [1.3, 0.8, 2.0]

  for temperature in list(temperature_massive):
    print('temperature:', temperature)

    for key in list(list_length):
      print(key)

      # Декодирование токенов
      generated_text = list(map(tokenizer.decode, out(key, temperature)))[0]
      generated_text_massive.append(generated_text)
      print()

  print('Temperature sampling generation is finished');
  print()

  # Save to Google_sheet

  rows = [35, 36, 37, 38, 41, 42, 43, 44, 47, 48, 49, 50]
  list_length = list_length + list_length + list_length


  for index, key in enumerate(generated_text_massive):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print(list_length[index], 'status_code:', response.status_code)

  print('Temperature sampling total is finished')
  print()

  # Nucleus search

  print()
  print('Nucleus search');
  print()

  def out(key):
    out = model.generate(input_ids,
                        do_sample=True,
                        temperature=1.3,
                        top_k=20,
                        top_p=0.8,
                        max_length=key)
    return(out)

  generated_text_massive = []
  list_length = [20, 50, 100, 200]

  for key in list(list_length):
    print(key)

    # Декодирование токенов
    generated_text = list(map(tokenizer.decode, out(key)))[0]
    generated_text_massive.append(generated_text)
    print()

  print('Nucleus search generation is finished');
  print()

  # Save to Google_sheet

  rows = [53, 54, 55, 56]

  for index, key in enumerate(generated_text_massive):
    range = columns + str(rows[index])
    value = key.replace('\n\n\n','')
    data = '?id=' + google_id + '&sheet=' + google_sheet + '&range=' + range + '&value=' + value
    response = requests.get(url + data)
    print(list_length[index], 'status_code:', response.status_code)

  print('Nucleus search total is finished')
  print()
  print('--- --- ---')
  print()

print('Finished')
print()

Конечно, код возможно оптимизировать, обернуть повторы в циклы и функции и прочее. В данном случае не стали на этом фокусироваться, так как вычисления пока не затратные и задача решается с приемлемой скоростью.

Результат

Таблица с репликами выложена здесь.

Понимаем, что есть много чего подбирать, оптимизировать, обрабатывать, добавлять, mask, padding и так далее. В данном случае это принципиально чистый исходный заход для фиксации начального состояния. Развитие ситуации предполагается в следующих статьях.

Примечания

  1. Если хотите протестировать вопросы, модели и параметры - укажите в комментариях.
    Сделаем тестирование и выложим ссылку на таблицу с результатами.
    Желательно формировать блоки по 10 вопросов, чтобы ничего не менять в коде.

  2. Если видите какие-то критические ошибки - пожалуйста, укажите в комментариях.
    Если можете предложить способы улучшения генерации - пожалуйста, укажите в комментариях.

Спасибо за прочтение статьи :-)

Хорошего дня :-)

Комментарии (2)


  1. holodoz
    02.07.2023 06:41
    +8

    Сэкономлю время тем, кто хочет узнать, насколько плохи ответы. Они очень плохи, лежат в диапазоне от

    "Какие основные законы Ньютона?

    В Ньютоне нет законов Ньютона.

    Какие"

    до

    Какие основные законы Ньютона?
    Научный термин Ньютон можно сказать с легкостью - это закон всемирного тяготения. Только, кроме ньютона, в Ньютоном есть и обратный полюс. (Если бы Вы хотели написать, что в принципе Ньютон – это только теория, то мы тут без Вас обойдётесь. )

    Как думаете, куда бы девушка (муж), ушла?:))
    а сколько вам лет
    Мне кажется у нее уже всё к свадьбе готово
    ну... я бы на работу)))
    Да у меня все впереди)))
    а куда?... до свидания
    а зачем вы мне мозги пудрите?....)))))

    Могу ли я найти вторую жизнь?
    на счет двух и той девушки - можно думать, но возможно второй нет - ищите себе другое тело-жизнь одна

    Помогите перевести ((
    В смысле переводим с иностранным?

    Какая разница между мужчиной и женщиной и мужчина мужчина и женщина?)


    1. avdosev
      02.07.2023 06:41

      Не удивлен, недавно сравнивал rugpt3 и gpt3 для своих целей. И даже старенький gpt3 показывает качественно более интересные результаты. Видимо, с ру датасетами в то время были проблемы.