Added another check to find a GPU.

This commit is contained in:
Henri Vasserman 2023-05-19 00:35:46 +03:00
parent 225305d32c
commit 962e2a9cd9
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -181,13 +181,15 @@ void ggml_cl_init(void) {
char text_buffer[1024] = {0};
platform = NULL;
device = NULL;
cl_platform_id platforms[NPLAT];
cl_uint num_platforms;
err = clGetPlatformIDs(NPLAT, platforms, &num_platforms);
CL_CHECK(err, "clGetPlatformIDs");
char * GGML_OPENCL_PLATFORM = getenv("GGML_OPENCL_PLATFORM");
if (GGML_OPENCL_PLATFORM != NULL) {
cl_platform_id platforms[NPLAT];
cl_uint num_platforms;
err = clGetPlatformIDs(NPLAT, platforms, &num_platforms);
CL_CHECK(err, "clGetPlatformIDs");
unsigned plat_num;
if (sscanf(GGML_OPENCL_PLATFORM, " %u", &plat_num) == 1) {
if (plat_num >= num_platforms) {
@ -214,12 +216,12 @@ void ggml_cl_init(void) {
}
}
device = NULL;
char * GGML_OPENCL_DEVICE = getenv("GGML_OPENCL_DEVICE");
if (GGML_OPENCL_DEVICE != NULL) {
cl_device_id devices[NDEV];
cl_uint num_devices;
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, NDEV, devices, &num_devices);
CL_CHECK(err, "clGetDeviceIDs");
unsigned dev_num;
if (sscanf(GGML_OPENCL_DEVICE, " %u", &dev_num) == 1) {
@ -244,6 +246,30 @@ void ggml_cl_init(void) {
exit(1);
} else {
fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", text_buffer);
if (platform == NULL) {
err = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(&platform), &platform, NULL);
CL_CHECK(err, "clGetDeviceInfo");
}
}
}
if (platform == NULL) {
cl_device_id devices[NDEV];
cl_uint num_devices;
for (unsigned i = 0; i < num_platforms; i++) {
clGetDeviceIDs(platforms[i], CL_DEVICE_TYPE_GPU, NDEV, devices, &num_devices);
CL_CHECK(err, "clGetDeviceIDs");
if (num_devices > 0) {
platform = platforms[i];
device = devices[0];
if (num_devices > 1) {
fprintf(stderr, "ggml_opencl: platform has more than 1 GPU, selecting the first.\n");
}
fprintf(stderr, "ggml_opencl: autodetected GPU.\n");
break;
}
}
}
@ -263,25 +289,25 @@ void ggml_cl_init(void) {
}
}
CL_CHECK(err, "clCreateContextFromType");
}
if (device == NULL) {
err = clGetContextInfo(context, CL_CONTEXT_DEVICES, sizeof(&device), &device, NULL);
CL_CHECK(err, "clGetContextInfo");
if (platform == NULL) {
err = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(&platform), &platform, NULL);
CL_CHECK(err, "clGetDeviceInfo");
}
}
if (platform != NULL) {
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using platform: '%s'\n", text_buffer);
}
if (device != NULL) {
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using device: '%s'\n", text_buffer);
}
GGML_ASSERT(platform != NULL);
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using platform: '%s'\n", text_buffer);
GGML_ASSERT(device != NULL);
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using device: '%s'\n", text_buffer);
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
if (err == CL_INVALID_PROPERTY || err == CL_INVALID_VALUE) {