Added support for SFTTrainer checkpoint models and adapter models containing one or more non-LoRA weights

My initial commit was more like a brute force.
The edits suggested by @FirstTimeEZ reduces the complexity.
This commit is contained in:
Victor Oluwadare 2024-10-08 20:31:43 +01:00 committed by GitHub
parent c6396aa4bb
commit c2c2626ec6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -338,16 +338,8 @@ if __name__ == '__main__':
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map: dict[str, PartialLoraTensor] = {}
# The following edits will enable conversion for: SFTTrainer checkpoint adapter models and other adapter models that contain weights besides LoRA weights
# Here, we first get the items with the 'lora_' substring
lora_model_items_name = [name for name,_ in lora_model.items()]
lora_model_items_with_lora_tensors = [name for name in lora_model_items_name if 'lora_' in name]
for name, tensor in lora_model.items():
# Check for only LoRA finetuned weights and base layer weights
if (name in lora_model_items_with_lora_tensors) or (".base_layer.weight" in name):
if ("lora_" in name) or (".base_layer.weight" in name):
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
@ -357,10 +349,6 @@ if __name__ == '__main__':
if ".base_layer.weight" in name:
continue
# we will either have a lora weight or a base layer weight, this error becomes trivial
# logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
# sys.exit(1)
if base_name in tensor_map:
if is_lora_a:
tensor_map[base_name].A = tensor
@ -371,6 +359,8 @@ if __name__ == '__main__':
tensor_map[base_name] = PartialLoraTensor(A=tensor)
else:
tensor_map[base_name] = PartialLoraTensor(B=tensor)
else:
pass
for name, tensor in tensor_map.items():
assert tensor.A is not None