Added support for SFTTrainer checkpoint models and adapter models containing some non-LoRA weights

The previous code triggers an unexpected name error and calls sys.exit(1) (lines 350-351 current version) even if a single weight in the lora_model is not a lora_A, lora_B, or base layer weight.
This edit collects the names of all LoRA weights in the model before the for loop in line 341 (current version). And in line 350 (edit version), the subsequent operations are performed only on the LoRA and base layer weights, ignoring any non-LoRA weights in the lora_model.
Hopefully, this helps by allowing the script to extract LoRA weights and convert LoRA to GGUF for adapters containing one or more non-LoRA weights.
This commit is contained in:
Victor Oluwadare 2024-10-08 02:35:08 +01:00 committed by GitHub
parent 6374743747
commit c6396aa4bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -338,28 +338,39 @@ if __name__ == '__main__':
def get_tensors(self) -> Iterator[tuple[str, Tensor]]: def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map: dict[str, PartialLoraTensor] = {} tensor_map: dict[str, PartialLoraTensor] = {}
for name, tensor in lora_model.items(): # The following edits will enable conversion for: SFTTrainer checkpoint adapter models and other adapter models that contain weights besides LoRA weights
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name:
continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
sys.exit(1)
if base_name in tensor_map: # Here, we first get the items with the 'lora_' substring
if is_lora_a: lora_model_items_name = [name for name,_ in lora_model.items()]
tensor_map[base_name].A = tensor lora_model_items_with_lora_tensors = [name for name in lora_model_items_name if 'lora_' in name]
for name, tensor in lora_model.items():
# Check for only LoRA finetuned weights and base layer weights
if (name in lora_model_items_with_lora_tensors) or (".base_layer.weight" in name):
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name:
continue
# we will either have a lora weight or a base layer weight, this error becomes trivial
# logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
# sys.exit(1)
if base_name in tensor_map:
if is_lora_a:
tensor_map[base_name].A = tensor
else:
tensor_map[base_name].B = tensor
else: else:
tensor_map[base_name].B = tensor if is_lora_a:
else: tensor_map[base_name] = PartialLoraTensor(A=tensor)
if is_lora_a: else:
tensor_map[base_name] = PartialLoraTensor(A=tensor) tensor_map[base_name] = PartialLoraTensor(B=tensor)
else:
tensor_map[base_name] = PartialLoraTensor(B=tensor)
for name, tensor in tensor_map.items(): for name, tensor in tensor_map.items():
assert tensor.A is not None assert tensor.A is not None