ggml : ggml_mul better broadcast support
This commit is contained in:
parent
f67bc3c363
commit
3ec7941bad
2 changed files with 60 additions and 56 deletions
84
ggml.c
84
ggml.c
|
@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
|
|||
(t1->ne[3]%t0->ne[3] == 0);
|
||||
}
|
||||
|
||||
static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
|
||||
}
|
||||
|
||||
static inline int ggml_up32(int n) {
|
||||
return (n + 31) & ~31;
|
||||
}
|
||||
|
@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0] && ggml_can_repeat(b, a));
|
||||
// TODO: support less-strict constraint
|
||||
// GGML_ASSERT(ggml_can_repeat(b, a));
|
||||
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad || b->grad)) {
|
||||
// TODO: support backward pass for broadcasting
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
|
@ -7960,22 +7970,14 @@ static void ggml_compute_forward_mul_f32(
|
|||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
const int nr = ggml_nrows(src0);
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src0->ne[3];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10 && ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (src1->backend == GGML_BACKEND_CUDA) {
|
||||
if (ith == 0) {
|
||||
|
@ -7985,6 +7987,17 @@ static void ggml_compute_forward_mul_f32(
|
|||
}
|
||||
#endif
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
|
@ -8002,47 +8015,50 @@ static void ggml_compute_forward_mul_f32(
|
|||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
if (nb10 == sizeof(float) && ggml_are_same_shape(src0, src1)) {
|
||||
for (int ir = ith; ir < nr; ir += nth) {
|
||||
// src0, src1 and dst are same shape => same indices
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||
// src0 and dst are same shape => same indices
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
UNUSED(ggml_vec_mul_f32);
|
||||
|
||||
vDSP_vmul(
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
|
||||
ne00);
|
||||
vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
|
||||
#else
|
||||
ggml_vec_mul_f32(ne00,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
||||
ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
|
||||
#endif
|
||||
// }
|
||||
// }
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
for (int ir = ith; ir < nr; ir += nth) {
|
||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||
// src0 and dst are same shape => same indices
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int i03 = ir/(ne02*ne01);
|
||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
const int i13 = i03 % ne13;
|
||||
const int i12 = i02 % ne12;
|
||||
const int i11 = i01 % ne11;
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
for (int i0 = 0; i0 < ne00; i0++) {
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
|
||||
|
||||
dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
|
||||
|
|
32
llama.cpp
32
llama.cpp
|
@ -1000,6 +1000,12 @@ static void llama_model_load_internal(
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA
|
||||
#else
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
|
||||
#endif
|
||||
|
||||
// prepare memory for the weights
|
||||
size_t vram_total = 0;
|
||||
{
|
||||
|
@ -1016,19 +1022,19 @@ static void llama_model_load_internal(
|
|||
{
|
||||
ggml_backend backend_output;
|
||||
if (n_gpu_layers > int(n_layer)) {
|
||||
backend_output = GGML_BACKEND_CUDA;
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD;
|
||||
} else {
|
||||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
|
||||
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
|
||||
}
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA;
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
|
@ -1047,7 +1053,7 @@ static void llama_model_load_internal(
|
|||
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend);
|
||||
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend);
|
||||
|
||||
if (backend == GGML_BACKEND_CUDA) {
|
||||
if (backend == LLAMA_BACKEND_OFFLOAD) {
|
||||
vram_total +=
|
||||
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
|
||||
|
@ -1213,13 +1219,7 @@ static bool llama_eval_internal(
|
|||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// cur = cur*attention_norm(broadcasted)
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
|
||||
#else
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
|
||||
cur);
|
||||
#endif
|
||||
}
|
||||
|
||||
// self-attention
|
||||
|
@ -1327,13 +1327,7 @@ static bool llama_eval_internal(
|
|||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
#else
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
|
||||
cur);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
|
@ -1371,13 +1365,7 @@ static bool llama_eval_internal(
|
|||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// inpL = inpL*norm(broadcasted)
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
inpL = ggml_mul(ctx0, inpL, model.norm);
|
||||
#else
|
||||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm, inpL),
|
||||
inpL);
|
||||
#endif
|
||||
|
||||
embeddings = inpL;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue