feat: Make number of experts configurable

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
teleprint-me 2024-05-17 03:20:14 -04:00
parent 9a81fafb48
commit 8aa493705c
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -384,7 +384,7 @@ class TensorNameMap:
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]
def __init__(self, arch: MODEL_ARCH, n_blocks: int): def __init__(self, arch: MODEL_ARCH, n_blocks: int, n_experts: int = 60):
self.mapping = {} self.mapping = {}
for tensor, keys in self.mappings_cfg.items(): for tensor, keys in self.mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
@ -398,7 +398,6 @@ class TensorNameMap:
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
# TODO: make this configurable # TODO: make this configurable
n_experts = 60
for xid in range(n_experts): for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)