now then?

This commit is contained in:
Henri Vasserman 2023-04-28 18:50:07 +03:00
parent 759510534c
commit f19ee3b2ec
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

14
ggml.c
View file

@ -8239,6 +8239,10 @@ static void ggml_compute_forward_mul_mat_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
#if !defined(GGML_USE_CUBLAS)
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#endif
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
@ -8256,10 +8260,7 @@ static void ggml_compute_forward_mul_mat_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#else
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#if defined(GGML_USE_CLBLAST)
#elif defined(GGML_USE_CLBLAST)
// zT = y * xT
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
ne11, ne01, ne10,
@ -8273,7 +8274,6 @@ static void ggml_compute_forward_mul_mat_f32(
1.0f, y, ne10,
x, ne00,
0.0f, d, ne01);
#endif
#endif
}
}
@ -8468,7 +8468,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
#endif
#if defined(GGML_USE_CUBLAS)
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@ -8735,6 +8734,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
@ -8745,7 +8746,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
CUDA_CHECK(cudaGetLastError());
#elif defined(GGML_USE_CLBLAST)
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#else
{
size_t id = 0;