Add support for multithread LHS conversion
This commit is contained in:
parent
119d3bf986
commit
f4eb1b3854
2 changed files with 26 additions and 19 deletions
|
@ -114,30 +114,37 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
size_t sr = kernel->get_sr();
|
||||
size_t bl = k_q4_0_block_size;
|
||||
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, bl, mr, kr, sr);
|
||||
// Calculate number of columns to be processed per thread
|
||||
const size_t num_m_per_thread = kai_roundup(m, nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
if(m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
||||
void * dst_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, bl, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func(m, k, bl, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr);
|
||||
lhs_info->pack_func(m_to_process, k, bl, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, k_q4_0_block_size, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, k_q4_0_block_size);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
kernel->run_kernel(m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr,
|
||||
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue