diff --git a/ggml-oshmem.c b/ggml-oshmem.c index 86e759693..0cdf4bdc5 100644 --- a/ggml-oshmem.c +++ b/ggml-oshmem.c @@ -51,20 +51,14 @@ struct ggml_openshmem_context * ggml_openshmem_init(void) { * */ ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE; - ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t); + ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t) + sizeof(uint64_t) + sizeof(uint64_t); ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size); - /* - * uint8_t signal_byte[shmem_npes()]; - */ - ctx->recv_signal = (uint64_t*)shmem_calloc(1, ctx->n_pes*sizeof(uint64_t)); - return ctx; } void ggml_openshmem_free(struct ggml_openshmem_context * ctx) { shmem_free(ctx->symmetric_comm_structure); - shmem_free(ctx->recv_signal); free(ctx); } @@ -127,18 +121,20 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru const int64_t symmetric_comm_structure_size = ctx->symmetric_comm_structure_size; + + uint64_t * my_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + uint64_t * dst_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t); + uint8_t * dst_symmetric_comm_structure = - ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t); int64_t * dst_symmetric_comm_offset = (int64_t*)(dst_symmetric_comm_structure); int64_t * dst_symmetric_comm_length = ((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t); uint8_t * dst_symmetric_comm_buffer = ((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t); - uint64_t * dst_recv_signal = - ctx->recv_signal+dst_pe; - uint64_t * my_recv_signal = - ctx->recv_signal+ctx->pe; const int64_t nelements = ggml_nelements(t); int64_t xmt_size = 0; @@ -223,18 +219,22 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) { + const int64_t symmetric_comm_structure_size = + ctx->symmetric_comm_structure_size; + + uint64_t * src_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe); + uint64_t * my_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t); + uint8_t * src_symmetric_comm_structure = - ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*src_pe); + ((uint8_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t)+sizeof(uint64_t); int64_t * src_symmetric_comm_offset = (int64_t*)(src_symmetric_comm_structure); int64_t * src_symmetric_comm_length = ((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t); uint8_t * src_symmetric_comm_buffer = ((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t); - uint64_t * src_recv_signal = - ctx->recv_signal+src_pe; - uint64_t * my_recv_signal = - ctx->recv_signal+ctx->pe; int64_t total_loop_count = 0;