de-duplicate ggml_forward_dup code taking care of contiguous tensors of same type.
with this we can duplicate tensor of any typ as long as they are contiguous.
This commit is contained in:
parent
38675e537c
commit
c1a8893de3
1 changed files with 36 additions and 22 deletions
58
ggml.c
58
ggml.c
|
@ -6638,6 +6638,36 @@ void ggml_set_param(
|
||||||
|
|
||||||
// ggml_compute_forward_dup
|
// ggml_compute_forward_dup
|
||||||
|
|
||||||
|
static void ggml_compute_forward_dup_same_cont(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nb00 = src0->nb[0];
|
||||||
|
const size_t nb0 = dst->nb[0];
|
||||||
|
|
||||||
|
const int ith = params->ith; // thread index
|
||||||
|
const int nth = params->nth; // number of threads
|
||||||
|
|
||||||
|
// parallelize by elements
|
||||||
|
const int ne = ggml_nelements(dst);
|
||||||
|
const int dr = (ne + nth - 1) / nth;
|
||||||
|
const int ie0 = dr * ith;
|
||||||
|
const int ie1 = MIN(ie0 + dr, ne);
|
||||||
|
|
||||||
|
memcpy(
|
||||||
|
((char *) dst->data + ie0*nb0),
|
||||||
|
((char *) src0->data + ie0*nb00),
|
||||||
|
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
||||||
|
|
||||||
|
}
|
||||||
static void ggml_compute_forward_dup_f16(
|
static void ggml_compute_forward_dup_f16(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
@ -6672,17 +6702,7 @@ static void ggml_compute_forward_dup_f16(
|
||||||
const int nth = params->nth; // number of threads
|
const int nth = params->nth; // number of threads
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||||
// parallelize by elements
|
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||||
const int ne = ggml_nelements(dst);
|
|
||||||
const int dr = (ne + nth - 1) / nth;
|
|
||||||
const int ie0 = dr * ith;
|
|
||||||
const int ie1 = MIN(ie0 + dr, ne);
|
|
||||||
|
|
||||||
memcpy(
|
|
||||||
((char *) dst->data + ie0*nb0),
|
|
||||||
((char *) src0->data + ie0*nb00),
|
|
||||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6971,17 +6991,7 @@ static void ggml_compute_forward_dup_f32(
|
||||||
const int nth = params->nth; // number of threads
|
const int nth = params->nth; // number of threads
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||||
// parallelize by elements
|
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||||
const int ne = ggml_nelements(dst);
|
|
||||||
const int dr = (ne + nth - 1) / nth;
|
|
||||||
const int ie0 = dr * ith;
|
|
||||||
const int ie1 = MIN(ie0 + dr, ne);
|
|
||||||
|
|
||||||
memcpy(
|
|
||||||
((char *) dst->data + ie0*nb0),
|
|
||||||
((char *) src0->data + ie0*nb00),
|
|
||||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7236,6 +7246,10 @@ static void ggml_compute_forward_dup(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||||
|
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue