cleaned up pointer arithmetic; rm'd a member variable of the oshmem context struct

This commit is contained in:
ct-clmsn 2023-12-21 19:35:33 -05:00
parent c8d67705fe
commit d05fcad5d1

View file

@ -20,7 +20,6 @@ struct ggml_openshmem_context {
int64_t symmetric_buffer_size;
int64_t symmetric_comm_structure_size;
uint8_t * symmetric_comm_structure;
uint64_t * recv_signal;
};
void ggml_openshmem_backend_init(void) {
@ -74,7 +73,7 @@ void ggml_openshmem_eval_init(
UNUSED(ctx);
uint8_t * dst_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t);
int64_t * dst_symmetric_comm_offset =
(int64_t*)(dst_symmetric_comm_structure);
@ -123,12 +122,12 @@ static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, stru
ctx->symmetric_comm_structure_size;
uint64_t * my_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*ctx->pe);
uint64_t * dst_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t);
((uint64_t*)my_recv_signal)+sizeof(uint64_t);
uint8_t * dst_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t);
((uint8_t*)dst_recv_signal)+sizeof(uint64_t);
int64_t * dst_symmetric_comm_offset =
(int64_t*)(dst_symmetric_comm_structure);
int64_t * dst_symmetric_comm_length =
@ -225,10 +224,10 @@ static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, stru
uint64_t * src_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe);
uint64_t * my_recv_signal =
((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t);
((uint64_t*)src_recv_signal)+sizeof(uint64_t);
uint8_t * src_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe)+sizeof(uint64_t)+sizeof(uint64_t);
((uint8_t*)my_recv_signal)+sizeof(uint64_t);
int64_t * src_symmetric_comm_offset =
(int64_t*)(src_symmetric_comm_structure);
int64_t * src_symmetric_comm_length =