reduced the number of shmem_calloc calls
This commit is contained in:
parent
3f2769bf26
commit
c8d67705fe
1 changed files with 17 additions and 17 deletions
|
@ -51,20 +51,14 @@ struct ggml_openshmem_context * ggml_openshmem_init(void) {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE;
|
ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE;
|
||||||
ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t);
|
ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t) + sizeof(uint64_t) + sizeof(uint64_t);
|
||||||
ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size);
|
ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size);
|
||||||
|
|
||||||
/*
|
|
||||||
* uint8_t signal_byte[shmem_npes()];
|
|
||||||
*/
|
|
||||||
ctx->recv_signal = (uint64_t*)shmem_calloc(1, ctx->n_pes*sizeof(uint64_t));
|
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_openshmem_free(struct ggml_openshmem_context * ctx) {
|
void ggml_openshmem_free(struct ggml_openshmem_context * ctx) {
|
||||||
shmem_free(ctx->symmetric_comm_structure);
|
shmem_free(ctx->symmetric_comm_structure);
|
||||||
shmem_free(ctx->recv_signal);
|
|
||||||
free(ctx);
|
free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,18 +121,20 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
|
||||||
|
|
||||||
const int64_t symmetric_comm_structure_size =
|
const int64_t symmetric_comm_structure_size =
|
||||||
ctx->symmetric_comm_structure_size;
|
ctx->symmetric_comm_structure_size;
|
||||||
|
|
||||||
|
uint64_t * my_recv_signal =
|
||||||
|
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
|
||||||
|
uint64_t * dst_recv_signal =
|
||||||
|
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t);
|
||||||
|
|
||||||
uint8_t * dst_symmetric_comm_structure =
|
uint8_t * dst_symmetric_comm_structure =
|
||||||
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
|
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t);
|
||||||
int64_t * dst_symmetric_comm_offset =
|
int64_t * dst_symmetric_comm_offset =
|
||||||
(int64_t*)(dst_symmetric_comm_structure);
|
(int64_t*)(dst_symmetric_comm_structure);
|
||||||
int64_t * dst_symmetric_comm_length =
|
int64_t * dst_symmetric_comm_length =
|
||||||
((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t);
|
((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t);
|
||||||
uint8_t * dst_symmetric_comm_buffer =
|
uint8_t * dst_symmetric_comm_buffer =
|
||||||
((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t);
|
((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t);
|
||||||
uint64_t * dst_recv_signal =
|
|
||||||
ctx->recv_signal+dst_pe;
|
|
||||||
uint64_t * my_recv_signal =
|
|
||||||
ctx->recv_signal+ctx->pe;
|
|
||||||
|
|
||||||
const int64_t nelements = ggml_nelements(t);
|
const int64_t nelements = ggml_nelements(t);
|
||||||
int64_t xmt_size = 0;
|
int64_t xmt_size = 0;
|
||||||
|
@ -223,18 +219,22 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
|
||||||
|
|
||||||
static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) {
|
static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) {
|
||||||
|
|
||||||
|
const int64_t symmetric_comm_structure_size =
|
||||||
|
ctx->symmetric_comm_structure_size;
|
||||||
|
|
||||||
|
uint64_t * src_recv_signal =
|
||||||
|
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe);
|
||||||
|
uint64_t * my_recv_signal =
|
||||||
|
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t);
|
||||||
|
|
||||||
uint8_t * src_symmetric_comm_structure =
|
uint8_t * src_symmetric_comm_structure =
|
||||||
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*src_pe);
|
((uint8_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t)+sizeof(uint64_t);
|
||||||
int64_t * src_symmetric_comm_offset =
|
int64_t * src_symmetric_comm_offset =
|
||||||
(int64_t*)(src_symmetric_comm_structure);
|
(int64_t*)(src_symmetric_comm_structure);
|
||||||
int64_t * src_symmetric_comm_length =
|
int64_t * src_symmetric_comm_length =
|
||||||
((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t);
|
((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t);
|
||||||
uint8_t * src_symmetric_comm_buffer =
|
uint8_t * src_symmetric_comm_buffer =
|
||||||
((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t);
|
((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t);
|
||||||
uint64_t * src_recv_signal =
|
|
||||||
ctx->recv_signal+src_pe;
|
|
||||||
uint64_t * my_recv_signal =
|
|
||||||
ctx->recv_signal+ctx->pe;
|
|
||||||
|
|
||||||
int64_t total_loop_count = 0;
|
int64_t total_loop_count = 0;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue