fix flake8
This commit is contained in:
parent
84d9277fe2
commit
61d44b0089
1 changed files with 3 additions and 0 deletions
|
@ -20,9 +20,11 @@ SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_
|
||||||
|
|
||||||
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
|
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
|
||||||
|
|
||||||
|
|
||||||
def get_short_name(long_quant_name):
|
def get_short_name(long_quant_name):
|
||||||
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
||||||
|
|
||||||
|
|
||||||
def get_head_sizes(type_k, type_v):
|
def get_head_sizes(type_k, type_v):
|
||||||
if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16":
|
if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16":
|
||||||
return [64, 128, 256]
|
return [64, 128, 256]
|
||||||
|
@ -30,6 +32,7 @@ def get_head_sizes(type_k, type_v):
|
||||||
return [64, 128]
|
return [64, 128]
|
||||||
return [128]
|
return [128]
|
||||||
|
|
||||||
|
|
||||||
for filename in glob("*.cu"):
|
for filename in glob("*.cu"):
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue