feat: Make number of experts configurable
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
parent
9a81fafb48
commit
8aa493705c
1 changed files with 1 additions and 2 deletions
|
@ -384,7 +384,7 @@ class TensorNameMap:
|
||||||
|
|
||||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||||
|
|
||||||
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
|
def __init__(self, arch: MODEL_ARCH, n_blocks: int, n_experts: int = 60):
|
||||||
self.mapping = {}
|
self.mapping = {}
|
||||||
for tensor, keys in self.mappings_cfg.items():
|
for tensor, keys in self.mappings_cfg.items():
|
||||||
if tensor not in MODEL_TENSORS[arch]:
|
if tensor not in MODEL_TENSORS[arch]:
|
||||||
|
@ -398,7 +398,6 @@ class TensorNameMap:
|
||||||
if tensor not in MODEL_TENSORS[arch]:
|
if tensor not in MODEL_TENSORS[arch]:
|
||||||
continue
|
continue
|
||||||
# TODO: make this configurable
|
# TODO: make this configurable
|
||||||
n_experts = 60
|
|
||||||
for xid in range(n_experts):
|
for xid in range(n_experts):
|
||||||
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
||||||
self.mapping[tensor_name] = (tensor, tensor_name)
|
self.mapping[tensor_name] = (tensor, tensor_name)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue