kompute : fix merge issues

This commit is contained in:
Jared Van Bortel 2024-01-29 12:41:02 -05:00
parent da1dc66659
commit dc08e512cc
2 changed files with 56 additions and 30 deletions

View file

@ -118,9 +118,9 @@ static void enable_sam() {
}
#endif
static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physicalDevice) {
static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
vk::PhysicalDeviceFeatures availableFeatures;
physicalDevice.getFeatures(&availableFeatures);
physical_device.getFeatures(&availableFeatures);
if (!availableFeatures.shaderInt16)
return false;
@ -134,7 +134,7 @@ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physicalDevic
vk::PhysicalDeviceFeatures2 features2;
features2.pNext = &availableFeatures11;
physicalDevice.getFeatures2(&features2);
physical_device.getFeatures2(&features2);
if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
!availableFeatures11.storageBuffer16BitAccess) {
@ -169,29 +169,31 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
return results;
std::vector<vk::PhysicalDevice> physicalDevices;
std::vector<vk::PhysicalDevice> physical_devices;
try {
physicalDevices = komputeManager()->listDevices();
physical_devices = komputeManager()->listDevices();
} catch (vk::SystemError & err) {
std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
return results;
}
uint32_t deviceCount = physicalDevices.size();
uint32_t deviceCount = physical_devices.size();
if (deviceCount == 0)
return results;
std::unordered_map<std::string, size_t> count_by_name;
for (uint32_t i = 0; i < deviceCount; i++) {
VkPhysicalDeviceProperties properties = physicalDevices.at(i).getProperties();
VkPhysicalDeviceMemoryProperties memoryProperties = physicalDevices.at(i).getMemoryProperties();
const uint32_t major = VK_VERSION_MAJOR(properties.apiVersion);
const uint32_t minor = VK_VERSION_MINOR(properties.apiVersion);
const auto & physical_device = physical_devices[i];
VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
if (major < 1 || minor < 2)
continue;
if (!ggml_vk_checkPhysicalDeviceFeatures(physicalDevices.at(i)))
if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
continue;
size_t heapSize = 0;
@ -206,23 +208,45 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
if (heapSize < memoryRequired)
continue;
vk::PhysicalDeviceSubgroupProperties subgroupProperties;
vk::PhysicalDeviceProperties2 deviceProperties2;
deviceProperties2.pNext = &subgroupProperties;
physicalDevices.at(i).getProperties2(&deviceProperties2);
auto ext_props = physical_device.enumerateDeviceExtensionProperties();
bool has_maintenance4 = false;
if (subgroupProperties.subgroupSize < 32)
// Check if maintenance4 is supported
for (const auto & properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
has_maintenance4 = true;
}
}
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceProperties2 dev_props2;
vk::PhysicalDeviceMaintenance3Properties dev_props3;
vk::PhysicalDeviceMaintenance4Properties dev_props4;
dev_props2.pNext = &dev_props3;
dev_props3.pNext = &subgroup_props;
if (has_maintenance4) {
subgroup_props.pNext = &dev_props4;
}
physical_device.getProperties2(&dev_props2);
if (subgroup_props.subgroupSize < 32)
continue;
ggml_vk_device d;
d.index = i;
d.type = properties.deviceType;
d.type = dev_props.deviceType;
d.heapSize = heapSize;
d.vendor = strdup(ggml_vk_getVendorName(properties.vendorID));
d.subgroupSize = subgroupProperties.subgroupSize;
d.bufferAlignment = properties.limits.minStorageBufferOffsetAlignment;
d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
d.subgroupSize = subgroup_props.subgroupSize;
d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
std::string name(properties.deviceName);
if (has_maintenance4) {
d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
} else {
d.maxAlloc = dev_props3.maxMemoryAllocationSize;
}
std::string name(dev_props.deviceName);
size_t n_idx = ++count_by_name[name];
if (n_idx > 1) {
name += " (" + std::to_string(n_idx) + ")";
@ -413,12 +437,6 @@ vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, v
static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
if (minStorageBufferOffsetAlignment == 0) {
vk::PhysicalDeviceProperties deviceProperties;
deviceProperties = komputeManager()->physicalDevice()->getProperties();
vk::PhysicalDeviceLimits deviceLimits = deviceProperties.limits;
minStorageBufferOffsetAlignment = deviceLimits.minStorageBufferOffsetAlignment;
}
// If offset is already aligned, return it directly
if (offset % minStorageBufferOffsetAlignment == 0) {
@ -1731,10 +1749,11 @@ struct ggml_backend_kompute_buffer_type_context {
int device;
int device_ref = 0;
uint64_t buffer_alignment;
uint64_t max_alloc;
std::string name;
ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment)
: device(device), buffer_alignment(buffer_alignment), name(ggml_kompute_format_name(device)) {}
ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
: device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
};
static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
@ -1842,6 +1861,11 @@ static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer
return ctx->buffer_alignment;
}
static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
return ctx->max_alloc;
}
static bool ggml_backend_kompute_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
GGML_UNUSED(buft);
return ggml_backend_is_kompute(backend);
@ -1851,6 +1875,7 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
/* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .supports_backend = */ ggml_backend_kompute_buffer_type_supports_backend,
/* .is_host = */ NULL,
@ -1865,7 +1890,7 @@ ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
for (const auto & dev : devices) {
vec.push_back({
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment)
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
});
}
return vec;

View file

@ -19,6 +19,7 @@ struct ggml_vk_device {
const char * vendor;
int subgroupSize;
uint64_t bufferAlignment;
uint64_t maxAlloc;
};
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);