remove writing to cl_buffer_c and change it to a writeonly buffer - should work since beta is always zero.

This commit is contained in:
Concedo 2023-04-22 23:19:17 +08:00
parent cd6c121357
commit eb73b4c261

View file

@ -242,7 +242,7 @@ static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS
{ {
cl_size_c = m*n*sizeof(float); cl_size_c = m*n*sizeof(float);
clReleaseMemObject(cl_buffer_c); clReleaseMemObject(cl_buffer_c);
cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, &err); cl_buffer_c = clCreateBuffer(context, CL_MEM_WRITE_ONLY, cl_size_c, NULL, &err);
if (err != CL_SUCCESS) { if (err != CL_SUCCESS) {
printf("Error creating OpenCL Buffer C: %d\n", err); printf("Error creating OpenCL Buffer C: %d\n", err);
fflush(stdout); fflush(stdout);
@ -263,20 +263,20 @@ static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS
} }
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events);
clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); //clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2);
if (dequant) { if (dequant) {
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 2);
if(err < 0) { if(err < 0) {
printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); printf("Error enqueueing OpenCL dequantize kernel: %d\n", err);
fflush(stdout); fflush(stdout);
} }
} }
clWaitForEvents(dequant ? 4 : 3, events); clWaitForEvents(dequant ? 3 : 2, events);
clReleaseEvent(events[0]); clReleaseEvent(events[0]);
clReleaseEvent(events[1]); clReleaseEvent(events[1]);
clReleaseEvent(events[2]); //clReleaseEvent(events[2]);
if (dequant) { if (dequant) {
clReleaseEvent(events[3]); clReleaseEvent(events[2]);
} }
// Call the SGEMM routine. // Call the SGEMM routine.