ggml : do not dereference src0 if NULL

This commit is contained in:
Georgi Gerganov 2023-06-24 11:36:10 +03:00
parent 4e3b9f2f9c
commit c2ccd541e9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 5 additions and 11 deletions

View file

@ -2542,7 +2542,7 @@ void ggml_cuda_free_scratch() {
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
ggml_cuda_func_t func; ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|| tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU); || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
switch (tensor->op) { switch (tensor->op) {

8
ggml.c
View file

@ -14335,7 +14335,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
if (skip_cpu) { if (skip_cpu) {
return; return;
} }
GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU); GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU); GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
@ -16032,9 +16032,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
}; };
if (node->src0) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
}
// COMPUTE // COMPUTE
if (node->n_tasks > 1) { if (node->n_tasks > 1) {
@ -16070,9 +16068,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} }
params.type = GGML_TASK_COMPUTE; params.type = GGML_TASK_COMPUTE;
if (node->src0) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
}
// wait for thread pool // wait for thread pool
if (node->n_tasks > 1) { if (node->n_tasks > 1) {
@ -16127,9 +16123,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} }
params.type = GGML_TASK_FINALIZE; params.type = GGML_TASK_FINALIZE;
if (node->src0) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
}
// wait for thread pool // wait for thread pool
if (node->n_tasks > 1) { if (node->n_tasks > 1) {