ggml : add n_as argument to ggml_mul_mat_id
This commit is contained in:
parent
7372b62271
commit
ee8fb399aa
6 changed files with 17 additions and 14 deletions
|
@ -8244,6 +8244,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = dst->op_params[0];
|
||||
const int32_t n_as = dst->op_params[1];
|
||||
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
|
@ -8272,7 +8274,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
|
||||
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue