ggml : add n_as argument to ggml_mul_mat_id

This commit is contained in:
slaren 2023-12-09 12:42:25 +01:00
parent 7372b62271
commit ee8fb399aa
6 changed files with 17 additions and 14 deletions

View file

@ -8244,6 +8244,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
const struct ggml_tensor * ids = src0;
const int32_t id = dst->op_params[0];
const int32_t n_as = dst->op_params[1];
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
std::vector<char> ids_host(ggml_nbytes(ids));
@ -8272,7 +8274,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0_row = dst->src[row_id + 2];