From d05fcad5d12bc76013411efc2da721cd26a374a8 Mon Sep 17 00:00:00 2001 From: ct-clmsn Date: Thu, 21 Dec 2023 19:35:33 -0500 Subject: [PATCH] cleaned up pointer arithmetic; rm'd a member variable of the oshmem context struct --- ggml-oshmem.c | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ggml-oshmem.c b/ggml-oshmem.c index 0cdf4bdc5..e9385edef 100644 --- a/ggml-oshmem.c +++ b/ggml-oshmem.c @@ -20,7 +20,6 @@ struct ggml_openshmem_context { int64_t symmetric_buffer_size; int64_t symmetric_comm_structure_size; uint8_t * symmetric_comm_structure; - uint64_t * recv_signal; }; void ggml_openshmem_backend_init(void) { @@ -74,7 +73,7 @@ void ggml_openshmem_eval_init( UNUSED(ctx); uint8_t * dst_symmetric_comm_structure = - ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t); int64_t * dst_symmetric_comm_offset = (int64_t*)(dst_symmetric_comm_structure); @@ -123,12 +122,12 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru ctx->symmetric_comm_structure_size; uint64_t * my_recv_signal = - ((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*ctx->pe); uint64_t * dst_recv_signal = - ((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t); + ((uint64_t*)my_recv_signal)+sizeof(uint64_t); uint8_t * dst_symmetric_comm_structure = - ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t); + ((uint8_t*)dst_recv_signal)+sizeof(uint64_t); int64_t * dst_symmetric_comm_offset = (int64_t*)(dst_symmetric_comm_structure); int64_t * dst_symmetric_comm_length = @@ -225,10 +224,10 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru uint64_t * src_recv_signal = ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe); uint64_t * my_recv_signal = - ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t); + ((uint64_t*)src_recv_signal)+sizeof(uint64_t); uint8_t * src_symmetric_comm_structure = - ((uint8_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t)+sizeof(uint64_t); + ((uint8_t*)my_recv_signal)+sizeof(uint64_t); int64_t * src_symmetric_comm_offset = (int64_t*)(src_symmetric_comm_structure); int64_t * src_symmetric_comm_length =