Проверка небольших чисел на простоту - популярная подзадача в спортивном программировании. И тест Миллера-Рабина, пожалуй, наиболее популярный из простых алгоритмов для этого.

У меня давно было желание с ним поиграться, стараясь оптимизировать различными способами. Например, векторизовать и посмотреть, станет ли быстрее.

Дисклеймер. Я - Java разработчик, на C++ писал в основном только в университете. Так что местами мой код мог вдруг получиться каким-нибудь странным.
Сразу отмечу, что у меня не было цели написать портируемый код. Моей целью было развлечься и посмотреть, что я смогу сделать, а не предложить миру универсальное решение.

Какую именно задачу решаем

Цель: уметь понимать, является ли знаковое 32-битное целое число простым, и желательно быстро.

Оптимальным решением по скорости проверки в данном случае было заранее вычислить все простые числа, соответствующие нашим требованиям, и положить их в массив/хеш-таблицу, чтобы потом делать там поиск. Всего их вышло бы около 105 миллионов. Этого я делать не буду, чтобы избежать долгих предварительных вычислений.

Вообще, хочется иметь аналог такой функции, только быстрый:

bool testNaive(int32_t n) {
    if (n <= 2) return n == 2;

    for (int32_t m = 2, boundary = sqrt(n); m <= boundary; m++) {
        if (n % m == 0) return false;
    }

    return true;
}

В коде выше приведён классический тест на простоту, который перебирает потенциальных кандидатов в делители числа n. Работает он заO(\sqrt n), то есть медленно.

Тест Миллера-Рабина

Тест Миллера-Рабина - вероятностный тест для проверки чисел на простоту. Но заранее зная верхнюю границу для проверяемого числа, этот тест несложно превратить в детерминированный, об этом чуть позже.

Тест простоты числа n состоит из нескольких раундов. На каждом раунде мы получаем один из двух ответов:

  • число составное;

  • число может быть простое, фиг знает.

Чем больше раундов мы получаем ответ "может быть простое", тем больше шанс, что число действительно простое.

Сам же раунд состоит из следующих действий:

  • взять число 1 < a < n, да такое, которое ещё не использовалось в предыдущих раундах;

  • проверить, является ли a "свидетелем простоты" числа n. Если не является, то n - составное, а если является, то n называют сильно псевдопростым по модулю a.

Далее я буду использовать имена переменных, как на википедии, чтобы было проще сверять. Проверка свидетельства простоты в тесте Миллера-Рабина происходит по следующему алгоритму:

  • В первую очередь нужно отсеять чётные числа, допустимы только нечётные n.

  • n нужно представить в виде 2^s\cdot{t}+1. Стоит заметить, что s всегда больше 0.

  • Выполнить следующий код. Внешне он отличается от псевдокода с википедии, но делает ровно то же самое. true означает сильную псевдопростоту n по модулю a:

int32_t s = __builtin_ctz(n - 1); // Число нулей в младших битах.
int32_t t = (n - 1) >> s;

int32_t x = pow_mod(a, t, n);     // Возведение в степень по модулю n.
if (x == 1) return true;

for (int i = 1; x != n - 1; i++) {
    if (i == s) return false;

    x = mul_mod(x, x, n);         // Возведение в квадрат по модулю n.

    if (x == 1) return false;
}

return true;

По какому принципу выбирать числа a? Лично я поступлю очень просто, хотя не факт что эффективно. На википедии есть ссылки на списки сильно псевдопростых чисел по модулю 2 и по модулю 3, так что мы могли бы сделать проверку для 2 и 3, и среди чисел, прошедших обе проверки, отсеять те, которые содержатся в обоих списках...

Но нет, на самом деле таких чисел довольно таки много. А вот если взять 2, 3 и 5, то в диапазоне от 1 до 231 останется только 4 сильно псевдопростых числа:

  • 25326001

  • 161304001

  • 960946321

  • 1157839381

Их уже легко можно захардкодить, получив в итоге следующий код:

bool testMillerRabinInteger(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1; // 3, 5 и 7.

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    int32_t primes[3] = {2, 3, 5};

    for (int32_t a : primes) {
        int32_t x = pow_mod(a, t, n);

        if (x == 1) continue;

        for (int i = 1; x != n - 1; i++) {
            if (i == s) return false;

            x = mul_mod(x, x, n);

            if (x == 1) return false;
        }
    }

    switch (n) {
        case 25326001:
        case 161304001:
        case 960946321:
        case 1157839381:
            return false;

        default:
            return true;
    }
}

Это уже полностью детерминированный тест на простоту, бери и пользуйся! Осталось расписать упомянутые тут mul_mod и pow_mod.

Для произведения по модулю достаточно расширить как множители, так и модуль, до 64-битных целых, чтобы защититься от переполнения в произведении, получив
(int32_t) (((int64_t) a * b) % m).

Для возведения в степень будем использовать разновидность быстрого возведения в степень, чуть более простую в реализации:

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) result = (result * n) % m;

        n = (n * n) % m;

        power >>= 1;
    }

    return (int32_t) result;
}

Включаем -O3, тестируем - всё работает! Рассчитываем на то, что компилятор хорошо оптимизировал этот код и самим ничего доделывать не придётся (если бы...).

База готова, вступление на этом закончено. Перейдём ко вкусненькому.

Что такое векторизация

Векторизация кода - это его параллелизация на уровне данных. Это примерно то, что делают видеокарты - один поток исполнения, но много независимых наборов данных, обрабатываемых одновременно.

Для CPU векторизация реализуется специальными наборами инструкций, например SSE и AVX. Обычно программистам не стоит об этом беспокоиться - всю работу на себя должен брать компилятор, главное указать ему ключи оптимизации. Основная цель компиляторной векторизации - циклы по большим массивам данных.

В нашем случае больших циклов нет, а данных мало, но зато есть тело цикла, которое выполняется до 3-х раз для 3-х различных значений параметра a. Такие паттерны компилятор сам векторизовать не будет, даже не надейтесь. Моя цель - убрать этот цикл, сделав "обычный" код, проводящий вычисления над всеми нужными значениями a одновременно.

Итого, нужно решить 2 подзадачи:

  • возведение в степень по модулю;

  • цикл с возведениями в квадрат.

Поскольку в среднем s сильно меньше t (s=O(\log n),\:t = O(n)), то можно сказать, что возведение в степень должно занимать гораздо больше времени исполнения, а значит сосредоточиться в первую очередь стоит на нём.

Векторные типы и инструкции

В более-менее современных CPU архитектуры x86 есть 2 набора векторных регистров - 128-битные XMM и 256-битные YMM. Бывают ещё 512-битные регистры в AVX-512, но они не так распространены.

Напрямую с регистрами мы работать не будем. C++ - язык высокоуровневый, в нём есть специальные типы данных для векторов, как и библиотечные функции на этих типах, которые соответствуют инструкциям процессора. Такие функции называются интринсиками. Чтобы почитать о них подробнее, придётся включать VPN, ведь документация запрещена на территории РФ (самим Intel, никакого экстремизма).

В рамках данной статьи я буду компилироваться GCC с 17-й версией языка (если это вдруг важно). Подключив immintrin.h, можно получить доступ ко множеству типов __m128* и __m256*, отвечающим различным типам векторов, например:

  • __m128i - это вектор из 4-х значений типа int32_t;

  • __m256i - это вектор из 4-х значений типа int64_t;

  • __m256d - это вектор из 4-х значений типа double.

Здесь видно первое преимущество векторов перед нашим циклом - они способны одновременно обрабатывать не 3 единицы данных, а 4. Это здорово, мы без каких-либо потерь можем добавить 4-е значение a, например, проверяя 2, 3, 5 и 7. Если снова заглянуть в OEIS, то будет видно, что все сильно псевдопростые числа по этим модулям, вплоть до 2147483648 (231), являются истинно простыми, т.е. тот switch в конце больше не будет нужен!

Для векторизованного возведения в степень нам нужно найти функции произведения и взятия остатка от деления, всё должно быть просто. Помним, что обе эти операции выполнялись над 64-битными целыми, а значит работать будем с __m256i.

И тут мы попадаем в первую серьёзную ловушку:

  • _mm256_mullo_epi64 - интринсик для умножения длинных целых, не компилируется:
    error: inlining failed in call to 'always_inline' '__m256i _mm256_mullo_epi64(__m256i, __m256i)': target specific option mismatch.

    Упс, кажется он доступен только для AVX-512. Мой процессор такого не умеет. Так что не удивительно, да и компилировался я только с -mavx -mavx2.

  • Векторизованного вычисления остатка от деления не существует в принципе. А обычное деление длинных целых, называемое _mm256_div_epi64, которое мне выдала документация - это даже не инструкция CPU, а кастомная функция, вообще отсутствующая в GCC. Придётся импровизировать.

Я не первый человек, который задался вопросом векторизованного деления целых чисел. Универсальный ответ, который дают люди из интернета - пользуйтесь делением с плавающей точкой, пока вам хватает точности. А когда перестанет хватать - напишите дополнительный код исправления ошибки. Так что отвлечёмся на минутку и подумаем, сколько точности нам даст double, если решить использовать именно его.

Числа с плавающей точкой

Как устроены числа типа double? Наверное не все хорошо помнят, так что освежим эту информацию. Любое такое число представлено в виде 3-х частей:

  • 1 бит для знака;

  • 52 бита для "мантиссы";

  • 11 бит для "порядка", или же смещения плавающей точки.

double - это 53 значащих бита (мантисса + всегда присутствующая рядом с ней единица) и известное положение плавающей точки, которое может быть далеко за пределами значащих бит. Если пробовать кодировать только целые числа, то становится видно, что любое 53-битное беззнаковое число (54-битное знаковое соотв.) можно закодировать без потери точности.

64-битный int64_t мы использовали для того, чтобы не было переполнения при умножении 32-битных целых. А значит double можно использовать для того, чтобы избежать переполнения при умножении 26-битных чисел (53 / 2). Если ограничиться проверкой на простоту только 26-битных чисел, то этого будет легко достичь арифметикой с плавающей точкой. Так и поступим. Будем работать с четвёрками значений типа double, то есть с __m256d.

Возведение в степень по модулю

Проблема с умножением решена. Что же с делением с остатком? Для него нам придётся воспользоваться следующей формулой:

a\pmod{m} = a - \lfloor \frac{a}{m} \rfloor \cdot m

Реализовав эту формулу, мы получим следующую функцию вычисления произведения по модулю, сразу для 4-х значений типа double (в комментариях - невекторизованный аналог):

__m256d mul_mod(__m256d a, __m256d b, __m256d m) {
    __m256d c = _mm256_mul_pd(a, b);     // double c = a * b;

    __m256d tmp = _mm256_div_pd(c, m);   // double tmp = c / m;
    tmp = _mm256_floor_pd(tmp);          // tmp = floor(tmp);
    tmp = _mm256_mul_pd(tmp, m);         // tmp = tmp * m;

    return _mm256_sub_pd(c, tmp);        // return c - tmp;
}

Код возведения в степень же почти не будет отличаться от невекторизованного варианта:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);  // result = {1, 1, 1, 1};

    while (power) {
        if (power & 1) result = mul_mod(result, n, m);

        n = mul_mod(n, n, m);

        power >>= 1;
    }

    return result;
}

Цикл с квадратами

Если попробовать сейчас собрать воедино то, что есть, то получим следующий код:

bool testMillerRabinVectorized0(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    __m256d primes = _mm256_set_pd(2, 3, 5, 7);

    __m256d n_pd = _mm256_set1_pd(n);
    __m256d x = pow_mod(primes, t, n_pd);

    // ?
}

На месте вопроса раньше стоял цикл, внутри которого во первых встречаются выражения return / continue, а во вторых для разных a там может быть разное число итераций:

if (x == 1) continue;

for (int i = 1; x != n - 1; i++) {
    if (i == s) return false;

    x = mul_mod(x, x, n);

    if (x == 1) return false;
}

Нужно привести этот код к виду, пригодному для векторизации. Парадоксально, но для этого нам потребуется выбросить условие выхода при x == 1 внутри цикла, оставим только выход по числу итераций и по x == n - 1. Сейчас объясню.

Сделаем так: будем присваивать x = 0, если это псеводпростое по модулю a. Если x станет равным 0 для всех a, то число n - простое, иначе - нет. Опишем это следующим псевдокодом:

if (x == 1)     x = 0;
if (x == n - 1) x = 0;

if (x == 0 forall a)     // x - сильно псевдопростое по всем модулям a.
    return true;

for (int i = 1; i < s; i++) {
    x = mul_mod(x, x, n);

    if (x == n - 1) x = 0;

    if (x == 0 forall a) // x - сильно псевдопростое по всем модулям a.
        return true;
}

return false;

Если на какой-то итерации x стал равным 1, то он останется 1 и на следующих итерациях, ведь 1 * 1 = 1. То же самое с 0, так как 0 * 0 = 0. В итоге return false на последней строке сработает, если:

  • для какого-то a мы получили x == 1;

  • для какого-то a мы получили i == s.

Это именно те условия выхода, которые нам нужны. Векторизованная версия этого кода будет выглядеть так:

__m256d n_minus_one = _mm256_set1_pd(n - 1);

x = blend_zero(x, _mm256_set1_pd(1)); // if (x == 1)     x = 0;
x = blend_zero(x, n_minus_one);       // if (x == n - 1) x = 0;

if (all_zero(x)) return true;

for (int i = 1; i < s; i++) {
    x = mul_mod(x, x, n_pd);

    x = blend_zero(x, n_minus_one);   // if (x == n - 1) x = 0;

    if (all_zero(x)) return true;
}

return false;

Что делает all_zero должно быть понятно из названия. blend_zero же делает покомпонентное сравнение элементов первого вектора с элементами второго вектора и проставляет в первом 0 там, где совпало, иначе оставляет оригинальное значение, где не совпало. Реализуются эти методы следующим образом:

const __m256d ZERO = _mm256_setzero_pd();

bool all_zero(__m256d a) {
    __m256d mask_pd = _mm256_cmp_pd(a, ZERO, _CMP_NEQ_OQ);

    return 0 == _mm256_movemask_pd(mask_pd);
}

__m256d blend_zero(__m256d a, __m256d b) {
    __m256d mask_pd = _mm256_cmp_pd(a, b, _CMP_EQ_OQ);

    return _mm256_blendv_pd(a, ZERO, mask_pd);
}

Маска типа __m256d содержит единичные биты (64 штуки, на весь double) в тех элементах, для которых условие в cmp выполнилось (в частности _CMP_NEQ_OQ или _CMP_EQ_OQ).

movemask преобразует вектор-маску в обычную битовую маску, где каждому элементу уже будет соответствовать ровно один бит, а не 64. То есть all_zero буквально проверяет, что "нет элементов, не равных 0". Да, двойное отрицание, и что?

_mm256_blendv_pd - составляет вектор из двух, беря элемент второго вектора, если в маске 1, либо первого вектора если в маске 0. Думаю, должно быть понятно, а если непонятно, то в документации есть псевдокод.

Кажется, что первая версия должна быть готова:

bool testMillerRabinVectorized(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    __m256d primes = _mm256_set_pd(2, 3, 5, 7);

    __m256d n_pd = _mm256_set1_pd(n);
    __m256d x = pow_mod(primes, t, n_pd);

    __m256d n_minus_one = _mm256_set1_pd(n - 1);

    x = blend_zero(x, _mm256_set1_pd(1));
    x = blend_zero(x, n_minus_one);

    if (all_zero(x)) return true;

    for (int i = 1; i < s; i++) {
        x = mul_mod(x, x, n_pd);

        x = blend_zero(x, n_minus_one);

        if (all_zero(x)) return true;
    }

    return false;
}

Сравниваем производительность

Методика тестирования у меня будет максимально тупая. Берём большой массив целых чисел и для каждого из них проверяем, является ли оно простым. У меня вышел следующий код:

void measure(std::function<bool(int32_t)> test, uint32_t count) {
    auto start = std::chrono::high_resolution_clock::now();

    int32_t n = 0;

    for (uint32_t i = 1; i < count; i++) {
        n += test(i);
    }

    auto end = std::chrono::high_resolution_clock::now();

    std::cout << (end - start).count() * 1e-9d << " seconds" << std::endl;
    std::cout << n << " primes found" << std::endl << std::endl;
}

Запустим его для имеющихся реализаций, помня, что верхняя граница у нас - 226, и скрестим пальцы:

std::cout << "testNaive:" << std::endl;
measure(&testNaive, 1 << 26);

std::cout << "testMillerRabinInteger:" << std::endl;
measure(&testMillerRabinInteger, 1 << 26);

std::cout << "testMillerRabinVectorized:" << std::endl;
measure(&testMillerRabinVectorized, 1 << 26);
testNaive:
33.9298 seconds
3957809 primes found

testMillerRabinInteger:
4.77004 seconds
3957809 primes found

testMillerRabinVectorized:
4.74518 seconds
3957809 primes found

Реальность полна разочарований, чуда не произошло. Обе функции работают примерно с одной скоростью, в рамках погрешности. Может быть векторизованная версия чуть-чуть быстрее в среднем, но это неважно.

Потрясающее совпадение! Также это не плохая новость - ручное вычисление остатка от деления, оказывается, не слишком и медленное. Остановлюсь ли я на этом? Конечно же, нет, мне нужно однозначно обогнать невекторизованный код, так что начнём тюнинг.

Суперскалярность

Начну с оптимизации, которая принесёт не самый большой профит, но зато она очень показательна и больше её вставить некуда.

Суперскалярность - это тоже параллелизация на уровне данных, но уже не совсем явная в коде, а неявная в процессоре. Ядро CPU способно выполнять несколько инструкций одновременно, при условии, что они независимы. Так что код, который обрабатывает независимые данные, способен работать быстрее кода, который обрабатывает зависимые данные.

Рассмотрим пример - функцию pow_mod для целых чисел, повторю её код:

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) result = (result * n) % m;

        n = (n * n) % m;

        power >>= 1;
    }

    return (int32_t) result;
}

В этом коде есть независимые вычисления над result и n, но вот беда - они находятся в разных скоупах, поэтому и компилятору, и процессору было бы сложно что-то с этим сделать. Что если скоупы объединить?

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) {
            result = (result * n) % m;
            n = (n * n) % m;
        } else {
            n = (n * n) % m;
        }

        power >>= 1;
    }

    return (int32_t) result;
}

Запустим тест и проверим. Для testMillerRabinInteger до оптимизации было 4.78585 секунд, после стало 4.25464 секунд, то есть заметно быстрее. Хотя, может, дело и не только в суперскалярности, может, бранч-предиктор стал тоже себя иначе вести, кто его знает. Главное, что код стал быстрее.

Так стоп, что я наделал, я же должен был оптимизировать векторизованный код, теперь всё стало только хуже! Быстро делаем то же самое для векторизованной версии pow_mod, потирая руки:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);

    while (power) {
        if (power & 1) {
            result = mul_mod(result, n, m);
            n = mul_mod(n, n, m);
        } else {
            n = mul_mod(n, n, m);
        }

        power >>= 1;
    }

    return result;
}

И... никакой заметной разницы. Вообще.

Почему? Может, компилятор не заинлайнил вызов mul_mod? Вроде бы нет, если сделать ручной инлайн, то производительность не меняется. Как было примерно 4.7 секунд, так и осталось.

Может, компилятор не захотел переупорядочить инструкции в коде во время компиляции? Проверим, вручную сгруппировав независимые операции:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);

    while (power) {
        if (power & 1) {
            __m256d c1 = _mm256_mul_pd(result, n);
            __m256d c2 = _mm256_mul_pd(n, n);

            __m256d tmp1 = _mm256_div_pd(c1, m);
            __m256d tmp2 = _mm256_div_pd(c2, m);

            tmp1 = _mm256_floor_pd(tmp1);
            tmp2 = _mm256_floor_pd(tmp2);

            tmp1 = _mm256_mul_pd(tmp1, m);
            tmp2 = _mm256_mul_pd(tmp2, m);

            result = _mm256_sub_pd(c1, tmp1);
            n = _mm256_sub_pd(c2, tmp2);
        } else {
            n = mul_mod(n, n, m);
        }

        power >>= 1;
    }

    return result;
}

Невероятно, теперь я получаю примерно 4.22 - 4.28 секунд, примерно как в суперскалярной версии невекторизованного алгоритма! А значит, что компилятор не всегда способен эффективно переупорядочить инструкции. Особенно если речь идёт об интринсиках, но это так, скорее предположение.

Обе версии алгоритма теперь работают немного быстрее, при этом среди них всё ещё нет явного лидера.

Деление и умножение

Что проще, умножать или делить? Очевидно, что умножать. Но в нашем алгоритме очень много делений, нужно сократить их количество. Для этого вспомним весьма банальное свойство:

\frac{a}{b} = a \cdot \frac{1}{b}

Деление - это умножение на обратное. Если посмотреть на код внимательно, то станет видно, что делим мы всегда на один и тот же вектор, а именно n_pd (внутри дополнительных функций встречается под именем m). Вместо этого предлагаю ввести второе значение:

__m256d n_pd  = _mm256_set1_pd(n);       // m
__m256d n_inv = _mm256_set1_pd(1.0 / n); // m_inv

И каждый раз, когда встречается _mm256_div_pd(c, m), писать вместо этого
_mm256_mul_pd(c, m_inv). Интуиция говорит, что должно стать быстрее, но насколько?

Сделаем и запустим. До исправления - 4.27096 секунды, после - 3.05808. Разница почти 30%, впечатляет! Вот уж действительно, деление - это очень дорого.

А знаете, какой алгоритм тоже использует деление? testMillerRabinInteger. Может если и в нём сделать умножение, то и он станет быстрее?

double mod_double(double n, double m, double m_inv) {
    double tmp = n * m_inv;
    tmp = floor(tmp);
    return n - tmp * m;
}

double pow_mod_double(double n, int32_t power, double m, double m_inv) {
    double result = 1;

    while (power) {
        if (power & 1) result = mod_double(result * n, m, m_inv);

        n = mod_double(n * n, m, m_inv);

        power >>= 1;
    }

    return result;
}

bool testMillerRabinDouble(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    double m = n,
           m_inv = 1.0 / m;

    int32_t primes[3] = {2, 3, 5};

    for (int32_t a : primes) {
        double x = pow_mod_double(a, t, m, m_inv);

        if (x == 1.0) continue;

        for (int i = 1; x != n - 1; i++) {
            if (i == s) return false;

            x = mod_double(x * x, m, m_inv);

            if (x == 1.0) return false;
        }
    }

    switch (n) {
        case 25326001:    // Остальные убрал, потому что они больше чем 1<<26.
            return false;

        default:
            return true;
    }
}

Тестируем:

testMillerRabinInteger:
4.2803 seconds
3957809 primes found

testMillerRabinDouble:
4.00873 seconds
3957809 primes found

Вот это поворот, вышло-то быстрее и для невекторизованного кода! То есть, это насколько целочисленные * и % долгие, если их можно заменить вещественными *, *, floor, ещё раз * и -, и всё равно окажется эффективнее.

У меня есть гипотеза. При умножении 64-битных целых чисел процессор фактически возвращает нам 128-битное произведение, разбитое на 2 регистра (старшая и младшая половины), один из которых нам не нужен. Аналогично с делением: процессор в один регистр кладёт частное, а в другой - остаток. Он просто делает много ненужной нам побочной работы. Уверен, есть и другие причины.

Итого, имеет векторизованный код, который проверяет "все" числа примерно за 3 секунды, и невекторизованный код, который делает то же самое примерно за 4 секунды. Кажется, что это победа. Для 26-битных чисел.

Как умножать если double недостаточно

Что делать для чисел, у которых от 27 до 31 значащих бит? Можно ли векторизовать их проверку на простоту, не прибегая к AVX-512? Конечно можно, но сразу скажу, что на эффективность этого решения рассчитывать не стоит. Тем не менее, я его реализую, чтобы хотя-бы узнать, насколько плохо всё выйдет.

Вся задача сводится к тому, чтобы реализовать функцию mul_mod, которая может перемножать 31-битные числа по модулю другого 31-битного числа, т.е. перестать ограничиваться 26-ю битами, как в текущей реализации:

__m256d mul_mod(__m256d a, __m256d b, __m256d m, __m256d m_inv) {
    __m256d c = _mm256_mul_pd(a, b);

    __m256d tmp = _mm256_mul_pd(c, m_inv);
    tmp = _mm256_floor_pd(tmp);
    tmp = _mm256_mul_pd(tmp, m);

    return _mm256_sub_pd(c, tmp);
}

Перед тем, как продолжить, я хочу немного отвлечься, вспомнив одну из своих любимых подзадач в спортивном программировании. Она здесь важна, правда:

Даны 2 числа a и b типа int64_t, нужно найти их произведение по модулю 1012. Желательно эффективно.

Всем желающим предлагаю отвлечься и подумать - как бы вы написали подобный алгоритм? На всякий случай напомню, что верхняя граница знаковых 64-битных чисел равна примерно 9*1018, этот факт легко забыть.

Я бы решал следующим образом. Без ограничения общности допустим, что и a, и b меньше 1012 (и неотрицательные). Если это не так, то вместо них можно взять соответствующие остатки от деления. В таком случае значение b всегда можно представить в виде
b = b1 * 1000000 + b2, где b1 < 1000000 и b2 < 1000000.

Чем хорош миллион, так это тем, что при перемножении числа <1012 и числа <106 мы получаем число <1018, то есть помещающееся в знаковое 64-битное. Такие произведения можно делать в коде без всяких опасений, главное, правильным образом преобразовать формулу для a * b:

(a \cdot b) \pmod{10^{12}}\\ = (a \cdot (b_1\cdot 10^6 + b_2)) \pmod{10^{12}}\\ = (a \cdot b_1 \cdot 10^6) \pmod{10^{12}} + (a \cdot b_2) \pmod{10^{12}} \\ = ((a\cdot b_1) \pmod{10^{12}} \cdot 10^6)\pmod{10^{12}} + (a \cdot b_2) \pmod{10^{12}} \\ = ((a\cdot b_1) \pmod{10^{12}} \cdot 10^6 + (a \cdot b_2)) \pmod{10^{12}}

В этой формуле ни одно из произведений не переполняет int64_t, а значит она годится для реализации в коде, и потребует это всего лишь нескольких строк.

Какое отношение это имеет к решаемой задаче? Прямое. Нам аналогично нужно перемножить два 32-битных числа, но для вычисления произведения есть всего-лишь 53 бита. И я предлагаю разбить один из множителей (b) на пару b_hi и b_lo - старшие и младшие 16 бит числа, если представлять его его в виде int32_t. А вместо миллиона тогда будет константа 1 << 16. В этом случае каждое произведение будет не более, чем 48-битным, то есть без потерь может быть вычислено с помощью произведения в double.

Полноценная векторизация

Введём вспомогательную функцию mod, тело которой нам уже должно быть знакомо:

__m256d mod(__m256d a, __m256d m, __m256d m_inv) {
    __m256d tmp = _mm256_mul_pd(a, m_inv);
    tmp = _mm256_floor_pd(tmp);
    tmp = _mm256_mul_pd(tmp, m);

    return _mm256_sub_pd(a, tmp);
}

Здесь всё то же самое, что раньше было в mul_mod, кроме собственно произведения. Теперь задача простая - как раз написать mul_mod, используя ранее выведенную формулу.

Смотрите внимательно, надеюсь не запутаетесь:

// Маска для извлечения 2-х младших байт для 32-битных целых.
const __m128i MASK = _mm_set1_epi32((1 << 16) - 1);

// Вектор, который содержит значения "2 в 16-й степени".
const __m256d K = _mm256_set1_pd(1 << 16);

__m256d mul_mod(__m256d a, __m256d b, __m256d m, __m256d m_inv) {
    // Конвертируем вектор [double] в вектор [int32_t].
    __m128i b_epi32 = _mm256_cvtpd_epi32(b);
    // Извлекаем старшие 2 байта через сдвиг вправо на 16 бит.
    __m128i b_epi32_hi = _mm_srli_epi32(b_epi32, 16);
    // Извлекаем младшие 2 байта через конъюнкцию с маской.
    __m128i b_epi32_lo = _mm_and_si128(b_epi32, MASK);

    // Конвертируем вектора [int32_t] обратно в [double].
    __m256d b_hi = _mm256_cvtepi32_pd(b_epi32_hi);
    __m256d b_lo = _mm256_cvtepi32_pd(b_epi32_lo);

    // tmp1 = ((a * b_hi) % m) * K;
    __m256d tmp1 = _mm256_mul_pd(a, b_hi);
    tmp1 = mod(tmp1, m, m_inv);
    tmp1 = _mm256_mul_pd(tmp1, K);

    // tmp2 = a * b_lo;
    __m256d tmp2 = _mm256_mul_pd(a, b_lo);

    // return (tmp1 + tmp2) % m;
    tmp2 = _mm256_add_pd(tmp1, tmp2);
    return mod(tmp2, m, m_inv);
}

Далее, при желании, этот код можно инлайнить куда нужно и производить в нём ручные суперскалярные оптимизации. Большого эффекта они не принесут, зато код превратят в ещё большую лапшу, так что приводить я его не буду.

Время финального теста. На этот раз я выставлю верхнюю границу проверяемого интервала в 1u << 31, чтобы проверить вообще все 31-битные целые. Вот сколько времени это заняло:

testMillerRabinInteger:
160.792 seconds
105097565 primes found

testMillerRabinVectorized:
408.16 seconds
105097565 primes found

В 2.5 раза дольше невекторизованного целочисленного алгоритма. Довольно таки позорно, но хотя-бы посчитанное количество простых чисел совпало, а значит код скорее всего верный.

Выводы

Если очень нужно, то из ручной векторизации кода можно выжать более высокую производительность. Но шансы этого становятся тем ниже, чем больше количество целочисленной арифметики в вашем коде. Всё-таки векторизация рассчитана, в первую очередь, либо на целочисленную агрегацию 32-битных чисел, либо на арифметику с плавающей точкой, и использование её не совсем по назначению (как делаю я) приведёт скорее к страданиям.

Также интересным оказалось то, насколько оператор % тормозной. То есть я и так знал, что он медленный, но не осознавал масштабов. The more you know.

Комментарии (10)


  1. ky0
    21.04.2024 17:09
    +1

    Круто! Вам бы в репозиторий prime95 зайти - может, мы сидим уже который год и бездарно тратим лишние мегаватты.


    1. ibessonov Автор
      21.04.2024 17:09

      Спасибо! Загляну, узнаю, что за репозиторий


  1. encyclopedist
    21.04.2024 17:09
    +2

    Если вы многократно делите на одно и то же число, то можно серьёзно ускорить деление без перехода на числа с плавающё точкой. Деление заменяется на умножение и сдвиги., а взятие остатка можно сделать сразу, не вычисляя частное и затем вычитая. См. https://github.com/lemire/fastmod и https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/


    1. ibessonov Автор
      21.04.2024 17:09

      Спасибо за ссылку! Я видел информацию о подобных алгоритмах, и мне они не подходят, поскольку в данном случае делитель - не константный, а зависит от тестируемого числа n.


      1. encyclopedist
        21.04.2024 17:09
        +1

        С помощью первой функции вы вычисляете множитель на основе делителя, это нужно сделать один раз для каждого n:

        uint64_t M = computeM_u32(d);
        

        B затем для каждого взятия остатка вы вызываете вторую функцию:

        fastmod_u32(a,M,d);
        


        1. ibessonov Автор
          21.04.2024 17:09
          +1

          Это здорово, нужно будет попробовать, главное сперва решить проблему 64-битного умножения. Спасибо!


  1. gchebanov
    21.04.2024 17:09
    +1

    О, я примерно тем же путём шел когда решал https://highload.fun/tasks/10

    Правда у меня в итоге векторизованная где-то в 2 раза быстрее не векторизованной, но и другие трюки использовал. Обязательно поделитесь результатом.


    1. ibessonov Автор
      21.04.2024 17:09
      +1

      С удовольствием бы посмотрел на ваш код! Мне, как видите, воображения не хватило.
      Хотя тут, наверное, можно применить и другой подход - одновременно проверять 4/8 различных чисел из входного потока. Даже не представляю, чем пользовались люди из топа таблицы лидеров.
      EDIT: невекторизованная целочисленная версия - Your best score: 248,606, это 23-е место. Если поднажать, наверное смогу стать 22-м =)
      EDIT2: интересно, какая часть этого времени ушла на fread, невыгодно же по одной записи дёргать наверное. Это вы мне интересную задачу подкинули...


      1. gchebanov
        21.04.2024 17:09
        +1

        Да, проверяю по 4 числа одновременно. Главное заранее отсекать большую часть работы (5/6) проверкой на делимость на маленькие делители, довольно красиво сделано, но видимо можно сделать лучше, я использовал для больших libdivide_u32_do_vec256, а для маленьких (2,3,5,7) сообразил трюк с _mm256_shuffle_epi8. Тоже в восторге от результатов первой тройки. Ах да, еще нужно чтение через mmap организовать, в справке от сайта есть пример.


  1. YouDontKnowMe
    21.04.2024 17:09

    Я понимаю, что речь в основном шла на векторизацию, но можно было бы упомянуть о быстром делении через сдвиги или оптимизацию монтгомери, ну или о великой оптимизации с использованием хэш-таблиц, чтобы вызывать лишь 1 свидетеля