diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 590c1f6ac..a22da55a3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1617,6 +1617,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values } + return d*sumf_d - dmin*sumf_m; + #else uint16_t aux16[2]; @@ -1626,7 +1628,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( aux16[0] = a[0] & 0x0f0f; aux16[1] = (a[0] >> 4) & 0x0f0f; - const float d = bq4_K->d[0]; + const float dall = bq4_K->d[0]; const float dmin = bq4_K->d[1]; const float d8_1 = bq8_1[0].d; @@ -1649,9 +1651,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + return dall * sumf_d - dmin * sumf_m; + #endif - return d*sumf_d - dmin*sumf_m; #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1743,10 +1746,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3 const int in = step%8; // 0, 4, 0, 4 const int vh = (*((const int *)(bq5_K->qh + in))) >> im; - const int v1 = __vsub4(((vh << 4) & 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f), 0x10101010); - const int v2 = __vsub4(((vh << 2) & 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f), 0x10101010); - const int v3 = __vsub4(((vh >> 0) & 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f), 0x10101010); - const int v4 = __vsub4(((vh >> 2) & 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f), 0x10101010); + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);