Skip to content

KTO finetuning - float division by zero #1651

@jetlime

Description

@jetlime

I am attempting to finetune the LLama3-8B-Instruct model on the UNSW-NB15 dataset.

dataset = load_dataset("Jetlime/NF-UNSW-NB15-v2", streaming=False, split="train")

# Model
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, token=HUGGING_FACE_READ_TOKEN)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    attn_implementation=attn_implementation,
    token=HUGGING_FACE_READ_TOKEN
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

# Use only a small subset of the training set for a first finetuning trial
dataset = dataset.train_test_split(test_size=0.95, seed=123, stratify_by_column="Attack")
dataset_finetuning = dataset["train"]
dataset_finetuning

# Dataset({
#    features: ['input', 'output', 'Attack'],
#    num_rows: 113538
#})

# Creating the dataset columns required by the KTO finetuner
import random
def format_chat_template(row):
    row['prompt'] = row["input"]
    if random.randrange(0,1):
        row["label"] = False
        if row["output"] == 1:
            row["completion"] = '0'
        else:
            row["completion"] = '1'
    else:
        row["label"] = True
        row["completion"] = str(row["output"])
    return row

dataset_finetuning = dataset_finetuning.map(
    format_chat_template, num_proc=os.cpu_count()
)
dataset_finetuning

When I then perform the training,

training_args = KTOConfig(
    beta=0.1,
    desirable_weight=1.0,
    undesirable_weight=1.0,
    output_dir="./results-KTO/"
)

kto_trainer = KTOTrainer(
    model,
    args=training_args,
    train_dataset=dataset_finetuning,
    tokenizer=tokenizer,
)

kto_trainer.train()

# Tokenizing train dataset: 100%|██████████| 113538/113538 [01:26<00:00, 1313.22 examples/s]
# Extracting KL train dataset: 100%|██████████| 113538/113538 [00:08<00:00, 14077.15 examples/s]
# Processing tokenized train dataset: 100%|██████████| 113538/113538 [00:42<00:00, 2678.53 examples/s]
# Processing tokenized train KL dataset: 100%|██████████| 113538/113538 [00:40<00:00, 2805.04 examples/s]
# Filtering desirable examples: 100%|██████████| 113538/113538 [01:37<00:00, 1163.86 examples/s]
# Filtering undesirable examples: 100%|██████████| 113538/113538 [01:36<00:00, 1170.86 examples/s]

I obtain a Zero Division Error:

ZeroDivisionError                         Traceback (most recent call last)
Cell In[7], [line 8](vscode-notebook-cell:?execution_count=7&line=8)
      [1](vscode-notebook-cell:?execution_count=7&line=1) training_args = KTOConfig(
      [2](vscode-notebook-cell:?execution_count=7&line=2)     beta=0.1,
      [3](vscode-notebook-cell:?execution_count=7&line=3)     desirable_weight=1.0,
      [4](vscode-notebook-cell:?execution_count=7&line=4)     undesirable_weight=1.0,
      [5](vscode-notebook-cell:?execution_count=7&line=5)     output_dir="./results-KTO/"
      [6](vscode-notebook-cell:?execution_count=7&line=6) )
----> [8](vscode-notebook-cell:?execution_count=7&line=8) kto_trainer = KTOTrainer(
      [9](vscode-notebook-cell:?execution_count=7&line=9)     model,
     [10](vscode-notebook-cell:?execution_count=7&line=10)     args=training_args,
     [11](vscode-notebook-cell:?execution_count=7&line=11)     train_dataset=dataset_finetuning,
     [12](vscode-notebook-cell:?execution_count=7&line=12)     tokenizer=tokenizer,
     [13](vscode-notebook-cell:?execution_count=7&line=13) )
     [15](vscode-notebook-cell:?execution_count=7&line=15) kto_trainer.train()
     [16](vscode-notebook-cell:?execution_count=7&line=16) kto_trainer.save_model(new_model)

File ~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:599, in KTOTrainer.__init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer, data_collator, model_init, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, compute_metrics, model_adapter_name, ref_adapter_name)
    [597](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:597) des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
    [598](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:598) des_weight_upper_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33, 2)
--> [599](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:599) und_weight_lower_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1.33, 2)
    [600](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:600) und_weight_upper_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1, 2)
    [602](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:602) des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound

ZeroDivisionError: float division by zero

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions