sycl: fix ib in dmmv
Signed-off-by: zhentaoyu <zhentao.yu@intel.com>
This commit is contained in:
parent
3ecfbcfaf1
commit
bd960a67dc
3 changed files with 5 additions and 5 deletions
|
@ -421,7 +421,7 @@ template <typename src_t, typename dst_t>
|
|||
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + item_ct1.get_group(2) * work_group_size;
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
|
||||
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "presets.hpp"
|
||||
|
||||
|
||||
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const sycl::half *x = (const sycl::half *)vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
|
@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
|
|||
v.y() = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const float * x = (const float *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
|
|
|
@ -19,7 +19,7 @@ static void im2col_kernel(
|
|||
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + item_ct1.get_group(2) * work_group_size;
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
|
||||
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||
|
@ -95,7 +95,7 @@ void ggml_sycl_op_im2col(
|
|||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue