fix flake8

This commit is contained in:
Johannes Gäßler 2024-05-29 17:09:25 +02:00
parent 84d9277fe2
commit 61d44b0089

View file

@ -20,9 +20,11 @@ SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
def get_short_name(long_quant_name): def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower() return long_quant_name.replace("GGML_TYPE_", "").lower()
def get_head_sizes(type_k, type_v): def get_head_sizes(type_k, type_v):
if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16": if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16":
return [64, 128, 256] return [64, 128, 256]
@ -30,6 +32,7 @@ def get_head_sizes(type_k, type_v):
return [64, 128] return [64, 128]
return [128] return [128]
for filename in glob("*.cu"): for filename in glob("*.cu"):
os.remove(filename) os.remove(filename)