vulkan: implement GGML_OP_OPT_STEP_ADAMW
This commit is contained in:
parent
095f8d17ac
commit
9526033b71
3 changed files with 168 additions and 0 deletions
|
@ -259,6 +259,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
|
@ -2173,6 +2174,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
@ -5329,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
|
@ -5936,6 +5944,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * x = dst->src[0];
|
||||
const ggml_tensor * g = dst->src[1];
|
||||
const ggml_tensor * gm = dst->src[2];
|
||||
const ggml_tensor * gv = dst->src[3];
|
||||
const ggml_tensor * p = dst->src[4];
|
||||
|
||||
GGML_ASSERT(x->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(gm->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(gv->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(p->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
GGML_ASSERT(ggml_is_contiguous(x));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(gm));
|
||||
GGML_ASSERT(ggml_is_contiguous(gv));
|
||||
GGML_ASSERT(ggml_is_contiguous(p));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, g));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, gm));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, gv));
|
||||
GGML_ASSERT(ggml_nelements(p) == 7);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
|
||||
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
|
||||
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
|
||||
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
|
||||
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
|
||||
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
|
||||
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
|
||||
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
|
||||
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
|
||||
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
|
||||
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
|
||||
|
||||
X_uma = d_X != nullptr;
|
||||
G_uma = d_G != nullptr;
|
||||
GM_uma = d_GM != nullptr;
|
||||
GV_uma = d_GV != nullptr;
|
||||
P_uma = d_P != nullptr;
|
||||
}
|
||||
|
||||
if (!X_uma) {
|
||||
d_X = x_buf_ctx->dev_buffer;
|
||||
x_offset = vk_tensor_offset(x) + x->view_offs;
|
||||
}
|
||||
if (!G_uma) {
|
||||
d_G = g_buf_ctx->dev_buffer;
|
||||
g_offset = vk_tensor_offset(g) + g->view_offs;
|
||||
}
|
||||
if (!GM_uma) {
|
||||
d_GM = gm_buf_ctx->dev_buffer;
|
||||
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
|
||||
}
|
||||
if (!GV_uma) {
|
||||
d_GV = gv_buf_ctx->dev_buffer;
|
||||
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
|
||||
}
|
||||
if (!P_uma) {
|
||||
d_P = p_buf_ctx->dev_buffer;
|
||||
p_offset = vk_tensor_offset(p) + p->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t x_size = ggml_nbytes(x);
|
||||
const uint64_t g_size = ggml_nbytes(g);
|
||||
const uint64_t gm_size = ggml_nbytes(gm);
|
||||
const uint64_t gv_size = ggml_nbytes(gv);
|
||||
const uint64_t p_size = ggml_nbytes(p);
|
||||
|
||||
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_X, x_offset, x_size },
|
||||
vk_subbuffer{ d_G, g_offset, g_size },
|
||||
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
||||
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
||||
vk_subbuffer{ d_P, p_offset, p_size },
|
||||
}, sizeof(vk_op_push_constants), &pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t n = ggml_nelements(dst->src[0]);
|
||||
|
||||
ggml_vk_op_f32_opt_step_adamw(
|
||||
ctx, subctx, dst,
|
||||
{ (uint32_t)n, 0, 0.0f, 0.0f },
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
int * op_params = (int *)dst->op_params;
|
||||
|
||||
|
@ -7100,6 +7213,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
break;
|
||||
default:
|
||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||
|
@ -7322,6 +7436,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
@ -7409,6 +7528,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
|
@ -8346,6 +8466,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -8951,6 +9072,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||
tensor->src[4], tensor->src[5]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2],
|
||||
tensor->src[3], tensor->src[4]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
|
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
|
@ -0,0 +1,42 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X {A_TYPE x[];};
|
||||
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
|
||||
layout (binding = 2) buffer GM {A_TYPE gradm[];};
|
||||
layout (binding = 3) buffer GV {A_TYPE gradv[];};
|
||||
layout (binding = 4) readonly buffer P {float params[7];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = params[0];
|
||||
const float beta1 = params[1];
|
||||
const float beta2 = params[2];
|
||||
const float eps = params[3];
|
||||
const float wd = params[4];
|
||||
const float beta1h = params[5];
|
||||
const float beta2h = params[6];
|
||||
|
||||
const float gi = grad[i];
|
||||
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
|
||||
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
|
||||
|
||||
gradm[i] = gmi;
|
||||
gradv[i] = gvi;
|
||||
|
||||
const float mh = gmi*beta1h;
|
||||
const float vh = sqrt(gvi*beta2h) + eps;
|
||||
|
||||
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
|
||||
}
|
|
@ -500,6 +500,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue