Vec load quantized values

This commit is contained in:
Aidan 2024-07-02 12:38:17 +01:00
parent 6c7c937927
commit 504a47abf8
2 changed files with 9 additions and 4 deletions

View file

@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x,
return x;
}
// Helper for vec loading aligned data
template <typename Tp, int n>
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
}
#endif // GGML_SYCL_COMMON_HPP

View file

@ -327,8 +327,6 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
scales_local[tid] = x[i].scales[tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
const uint8_t * q = x[i].qs + 32*il + n*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc;
@ -337,9 +335,10 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
const float d2 = dall * sc;
const float m2 = dmin * m;
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
}
#else
const int tid = item_ct1.get_local_id(2);