Skip to content

Changing a single example affects forward pass for other examples in a batch #335

@mayank31398

Description

@mayank31398

@stas00 , I wrote this script to do get the conditional NLL for the labels given the context.
Tried different batches with only the first example changing and rest of the examples fixed in the batch. However, after a certain point, the changing of first examples, affects the NLL for other examples.

This is not supposed to happen.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "bigscience/bloom"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    max_memory={0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB',
                4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'},
    torch_dtype=torch.bfloat16,
)

model.eval()

def compute_gen_loss(lm_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = labels.shape[0]
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    loss = loss.reshape(batch_size, -1)
    loss = loss.sum(dim=-1) / (shift_labels != -100).sum(dim=-1)
    return loss


def pad_ids(arrays, padding, max_length=-1):
    if (max_length < 0):
        max_length = max(list(map(len, arrays)))

    arrays = [[padding] * (max_length - len(array)) +
              array for array in arrays]

    return arrays


def forward(text: list, labels: str, conditional: bool = True):
    input_tokens = tokenizer(text).input_ids
    label_tokens = tokenizer(labels).input_ids

    input_ids = [x + y for (x, y) in zip(input_tokens, label_tokens)]
    attention_mask = [(len(x) + len(y)) * [1]
                      for (x, y) in zip(input_tokens, label_tokens)]
    if (conditional):
        labels = [[-100] * len(x) + y for (x, y)
                  in zip(input_tokens, label_tokens)]
    else:
        labels = input_ids

    pad = 3
    input_ids = pad_ids(input_ids, pad)
    attention_mask = pad_ids(attention_mask, 0)
    # labels need to be on output device
    labels = pad_ids(labels, -100)

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    lm_logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    ).logits

    print(compute_gen_loss(lm_logits, labels).cpu().tolist())

text = [
    "DeepSpeed",
    "DeepSpeed is a",
    "DeepSpeed is a machine",
    "DeepSpeed is a machine learning framework",
]
labels = [
    " is awesome.",
    " good person.",
    " that can wipe out the planet.",
    " for generating memes.",
]
forward(text, labels)

labels[0] = " is awesome. really awesome"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised"
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed."
forward(text, labels)

labels[0] = " is awesome. really awesome. Try it. You'll be surprised. BLOOM was trained using DeepSpeed. Oh no the values are bugging out now."
forward(text, labels)

Output:

[4.8125, 5.1875, 3.296875, 5.09375]
[5.625, 5.1875, 3.296875, 5.09375]
[4.375, 5.1875, 3.296875, 5.09375]
[4.0625, 5.1875, 3.28125, 5.09375]
[3.953125, 5.1875, 3.28125, 5.0625]
[4.25, 5.1875, 3.296875, 5.09375]

Value drops from 3.29 to 3.28 in column 2 when only example for column 0 is changed. Even column 3 changes in last case.
Only column 0 is supposed to change here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions