ggml/kompute: Move butf into struct ggml_backend_kompute_context

Signed-off-by: Weishi Li <liweishi@kylinos.cn>
This commit is contained in:
Weishi Li 2024-08-21 14:26:51 +08:00 committed by Feng Jiang
parent e914ac7c68
commit 74ba8516ce

View file

@ -71,8 +71,10 @@ struct ggml_backend_kompute_context {
std::string name;
std::shared_ptr<vk::DescriptorPool> pool;
ggml_backend_buffer_type buft;
ggml_backend_kompute_context(int device)
: device(device), name(ggml_kompute_format_name(device)) {}
: device(device), name(ggml_kompute_format_name(device)) { buft.context = nullptr; }
};
// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
@ -1918,24 +1920,25 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
};
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
static std::vector<ggml_backend_buffer_type> bufts = []() {
std::vector<ggml_backend_buffer_type> vec;
if (!s_kompute_context)
s_kompute_context = new ggml_backend_kompute_context(device);
auto * buft = &s_kompute_context->buft;
if (!buft->context) {
auto devices = ggml_vk_available_devices_internal(0);
vec.reserve(devices.size());
for (const auto & dev : devices) {
vec.push_back({
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
});
for (std::size_t i = 0; i < devices.size(); i++) {
if (device == devices[i].index) {
buft->context = new ggml_backend_kompute_buffer_type_context(
devices[i].index,
devices[i].bufferAlignment,
devices[i].maxAlloc);
buft->iface = ggml_backend_kompute_buffer_type_interface;
break;
}
}
return vec;
}();
}
auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
});
return it < bufts.end() ? &*it : nullptr;
return buft;
}
// backend
@ -1974,8 +1977,8 @@ static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struc
}
static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
GGML_UNUSED(backend);
return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
auto *ctx = static_cast<ggml_backend_kompute_context *>(backend->context);
return &ctx->buft == buft;
}
static struct ggml_backend_i kompute_backend_i = {
@ -2007,8 +2010,8 @@ static ggml_guid_t ggml_backend_kompute_guid() {
}
ggml_backend_t ggml_backend_kompute_init(int device) {
GGML_ASSERT(s_kompute_context == nullptr);
s_kompute_context = new ggml_backend_kompute_context(device);
if (!s_kompute_context)
s_kompute_context = new ggml_backend_kompute_context(device);
ggml_backend_t kompute_backend = new ggml_backend {
/* .guid = */ ggml_backend_kompute_guid(),