rewrite platform and device selection

This commit is contained in:
Henri Vasserman 2023-05-13 22:04:46 +03:00
parent bb5c3e2c70
commit b8fb5cdf5c
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -143,7 +143,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
do { \
cl_int err_ = (err); \
if (err_ != CL_SUCCESS) { \
fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
@ -152,6 +152,7 @@ static cl_platform_id platform;
static cl_device_id device;
static cl_context context;
static cl_command_queue queue;
static cl_bool out_of_order;
static cl_program program;
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
@ -188,34 +189,122 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
void ggml_cl_init(void) {
cl_int err = 0;
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
printf("\nInitializing CLBlast (First Run)...");
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
cl_uint num_platforms;
clGetPlatformIDs(0, NULL, &num_platforms);
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
clGetPlatformIDs(num_platforms, platforms, NULL);
platform = platforms[plat_num];
char platform_buffer[1024];
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
cl_uint num_devices;
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
device = devices[dev_num];
char device_buffer[1024];
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
CL_CHECK(err, "clCreateContext");
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
CL_CHECK(err, "clCreateCommandQueue");
free(platforms);
free(devices);
enum { NPLAT = 16, NDEV = 16 };
char text_buffer[1024] = {0};
platform = NULL;
char * GGML_OPENCL_PLATFORM = getenv("GGML_OPENCL_PLATFORM");
if (GGML_OPENCL_PLATFORM != NULL) {
cl_platform_id platforms[NPLAT];
cl_uint num_platforms;
err = clGetPlatformIDs(NPLAT, platforms, &num_platforms);
CL_CHECK(err, "clGetPlatformIDs");
unsigned plat_num;
if (sscanf(GGML_OPENCL_PLATFORM, " %u", &plat_num) == 1) {
if (plat_num >= num_platforms) {
fprintf(stderr, "ggml_opencl: There is no platform %d\n", plat_num);
exit(1);
} else {
platform = platforms[plat_num];
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
}
} else {
for (unsigned i = 0; i < num_platforms; i++) {
clGetPlatformInfo(platforms[i], CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
if (strstr(text_buffer, GGML_OPENCL_PLATFORM) != NULL) {
platform = platforms[i];
break;
}
}
}
if (platform == NULL) {
fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", GGML_OPENCL_PLATFORM);
exit(1);
} else {
fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", text_buffer);
}
}
text_buffer[0] = 0;
device = NULL;
char * GGML_OPENCL_DEVICE = getenv("GGML_OPENCL_DEVICE");
if (GGML_OPENCL_DEVICE != NULL) {
cl_device_id devices[16];
cl_uint num_devices;
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, NDEV, devices, &num_devices);
unsigned dev_num;
if (sscanf(GGML_OPENCL_DEVICE, " %u", &dev_num) == 1) {
if (dev_num >= num_devices) {
fprintf(stderr, "ggml_opencl: There is no device %d\n", dev_num);
exit(1);
} else {
device = devices[dev_num];
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
}
} else {
for (unsigned i = 0; i < num_devices; i++) {
clGetDeviceInfo(devices[i], CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
if (strstr(text_buffer, GGML_OPENCL_DEVICE) != NULL) {
device = devices[i];
break;
}
}
}
if (device == NULL) {
fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", GGML_OPENCL_DEVICE);
exit(1);
} else {
fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", text_buffer);
}
}
cl_context_properties *properties = platform == NULL ? NULL : (cl_context_properties[]){
(intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
};
if (device != NULL) {
context = clCreateContext(properties, 1, &device, NULL, NULL, &err);
CL_CHECK(err, "clCreateContext");
} else {
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_GPU, NULL, NULL, &err);
if (err == CL_DEVICE_NOT_AVAILABLE || err == CL_DEVICE_NOT_FOUND) {
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_DEFAULT, NULL, NULL, &err);
if (err == CL_DEVICE_NOT_AVAILABLE || err == CL_DEVICE_NOT_FOUND) {
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_ALL, NULL, NULL, &err);
}
}
CL_CHECK(err, "clCreateContextFromType");
}
if (device == NULL) {
err = clGetContextInfo(context, CL_CONTEXT_DEVICES, sizeof(&device), &device, NULL);
CL_CHECK(err, "clGetContextInfo");
if (platform == NULL) {
err = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(&platform), &platform, NULL);
CL_CHECK(err, "clGetDeviceInfo");
}
}
if (platform != NULL) {
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using platform: '%s'\n", text_buffer);
}
if (device != NULL) {
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
fprintf(stderr, "ggml_opencl: using device: '%s'\n", text_buffer);
}
out_of_order = CL_TRUE;
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
if (err == CL_INVALID_PROPERTY) {
out_of_order = CL_FALSE;
queue = clCreateCommandQueue(context, device, 0, &err);
}
CL_CHECK(err, "clCreateCommandQueue");
program = build_program_from_source(context, device, clblast_dequant);