hotfix for 70b broadcast issues
This commit is contained in:
parent
9731682ad6
commit
48c27a9ce1
2 changed files with 10 additions and 6 deletions
|
@ -1783,13 +1783,22 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
|
|
||||||
|
|
||||||
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne12 = src1->ne[2];
|
||||||
|
const int64_t ne13 = src1->ne[3];
|
||||||
|
|
||||||
const int64_t ne0 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0];
|
||||||
const int64_t ne1 = dst->ne[1];
|
const int64_t ne1 = dst->ne[1];
|
||||||
|
|
||||||
|
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
||||||
|
// ref: https://github.com/ggerganov/ggml/pull/224
|
||||||
|
|
||||||
// TODO: find the optimal values for these
|
// TODO: find the optimal values for these
|
||||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
if (ne02 == ne12 && ne03 == ne13 &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||||
src1->type == GGML_TYPE_F32 &&
|
src1->type == GGML_TYPE_F32 &&
|
||||||
dst->type == GGML_TYPE_F32 &&
|
dst->type == GGML_TYPE_F32 &&
|
||||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
|
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
|
||||||
|
|
5
ggml.c
5
ggml.c
|
@ -10423,11 +10423,6 @@ static void ggml_compute_forward_mul_mat(
|
||||||
|
|
||||||
#if defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||||
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
|
||||||
// ref: https://github.com/ggerganov/ggml/pull/224
|
|
||||||
GGML_ASSERT(ne02 == ne12);
|
|
||||||
GGML_ASSERT(ne03 == ne13);
|
|
||||||
|
|
||||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue