Skip to content

Generating with Flax fails when using Causal Language models #18884

@SamKG

Description

@SamKG

System Info

  • transformers version: 4.21.1
  • Platform: Linux-4.18.0-372.19.1.el8_6.x86_64-x86_64-with-glibc2.28
  • Python version: 3.10.4
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.1+cu102 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.6.0 (gpu)
  • Jax version: 0.3.17
  • JaxLib version: 0.3.15
  • Using GPU in script?: Yes (Nvidia A100)
  • Using distributed or parallel set-up in script?: No

Who can help?

@patrickvonplaten
@Narsil
@patil-suraj

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the following code snippet:

from jax import numpy as jnp
import transformers

model = transformers.FlaxAutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")

entence = "Paris is one of the densest populated areas in Europe."
input_ids = tokenizer(sentence, return_tensors="jax")["input_ids"]

model.generate(input_ids)

Expected behavior

Expected behavior is that the model generates completions for the given input id.

Observed behavior is that the following error is thrown:

File ~/.conda/envs/lm-extraction/lib/python3.10/site-packages/jax/_src/lax/lax.py:4577, in _check_same_dtypes(name, ignore_fp_precision, *ttypes)
   4575   equiv = _JNP_FUNCTION_EQUIVALENTS[name]
   4576   msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 4577 raise TypeError(msg.format(name, ", ".join(map(str, types))))

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got int32, float32.

This seems to be a type mismatch error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions