ggml : sync latest ggml_mul_mat_id
This commit is contained in:
parent
a3eefe95a8
commit
861cd67899
4 changed files with 114 additions and 75 deletions
73
ggml-cuda.cu
73
ggml-cuda.cu
|
@ -1,13 +1,15 @@
|
|||
#include <algorithm>
|
||||
#include <assert.h>
|
||||
#include <atomic>
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cinttypes>
|
||||
#include <float.h>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <atomic>
|
||||
#include <assert.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
|
@ -8234,36 +8236,51 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#if 0
|
||||
//#ifdef CUDA_USE_TENSOR_CORES
|
||||
// const bool use_tensor_cores = true;
|
||||
//#else
|
||||
// const bool use_tensor_cores = false;
|
||||
//#endif
|
||||
|
||||
ggml_cuda_mul_mat_id_cublas(dst);
|
||||
|
||||
// TODO: mmq/mmv support
|
||||
#else
|
||||
const struct ggml_tensor * ids = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const int id = dst->op_params[0];
|
||||
|
||||
int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
|
||||
int32_t a_id;
|
||||
CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
|
||||
const struct ggml_tensor * src0 = dst->src[a_id + 2];
|
||||
|
||||
ggml_cuda_mul_mat(src0, src1, dst);
|
||||
#endif
|
||||
|
||||
(void) _src0;
|
||||
(void) _src1;
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = dst->op_params[0];
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
|
||||
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
ggml_tensor_extra_gpu src1_row_extra;
|
||||
ggml_tensor_extra_gpu dst_row_extra;
|
||||
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
src1_row.ne[1] = 1;
|
||||
dst_row.ne[1] = 1;
|
||||
|
||||
src1_row.extra = &src1_row_extra;
|
||||
dst_row.extra = &dst_row_extra;
|
||||
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
//int32_t row_id;
|
||||
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
||||
|
||||
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue