Branchless max()

November 2, 2020
cpp

Recently a coworker told me about a question he’d gotten in an interview: write max() without branches.

Branchless programming strikes me as a similar “game” to code golf (writing a program in as few lines or characters as possible), though occasionally it has real utility. For example, if the branchless program is a non-obvious-enough transformation, it could out-perform a “branchy” implementation in a heavily-pipelined CPU.

“Branchy” max

Let’s start with the trivial, “branchy” implementation:

int branchy_max(int a, int b) {
    return a > b ? a : b;
}

Nothing to it, right?

An approximate branchless implementation

For simple enough programs that have conditional computations, the general strategy of branchless programming is to create a value that encodes the condition as either a 1 or 0, then multiply it with the input values in a way that computes the result in both cases of the condition.

The condition in this case is a > b. Notice that transforming this into a - b gives a positive value if the condition is true and negative otherwise. We can detect which it is by masking the sign bit of the difference, which is only set if it’s negative:

auto difference = a - b;
constexpr auto sign_bit = sizeof(decltype(difference)) * 8 - 1;
int b_greater_than_a = (difference >> sign_bit) & 1;

Now we need to find an expression that evaluates to b if the difference is negative, and a if it is positive.

Adding the conditional difference to b achieves this: a - (a - b) * b_greater_than_a. If the difference is negative, we have a - (a - b) * 1 == b, and a - (a - b) * 0 == a otherwise.

Putting it together:

int branchless_max(int a, int b) {
    const auto difference = a - b;
    constexpr auto sign_bit = sizeof(decltype(difference)) * 8 - 1;
    const int b_greater_than_a = (difference >> sign_bit) & 1;
    return a - difference * b_greater_than_a;
}

Accounting for overflow & underflow

For the purpose of a thought exercise this is good. But it’s not quite correct across its input range: a - b could underflow, say if a is negative and b is a large positive number, or conversely overflow if b is a large negative and a is positive.

We can detect an underflow if a is negative and b is positive, but difference is positive; similarly, overflow happens iff a is positive and b is negative, but difference is negative. Then, we can adjust the result to force it to the true max input:

int branchless_max_overunder(int a, int b) {
    // Helper to avoid excessive use of '>>'
    constexpr auto is_negative = [](auto v) {
        constexpr auto sign_bit = sizeof(decltype(v)) * 8 - 1;
        return (v >> sign_bit) & 1;
    };

    auto difference = a - b;
    int b_greater_than_a = is_negative(difference);

    int underflow = is_negative(a) & !is_negative(b) & !b_greater_than_a;
    int overflow = (!is_negative(a)) & is_negative(b) & b_greater_than_a;
    int underflow_or_overflow = underflow | overflow;

    return (a - difference * b_greater_than_a) * !underflow_or_overflow
           + underflow * b + overflow * a;
}

Genericizing & C++20

As a bonus, we can now generalize the parameters and return type for any signed integral type using C++20 concepts:

#include <concepts>

template <std::signed_integral T, std::signed_integral U>
auto generic_branchless_max(T a, U b) {
    // Helper to avoid excessive use of '>>'
    constexpr auto is_negative = [](auto v) {
        constexpr auto sign_bit = sizeof(decltype(v)) * 8 - 1;
        return (v >> sign_bit) & 1;
    };

    auto difference = a - b;
    int b_greater_than_a = is_negative(difference);

    int underflow = is_negative(a) & !is_negative(b) & !b_greater_than_a;
    int overflow = (!is_negative(a)) & is_negative(b) & b_greater_than_a;
    int underflow_or_overflow = underflow | overflow;

    return (a - difference * b_greater_than_a) * !underflow_or_overflow
           + underflow * b + overflow * a;
}

The auto return type is useful here because it allows un-bundling the two arguments into separate template parameter types, so that things like generic_branchless_max(2LL, 2) work, with the return type doing the necessary promotion and sign-extension.

Results

I absolutely do not recommend writing code like this. It’s convoluted and will almost certainly generate worse optimized code. This was our original branchless implementation under LLVM’s -O3:

branchless_max(int, int):                   # @branchless_max(int, int)
        mov     eax, edi
        mov     ecx, edi
        sub     ecx, esi
        mov     edx, ecx
        sar     edx, 31
        and     edx, ecx
        sub     eax, edx
        ret

The code generated for taking into account overflow and underflow is much worse. This is the instantiation on two 64-bit integer types, but the assembly is identical (up to bit shifts) with other sized parameters:

auto generic_branchless_max<long long, long long>(long long, long long):  # @auto generic_branchless_max<long long, long long>(long long, long long)
        mov     rax, rdi
        sub     rax, rsi
        mov     r8, rsi
        not     r8
        and     r8, rdi
        mov     rdx, rdi
        not     rdx
        and     rdx, rsi
        and     rdx, rax
        mov     r9, rax
        sar     r9, 63
        and     r9, rax
        not     rax
        and     rax, r8
        mov     rcx, rax
        or      rcx, rdx
        sar     rdx, 63
        and     rdx, rdi
        sub     rdi, r9
        sar     rcx, 63
        not     rcx
        and     rcx, rdi
        sar     rax, 63
        and     rax, rsi
        add     rax, rdx
        add     rax, rcx
        ret

And here is our original, branchy implementation (a > b ? a : b):

max(int, int):                               # @max(int, int)
        mov     eax, esi
        cmp     edi, esi
        cmovg   eax, edi
        ret

Note that the assembly doesn’t actually take any branches: cmov (as a generic x86 instruction) is a conditional move by itself.