ggml : change ggml_scale to take a float instead of tensor (#4573)
* ggml : change ggml_scale to take a float instead of tensor * ggml : fix CPU implementation * tests : fix test-grad0 ggml-ci
This commit is contained in:
parent
769a7bc85e
commit
afefa319f1
12 changed files with 82 additions and 205 deletions
14
ggml-cuda.cu
14
ggml-cuda.cu
|
@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale(
|
|||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
float scale;
|
||||
// HACK: support for ggml backend interface
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
scale = ((float *) src1->data)[0];
|
||||
} else {
|
||||
// TODO: pass pointer to kernel instead of copying to host
|
||||
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
|
||||
scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
|
||||
|
||||
const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = nullptr;
|
||||
float * src1_ddf = nullptr;
|
||||
|
@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
|
||||
}
|
||||
|
||||
if (use_src1 && !src1_stays_on_host) {
|
||||
if (use_src1) {
|
||||
if (src1_on_device) {
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue