Added cache_block_outputs parameter to handle models with non-regular structure such as ChatGLM#1479
Conversation
There was a problem hiding this comment.
Thanks for adding more flexibility to GPTQ @AlexKoff88 ! This PR looks good. I have some questions regarding cache_block_outputs :
- Did you run some tests to see if we really use less memory ? Because it seems like we are recomputing the inputs from scratch to get the output of the previous block that we get here (L450):
layer_output = block(*layer_inputs[j], **layer_input_kwargs[j]). To me, we get the same input in both cases. - The issue with the chatglm is that the entire input was required for the forward of the attention block. This was solved by passing *layer_inputs[j] and keeping in the whole input previously. I was able to run chatgml model with
cache_block_outputs=Trueandcache_block_outputs = False. Here the snippet:
from transformers import GPTQConfig, AutoTokenizer, AutoModelForCausalLM
model_id = "THUDM/chatglm2-6b"
dataset = ["auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
gptq_config = GPTQConfig(bits=4, dataset = dataset, tokenizer=tokenizer, block_name_to_quantize = "transformer.encoder.layers", cache_block_outputs=False)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config = gptq_config)
Thanks @SunMarc for the prompt review. I get the following error with your code if I don't set |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks a lot for adding this. I've left a few comments. For the test, it was a mistake on my side. I thought that I added the arg but as it was not present in the init, so it was always caching the block. Can you add a test in test_quantization.py inside gptq folder to test cache_block_outputs ? Apart from that, that should be good to go !
|
|
Awesome @AlexKoff88 ! This should be good to be merged once the conflits are resolved =) |
|
Done |
SunMarc
left a comment
There was a problem hiding this comment.
LGTM ! Thanks again @AlexKoff88 for iterating !
fxmarty
left a comment
There was a problem hiding this comment.
LGTM, thank you @AlexKoff88
cache_block_outputsenables the collection of the block output to speed up GPTQ process. However, it does not work for some models such as ChatGLM where the LayerNorm is the first layer in the block.Just compare:
OPT structure:
model.decoder.layers.0.self_attn
model.decoder.layers.0.self_attn.k_proj
model.decoder.layers.0.self_attn.v_proj
model.decoder.layers.0.self_attn.q_proj
model.decoder.layers.0.self_attn.out_proj
model.decoder.layers.0.activation_fn
model.decoder.layers.0.self_attn_layer_norm
model.decoder.layers.0.fc1
model.decoder.layers.0.fc2
model.decoder.layers.0.final_layer_norm
ChatGLM structure:
transformer.encoder.layers.0
transformer.encoder.layers.0.input_layernorm
transformer.encoder.layers.0.self_attention
transformer.encoder.layers.0.self_attention.query_key_value
transformer.encoder.layers.0.self_attention.core_attention
transformer.encoder.layers.0.self_attention.core_attention.attention_dropout
transformer.encoder.layers.0.self_attention.dense
transformer.encoder.layers.0.post_attention_layernorm
transformer.encoder.layers.0.mlp
transformer.encoder.layers.0.mlp.dense_h_to_4h
transformer.encoder.layers.0.mlp.dense_4h_to_h
The solution is to disable SA block output caching and collect the quantizing block inputs starting from the beginning of the model. It slows down the optimization a bit but works more stable.
Related PR to Transformers: huggingface/transformers#27032