diff --git a/ggml-oshmem.c b/ggml-oshmem.c index c49250e02..4c4923979 100644 --- a/ggml-oshmem.c +++ b/ggml-oshmem.c @@ -70,18 +70,29 @@ int ggml_openshmem_pe(struct ggml_openshmem_context * ctx) { } void ggml_openshmem_eval_init( - struct ggml_openshmem_context * ctx_openshmem, + struct ggml_openshmem_context * ctx, int * n_tokens, int * n_past, int * n_threads) { - UNUSED(ctx_openshmem); + UNUSED(ctx); + + uint8_t * dst_symmetric_comm_structure = + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + int64_t * dst_symmetric_comm_offset = + (int64_t*)(dst_symmetric_comm_structure); // synchronize the worker node parameters with the root node shmem_barrier_all(); - shmem_broadcast(SHMEM_TEAM_WORLD, n_tokens, n_tokens, 1, 0); - shmem_broadcast(SHMEM_TEAM_WORLD, n_past, n_tokens, 1, 0); - shmem_broadcast(SHMEM_TEAM_WORLD, n_threads, n_tokens, 1, 0); + memcpy(dst_symmetric_comm_offset, n_tokens, sizeof(int)); + memcpy(dst_symmetric_comm_offset+sizeof(int), n_past, sizeof(int)); + memcpy(dst_symmetric_comm_offset+sizeof(int)+sizeof(int), n_past, sizeof(int)); + + shmem_int32_broadcast(SHMEM_TEAM_WORLD, (int*)dst_symmetric_comm_offset, (int*)dst_symmetric_comm_offset, 3, 0); + + memcpy(n_tokens, dst_symmetric_comm_offset, sizeof(int)); + memcpy(n_past, dst_symmetric_comm_offset+sizeof(int), sizeof(int)); + memcpy(n_threads, dst_symmetric_comm_offset+sizeof(int)+sizeof(int), sizeof(int)); shmem_quiet(); } @@ -139,16 +150,16 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru default: GGML_ASSERT(false && "not implemented"); } - int64_t count[2] = { (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE), 1 }; - const int64_t total_loop_count = count[ count[0] == 0 ]; - - int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * count[0]) }; + int64_t init_segments = (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE); + int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * init_segments) }; int64_t xmt_byte_offset = 0; int64_t xmt_byte_amount = 0; + const int64_t total_loop_count = init_segments + !( xmt_amount[1] < 1); + memcpy(dst_symmetric_comm_offset, &total_loop_count, sizeof(int64_t)); - shmem_put_signal( + shmem_int64_put_signal( dst_symmetric_comm_offset, dst_symmetric_comm_offset, sizeof(int64_t), @@ -173,7 +184,7 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru memcpy(dst_symmetric_comm_length, &xmt_byte_amount, sizeof(int64_t)); memcpy(dst_symmetric_comm_buffer, ((uint8_t*)t->data)+xmt_byte_offset, xmt_byte_amount); - shmem_put_signal( + shmem_uint8_put_signal( dst_symmetric_comm_structure, dst_symmetric_comm_structure, symmetric_comm_structure_size, @@ -220,7 +231,7 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru (*my_recv_signal) = 0; memcpy(src_symmetric_comm_offset, &total_loop_count, sizeof(int64_t)); - shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); + shmem_uint8_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); for(int32_t i = 0; i < total_loop_count; ++i) { shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1); @@ -232,7 +243,7 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru (*src_symmetric_comm_length) ); - shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); + shmem_uint8_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); } shmem_fence();