de-duplicate ggml_forward_dup code taking care of contiguous tensors of same type.
with this we can duplicate tensor of any typ as long as they are contiguous.
This commit is contained in:
parent
38675e537c
commit
c1a8893de3
1 changed files with 36 additions and 22 deletions
58
ggml.c
58
ggml.c
|
@ -6638,6 +6638,36 @@ void ggml_set_param(
|
|||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
static void ggml_compute_forward_dup_same_cont(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb0 = dst->nb[0];
|
||||
|
||||
const int ith = params->ith; // thread index
|
||||
const int nth = params->nth; // number of threads
|
||||
|
||||
// parallelize by elements
|
||||
const int ne = ggml_nelements(dst);
|
||||
const int dr = (ne + nth - 1) / nth;
|
||||
const int ie0 = dr * ith;
|
||||
const int ie1 = MIN(ie0 + dr, ne);
|
||||
|
||||
memcpy(
|
||||
((char *) dst->data + ie0*nb0),
|
||||
((char *) src0->data + ie0*nb00),
|
||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
||||
|
||||
}
|
||||
static void ggml_compute_forward_dup_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
|
@ -6672,17 +6702,7 @@ static void ggml_compute_forward_dup_f16(
|
|||
const int nth = params->nth; // number of threads
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||
// parallelize by elements
|
||||
const int ne = ggml_nelements(dst);
|
||||
const int dr = (ne + nth - 1) / nth;
|
||||
const int ie0 = dr * ith;
|
||||
const int ie1 = MIN(ie0 + dr, ne);
|
||||
|
||||
memcpy(
|
||||
((char *) dst->data + ie0*nb0),
|
||||
((char *) src0->data + ie0*nb00),
|
||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -6971,17 +6991,7 @@ static void ggml_compute_forward_dup_f32(
|
|||
const int nth = params->nth; // number of threads
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||
// parallelize by elements
|
||||
const int ne = ggml_nelements(dst);
|
||||
const int dr = (ne + nth - 1) / nth;
|
||||
const int ie0 = dr * ith;
|
||||
const int ie1 = MIN(ie0 + dr, ne);
|
||||
|
||||
memcpy(
|
||||
((char *) dst->data + ie0*nb0),
|
||||
((char *) src0->data + ie0*nb00),
|
||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -7236,6 +7246,10 @@ static void ggml_compute_forward_dup(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
return;
|
||||
}
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue