__dp4a -> ggml_cuda_dp4a

This commit is contained in:
Johannes Gäßler 2024-06-30 20:20:54 +02:00
parent 0480dab44a
commit a92595aa93
3 changed files with 96 additions and 82 deletions

View file

@ -3,6 +3,7 @@
#include "ggml.h"
#include "ggml-cuda.h"
#include <cstdint>
#include <memory>
#if defined(GGML_USE_HIPBLAS)
@ -280,33 +281,6 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
return c;
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
v_add3_u32 %0, %1, %2, %0 \n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
v_add3_u32 %0, %1, %2, %0 \n \
"
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
: "v"(a), "v"(b)
);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
}
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
// __shfl_xor() for half2 was added in ROCm 5.6
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
@ -479,6 +453,46 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
}
#endif // CUDART_VERSION < 12000
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
v_add3_u32 %0, %1, %2, %0 \n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
v_add3_u32 %0, %1, %2, %0 \n \
"
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
: "v"(a), "v"(b)
);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= MIN_CC_DP4A
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= MIN_CC_DP4A
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
// TODO: move to ggml-common.h
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

View file

@ -72,7 +72,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
const int sumi = ggml_cuda_dp4a(v, u, 0);
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
@ -120,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
const int sumi = ggml_cuda_dp4a(v, u, 0);
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
const int sumi = ggml_cuda_dp4a(v, u, 0);
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
@ -234,7 +234,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
const int sumi = ggml_cuda_dp4a(v, u, 0);
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {

View file

@ -60,8 +60,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
}
const float2 ds8f = __half22float2(ds8);
@ -88,8 +88,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
}
#ifdef GGML_CUDA_F16
@ -126,14 +126,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}
const float2 ds8f = __half22float2(ds8);
@ -161,14 +161,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}
#ifdef GGML_CUDA_F16
@ -202,7 +202,7 @@ template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
}
return d8_0*d8_1 * ((T) sumi);
@ -220,7 +220,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
}
#ifdef GGML_CUDA_F16
@ -259,13 +259,13 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
const int vi = (v >> (2*i)) & 0x03030303;
sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
// fill int with 4x m
int m = sc >> 4;
m |= m << 8;
m |= m << 16;
sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
}
const float2 dm2f = __half22float2(dm2);
@ -294,8 +294,8 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_m = __dp4a(0x01010101, u[i], sumi_m);
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
}
sumf_d += dm2f.x * sumi_d;
@ -339,7 +339,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
const int vi = __vsubss4(vil, vih);
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d3 * sumf;
@ -363,7 +363,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
}
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -392,8 +392,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
@ -423,7 +423,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
@ -464,8 +464,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int v0i = vl0i | vh0i;
const int v1i = vl1i | vh1i;
const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]);
@ -496,7 +496,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
@ -535,7 +535,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d*sumf;
@ -558,11 +558,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + 2; ++i) {
sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
}
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
@ -857,12 +857,12 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
sumi = __dp4a(grid0, u0, sumi);
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
sumi = __dp4a(grid1, u1, sumi);
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
}
const int ls = aux32 >> 28;
@ -900,11 +900,11 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
if (l0 < 4) {
sumi0 = __dp4a(grid_l, u0, sumi0);
sumi0 = __dp4a(grid_h, u1, sumi0);
sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
} else {
sumi1 = __dp4a(grid_l, u0, sumi1);
sumi1 = __dp4a(grid_h, u1, sumi1);
sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
}
}
const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
@ -950,11 +950,11 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
if (l0 < 4) {
sumi0 = __dp4a(grid_l, u0, sumi0);
sumi0 = __dp4a(grid_h, u1, sumi0);
sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
} else {
sumi1 = __dp4a(grid_l, u0, sumi1);
sumi1 = __dp4a(grid_h, u1, sumi1);
sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
}
}
const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
@ -991,8 +991,8 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
sumi = __dp4a(grid_l, u0, sumi);
sumi = __dp4a(grid_h, u1, sumi);
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
}
const int ls = aux32 >> 28;
@ -1036,8 +1036,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
sumi = __dp4a(grid_l, u0, sumi);
sumi = __dp4a(grid_h, u1, sumi);
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
}
sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
@ -1069,8 +1069,8 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
sumi = __dp4a(grid0, u0, sumi);
sumi = __dp4a(grid1, u1, sumi);
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
}
const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
@ -1104,13 +1104,13 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]);
sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]);
sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);
sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);
const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
int sumy = 0;
sumy = __dp4a(u0, 0x01010101, sumy);
sumy = __dp4a(u1, 0x01010101, sumy);
sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);
sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);
sumf[l0/4] += delta*sumy;
}
@ -1160,8 +1160,8 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4);
sumi = __dp4a(v.x, q8[l + 0], sumi);
sumi = __dp4a(v.y, q8[l + 4], sumi);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
}
const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
@ -1187,8 +1187,8 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
sumi = __dp4a(v.x, u0, sumi);
sumi = __dp4a(v.y, u1, sumi);
sumi = ggml_cuda_dp4a(v.x, u0, sumi);
sumi = ggml_cuda_dp4a(v.y, u1, sumi);
}
const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);