vulkan: implement GGML_OP_OPT_STEP_ADAMW
This commit is contained in:
parent
095f8d17ac
commit
9526033b71
3 changed files with 168 additions and 0 deletions
|
@ -259,6 +259,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
vk_pipeline pipeline_pool2d_f32;
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||||
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||||
|
|
||||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||||
|
@ -2173,6 +2174,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
@ -5329,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_leaky_relu_f32;
|
return ctx->device->pipeline_leaky_relu_f32;
|
||||||
|
@ -5936,6 +5944,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
||||||
|
const ggml_tensor * x = dst->src[0];
|
||||||
|
const ggml_tensor * g = dst->src[1];
|
||||||
|
const ggml_tensor * gm = dst->src[2];
|
||||||
|
const ggml_tensor * gv = dst->src[3];
|
||||||
|
const ggml_tensor * p = dst->src[4];
|
||||||
|
|
||||||
|
GGML_ASSERT(x->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(gm->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(gv->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(p->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->buffer != nullptr);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(x));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(g));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(gm));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(gv));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(p));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(x, g));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(x, gm));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(x, gv));
|
||||||
|
GGML_ASSERT(ggml_nelements(p) == 7);
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
|
||||||
|
GGML_ASSERT(pipeline != nullptr);
|
||||||
|
|
||||||
|
if (dryrun) {
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
|
||||||
|
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
|
||||||
|
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
|
||||||
|
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
|
||||||
|
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
|
||||||
|
|
||||||
|
if (ctx->device->uma) {
|
||||||
|
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
|
||||||
|
|
||||||
|
X_uma = d_X != nullptr;
|
||||||
|
G_uma = d_G != nullptr;
|
||||||
|
GM_uma = d_GM != nullptr;
|
||||||
|
GV_uma = d_GV != nullptr;
|
||||||
|
P_uma = d_P != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!X_uma) {
|
||||||
|
d_X = x_buf_ctx->dev_buffer;
|
||||||
|
x_offset = vk_tensor_offset(x) + x->view_offs;
|
||||||
|
}
|
||||||
|
if (!G_uma) {
|
||||||
|
d_G = g_buf_ctx->dev_buffer;
|
||||||
|
g_offset = vk_tensor_offset(g) + g->view_offs;
|
||||||
|
}
|
||||||
|
if (!GM_uma) {
|
||||||
|
d_GM = gm_buf_ctx->dev_buffer;
|
||||||
|
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
|
||||||
|
}
|
||||||
|
if (!GV_uma) {
|
||||||
|
d_GV = gv_buf_ctx->dev_buffer;
|
||||||
|
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
|
||||||
|
}
|
||||||
|
if (!P_uma) {
|
||||||
|
d_P = p_buf_ctx->dev_buffer;
|
||||||
|
p_offset = vk_tensor_offset(p) + p->view_offs;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t x_size = ggml_nbytes(x);
|
||||||
|
const uint64_t g_size = ggml_nbytes(g);
|
||||||
|
const uint64_t gm_size = ggml_nbytes(gm);
|
||||||
|
const uint64_t gv_size = ggml_nbytes(gv);
|
||||||
|
const uint64_t p_size = ggml_nbytes(p);
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||||
|
vk_subbuffer{ d_X, x_offset, x_size },
|
||||||
|
vk_subbuffer{ d_G, g_offset, g_size },
|
||||||
|
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
||||||
|
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
||||||
|
vk_subbuffer{ d_P, p_offset, p_size },
|
||||||
|
}, sizeof(vk_op_push_constants), &pc, elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const size_t n = ggml_nelements(dst->src[0]);
|
||||||
|
|
||||||
|
ggml_vk_op_f32_opt_step_adamw(
|
||||||
|
ctx, subctx, dst,
|
||||||
|
{ (uint32_t)n, 0, 0.0f, 0.0f },
|
||||||
|
dryrun
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
int * op_params = (int *)dst->op_params;
|
int * op_params = (int *)dst->op_params;
|
||||||
|
|
||||||
|
@ -7100,6 +7213,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||||
|
@ -7322,6 +7436,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -7409,6 +7528,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
@ -8346,6 +8466,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -8951,6 +9072,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||||
tensor->src[4], tensor->src[5]);
|
tensor->src[4], tensor->src[5]);
|
||||||
|
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||||
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2],
|
||||||
|
tensor->src[3], tensor->src[4]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
|
|
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) buffer X {A_TYPE x[];};
|
||||||
|
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
|
||||||
|
layout (binding = 2) buffer GM {A_TYPE gradm[];};
|
||||||
|
layout (binding = 3) buffer GV {A_TYPE gradv[];};
|
||||||
|
layout (binding = 4) readonly buffer P {float params[7];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float alpha = params[0];
|
||||||
|
const float beta1 = params[1];
|
||||||
|
const float beta2 = params[2];
|
||||||
|
const float eps = params[3];
|
||||||
|
const float wd = params[4];
|
||||||
|
const float beta1h = params[5];
|
||||||
|
const float beta2h = params[6];
|
||||||
|
|
||||||
|
const float gi = grad[i];
|
||||||
|
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
|
||||||
|
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
|
||||||
|
|
||||||
|
gradm[i] = gmi;
|
||||||
|
gradv[i] = gvi;
|
||||||
|
|
||||||
|
const float mh = gmi*beta1h;
|
||||||
|
const float vh = sqrt(gvi*beta2h) + eps;
|
||||||
|
|
||||||
|
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
|
||||||
|
}
|
|
@ -500,6 +500,8 @@ void process_shaders() {
|
||||||
|
|
||||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue