metal : fix out-of-bounds access + style changes

This commit is contained in:
Georgi Gerganov 2023-07-27 10:10:51 +03:00
parent b5472ea0ad
commit 1038d1d2bc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -7,6 +7,11 @@
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#ifdef GGML_METAL_NDEBUG
#define metal_printf(...)
#else
@ -372,8 +377,8 @@ void ggml_metal_graph_find_concurrency(
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES];
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
for (int i = 0; i < GGML_MAX_NODES; i++) { ctx->concur_list[i] = 0; }
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
ctx->concur_list_len = 0;
int n_left = gf->n_nodes;
@ -386,21 +391,33 @@ void ggml_metal_graph_find_concurrency(
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
if (nodes_unused[i]) {
// if the requirements for gf->nodes[i] are satisfied
int exe_flag=1;
int exe_flag = 1;
// scan all srcs
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
if (src_cur) {
// if is leaf nodes it's satisfied.
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
// TODO: ggml_is_leaf()
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
continue;
}
// otherwise this src should be the output from previous nodes.
int is_found = 0;
// scan 2*search_depth back because we inserted barrier.
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
is_found = 1;
break;
}
}
if (is_found == 0) {
exe_flag = 0;
break;
}
if (is_found == 0) {exe_flag = 0; break;}
}
}
if (exe_flag) {
@ -416,9 +433,9 @@ void ggml_metal_graph_find_concurrency(
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
} else {
exe_flag = 0;
}
exe_flag = 0;
}
}
}
@ -435,7 +452,9 @@ void ggml_metal_graph_find_concurrency(
ctx->concur_list[level_pos + concurrency] = -1;
ctx->concur_list_len++;
// jump all sorted nodes at nodes_bak
while (!nodes_unused[n_start]) {n_start++;}
while (!nodes_unused[n_start]) {
n_start++;
}
level_pos += concurrency + 1;
}