reduced the number of shmem_calloc calls

This commit is contained in:
ct-clmsn 2023-12-21 18:58:46 -05:00
parent 3f2769bf26
commit c8d67705fe

View file

@ -51,20 +51,14 @@ struct ggml_openshmem_context * ggml_openshmem_init(void) {
*
*/
ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE;
ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t);
ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t) + sizeof(uint64_t) + sizeof(uint64_t);
ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size);
/*
* uint8_t signal_byte[shmem_npes()];
*/
ctx->recv_signal = (uint64_t*)shmem_calloc(1, ctx->n_pes*sizeof(uint64_t));
return ctx;
}
void ggml_openshmem_free(struct ggml_openshmem_context * ctx) {
shmem_free(ctx->symmetric_comm_structure);
shmem_free(ctx->recv_signal);
free(ctx);
}
@ -127,18 +121,20 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
const int64_t symmetric_comm_structure_size =
ctx->symmetric_comm_structure_size;
uint64_t * my_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
uint64_t * dst_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t);
uint8_t * dst_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t);
int64_t * dst_symmetric_comm_offset =
(int64_t*)(dst_symmetric_comm_structure);
int64_t * dst_symmetric_comm_length =
((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t);
uint8_t * dst_symmetric_comm_buffer =
((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t);
uint64_t * dst_recv_signal =
ctx->recv_signal+dst_pe;
uint64_t * my_recv_signal =
ctx->recv_signal+ctx->pe;
const int64_t nelements = ggml_nelements(t);
int64_t xmt_size = 0;
@ -223,18 +219,22 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) {
const int64_t symmetric_comm_structure_size =
ctx->symmetric_comm_structure_size;
uint64_t * src_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe);
uint64_t * my_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t);
uint8_t * src_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*src_pe);
((uint8_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t)+sizeof(uint64_t);
int64_t * src_symmetric_comm_offset =
(int64_t*)(src_symmetric_comm_structure);
int64_t * src_symmetric_comm_length =
((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t);
uint8_t * src_symmetric_comm_buffer =
((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t);
uint64_t * src_recv_signal =
ctx->recv_signal+src_pe;
uint64_t * my_recv_signal =
ctx->recv_signal+ctx->pe;
int64_t total_loop_count = 0;