←back to thread

352 points ferriswil | 1 comments | | HN request time: 0s | source
Show context
jart ◴[] No.41893178[source]
It's a very crude approximation, e.g. 1.75 * 2.5 == 3 (although it seems better as the numbers get closer to 0).

I tried implementing this for AVX512 with tinyBLAS in llamafile.

    inline __m512 lmul512(__m512 x, __m512 y) {
        __m512i sign_mask = _mm512_set1_epi32(0x80000000);
        __m512i exp_mask = _mm512_set1_epi32(0x7F800000);
        __m512i mant_mask = _mm512_set1_epi32(0x007FFFFF);
        __m512i exp_bias = _mm512_set1_epi32(127);
        __m512i x_bits = _mm512_castps_si512(x);
        __m512i y_bits = _mm512_castps_si512(y);
        __m512i sign_x = _mm512_and_si512(x_bits, sign_mask);
        __m512i sign_y = _mm512_and_si512(y_bits, sign_mask);
        __m512i exp_x = _mm512_srli_epi32(_mm512_and_si512(x_bits, exp_mask), 23);
        __m512i exp_y = _mm512_srli_epi32(_mm512_and_si512(y_bits, exp_mask), 23);
        __m512i mant_x = _mm512_and_si512(x_bits, mant_mask);
        __m512i mant_y = _mm512_and_si512(y_bits, mant_mask);
        __m512i sign_result = _mm512_xor_si512(sign_x, sign_y);
        __m512i exp_result = _mm512_sub_epi32(_mm512_add_epi32(exp_x, exp_y), exp_bias);
        __m512i mant_result = _mm512_srli_epi32(_mm512_add_epi32(mant_x, mant_y), 1);
        __m512i result_bits = _mm512_or_si512(
            _mm512_or_si512(sign_result, _mm512_slli_epi32(exp_result, 23)), mant_result);
        return _mm512_castsi512_ps(result_bits);
    }
Then I used it for Llama-3.2-3B-Instruct.F16.gguf and it outputted jibberish. So you would probably have to train and design your model specifically to use this multiplication approximation in order for it to work. Or maybe I'd have to tune the model so that only certain layers and/or operations use the approximation. However the speed was decent. Prefill only dropped from 850 tokens per second to 200 tok/sec on my threadripper. Prediction speed was totally unaffected, staying at 34 tok/sec. I like how the code above generates vpternlog ops. So if anyone ever designs an LLM architecture and releases weights on Hugging Face that use this algorithm, we'll be able to run them reasonably fast without special hardware.
replies(1): >>41893810 #
raluk ◴[] No.41893810[source]
Your kernel seems to be incorrect for 1.75 * 2.5. From paper we have 1.75 == (1+0.75)*2^0 for 2.5 == (1+0.25)*2^1 so result is (1+0.75+0.25+2^-4)*2^1 == 4.125 (correct result is 4.375)
replies(1): >>41894250 #
1. raluk ◴[] No.41894250[source]
Extra. I am not sure if that is clear from paper, but in example of 1.75 * 2.5 we can represent 1.75 also as (1-0.125) * 2. This gives good aproximations for numbers that are close but less than power of 2. This way abs(a*b) in (1+a)*(1+b) is allways small and strictly less than 0.25.

Another example, if we have for example 1.9 * 1.9 then we need to account for overflow in (0.9 + 0.9) and this seems to induce similar overhead as expressing numbers as (1-0.05)*2 .