sycl: fix convert and dequantize

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>
This commit is contained in:
zhentaoyu 2024-08-12 09:19:03 +00:00
parent 9a9f7c959c
commit 3ecfbcfaf1
2 changed files with 10 additions and 10 deletions

View file

@ -29,7 +29,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_sycl(const void *__restrict__ vx, static void dequantize_block_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int64_t k, dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -435,7 +435,7 @@ template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx, static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int64_t k, dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
// decrease global range when it exceeds the max int // decrease global range when it exceeds the max int
int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE; int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE;

View file

@ -15,9 +15,9 @@
#include "common.hpp" #include "common.hpp"
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q5_0 * x = (const block_q5_0 *) vx; const block_q5_0 * x = (const block_q5_0 *) vx;
@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q5_1 * x = (const block_q5_1 *) vx; const block_q5_1 * x = (const block_q5_1 *) vx;
@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q8_0 * x = (const block_q8_0 *) vx; const block_q8_0 * x = (const block_q8_0 *) vx;
@ -138,7 +138,7 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int64_t i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
@ -168,7 +168,7 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int64_t i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);