ggml : update ggml_nbytes() to handle non-contiguous tensors

This commit is contained in:
Georgi Gerganov 2023-06-01 21:27:03 +03:00
parent 17930fbcb7
commit f67c2d8cab
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

9
ggml.c
View file

@ -3732,7 +3732,14 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]; // this should handle cases where the tensor is not contiguous in memory
// probaby just:
//
// return tensor->ne[3]*tensor->nb[3]
//
// is enough, but just in case, adding the second part
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
} }
int ggml_blck_size(enum ggml_type type) { int ggml_blck_size(enum ggml_type type) {