fixed small segmentation bug; switched to using type-sensitive openshmem calls
This commit is contained in:
parent
0de3b02353
commit
fa49c150d0
1 changed files with 24 additions and 13 deletions
|
@ -70,18 +70,29 @@ int ggml_openshmem_pe(struct ggml_openshmem_context * ctx) {
|
|||
}
|
||||
|
||||
void ggml_openshmem_eval_init(
|
||||
struct ggml_openshmem_context * ctx_openshmem,
|
||||
struct ggml_openshmem_context * ctx,
|
||||
int * n_tokens,
|
||||
int * n_past,
|
||||
int * n_threads) {
|
||||
UNUSED(ctx_openshmem);
|
||||
UNUSED(ctx);
|
||||
|
||||
uint8_t * dst_symmetric_comm_structure =
|
||||
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
|
||||
int64_t * dst_symmetric_comm_offset =
|
||||
(int64_t*)(dst_symmetric_comm_structure);
|
||||
|
||||
// synchronize the worker node parameters with the root node
|
||||
shmem_barrier_all();
|
||||
|
||||
shmem_broadcast(SHMEM_TEAM_WORLD, n_tokens, n_tokens, 1, 0);
|
||||
shmem_broadcast(SHMEM_TEAM_WORLD, n_past, n_tokens, 1, 0);
|
||||
shmem_broadcast(SHMEM_TEAM_WORLD, n_threads, n_tokens, 1, 0);
|
||||
memcpy(dst_symmetric_comm_offset, n_tokens, sizeof(int));
|
||||
memcpy(dst_symmetric_comm_offset+sizeof(int), n_past, sizeof(int));
|
||||
memcpy(dst_symmetric_comm_offset+sizeof(int)+sizeof(int), n_past, sizeof(int));
|
||||
|
||||
shmem_int32_broadcast(SHMEM_TEAM_WORLD, (int*)dst_symmetric_comm_offset, (int*)dst_symmetric_comm_offset, 3, 0);
|
||||
|
||||
memcpy(n_tokens, dst_symmetric_comm_offset, sizeof(int));
|
||||
memcpy(n_past, dst_symmetric_comm_offset+sizeof(int), sizeof(int));
|
||||
memcpy(n_threads, dst_symmetric_comm_offset+sizeof(int)+sizeof(int), sizeof(int));
|
||||
|
||||
shmem_quiet();
|
||||
}
|
||||
|
@ -139,16 +150,16 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
|
|||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
int64_t count[2] = { (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE), 1 };
|
||||
const int64_t total_loop_count = count[ count[0] == 0 ];
|
||||
|
||||
int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * count[0]) };
|
||||
int64_t init_segments = (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE);
|
||||
int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * init_segments) };
|
||||
int64_t xmt_byte_offset = 0;
|
||||
int64_t xmt_byte_amount = 0;
|
||||
|
||||
const int64_t total_loop_count = init_segments + !( xmt_amount[1] < 1);
|
||||
|
||||
memcpy(dst_symmetric_comm_offset, &total_loop_count, sizeof(int64_t));
|
||||
|
||||
shmem_put_signal(
|
||||
shmem_int64_put_signal(
|
||||
dst_symmetric_comm_offset,
|
||||
dst_symmetric_comm_offset,
|
||||
sizeof(int64_t),
|
||||
|
@ -173,7 +184,7 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
|
|||
memcpy(dst_symmetric_comm_length, &xmt_byte_amount, sizeof(int64_t));
|
||||
memcpy(dst_symmetric_comm_buffer, ((uint8_t*)t->data)+xmt_byte_offset, xmt_byte_amount);
|
||||
|
||||
shmem_put_signal(
|
||||
shmem_uint8_put_signal(
|
||||
dst_symmetric_comm_structure,
|
||||
dst_symmetric_comm_structure,
|
||||
symmetric_comm_structure_size,
|
||||
|
@ -220,7 +231,7 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru
|
|||
(*my_recv_signal) = 0;
|
||||
|
||||
memcpy(src_symmetric_comm_offset, &total_loop_count, sizeof(int64_t));
|
||||
shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
|
||||
shmem_uint8_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
|
||||
|
||||
for(int32_t i = 0; i < total_loop_count; ++i) {
|
||||
shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1);
|
||||
|
@ -232,7 +243,7 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru
|
|||
(*src_symmetric_comm_length)
|
||||
);
|
||||
|
||||
shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
|
||||
shmem_uint8_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
|
||||
}
|
||||
|
||||
shmem_fence();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue