You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue tracks the migration of quantize_ per-workflow configuration from Callables to configs..
We are migrating the way quantize_ workflows are configured from callables (tensor subclass inserters) to direct configuration (config objects). Motivation: align with the rest of the ecosystem, enable inspection of configs after instantiation, remove a common source of confusion.
What is changing:
Specifically, here is how the signature of quantize_'s second argument will change:
## torchao v0.8.0 and before#defquantize(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
...,
): ...
## torchao v0.9.0#defquantize(
model: torch.nn.Module,
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
...,
): ...
## torchao v0.10.0 or later (exact version TBD)#defquantize(
model: torch.nn.Module,
config: AOBaseConfig,
...,
): ...
the name of the second argument to quantize_ changed from apply_tensor_subclass to config. Since the vast majority of callsites today are passing in configuration with a positional argument, this change should not affect most people.
the type of the second argument to quantize_ will change from Callable[[torch.nn.Module], torch.nn.Module] to config: AOBaseConfig, following a deprecation process detailed below.
for individual workflows, the user facing API name changed from snake case (int8_weight_only) to camel case (Int8WeightOnlyConfig). All argument names for each config are kept as-is. We will keep the old snake case names (int8_weight_only) around and alias them to the new names (int8_weight_only = Int8WeightOnlyConfig), to avoid breaking callsites. We plan to keep the old names forever. Here are all the workflow config name changes:
old name (will keep working)
new name (recommended)
int4_weight_only
Int4WeightOnlyConfig
float8_dynamic_activation_float8_weight
Float8DynamicQuantizationFloat8WeightConfig
float8_static_activation_float8_weight
Float8StaticActivationFloat8WeightConfig
float8_weight_only
Float8WeightOnlyConfig
fpx_weight_only
FPXWeightOnlyConfig
gemlite_uintx_weight_only
GemliteUIntXWeightOnlyConfig
int4_dynamic_activation_int4_weight
Int4DynamicActivationInt4WeightConfig
int8_dynamic_activation_int4_weight
Int8DynamicActivationInt4WeightConfig
int8_dynamic_activation_int8_semi_sparse_weight
n/a (deprecated)
int8_dynamic_activation_int8_weight
Int8DynamicActivationInt8WeightConfig
int8_weight_only
Int8WeightOnlyConfig
uintx_weight_only
UIntXWeightOnlyConfig
Configuration for prototype workflows using quantize_ will be migrated at a later time. sparsify_ will be migrated in a similar fashion at a later time.
How these changes can affect you:
If you are a user of existing quantize_ API workflows and are passing in config by a positional argument (quantize_(model, int8_weight_only(group_size=128))), you are not affected. This syntax will keep working going forward. You have the option to migrate your callsite to the new config name (quantize_(model, Int8WeightOnlyConfig(group_size=128)) at your own pace.
If you are a user of existing quantize_ API workflows and are passing in config by a keyword argument (quantize_(model, tensor_subclass_inserter=int8_weight_only(group_size=128))), your callsite will break. You will need to change your callsite to quantize_(model, config=int8_weight_only(group_size=128)). We don't expect many people to be in this bucket.
If you are a user of sparsify_, you are not affected for now and a similar change will happen in a future version of torchao.
This migration will be a two step process:
in torchao v0.9.0, we will enable the new syntax while starting the deprecation process for the old syntax.
in torchao v.0.10.0 or later, we will remove the old syntax
We will keep the old callable syntax supported by quantize_ for one release cycle, and delete it afterwards. We will keep the old names as aliases for new names going forward (example: int4_weight_only as an alias of Int4WeightOnlyConfig) to keep existing callsites working without changes.
impact on API users
If you are just using the torchao quantize_ API as specified in the README, this is not BC-breaking. For example, the following syntax will keep working.
quantize_(model, int8_weight_only())
Note that the type of the object created by int8_weight_only()will change from a Callable to a config. You have the option to migrate to the explicit config creation, as follows:
quantize_(model, Int8WeightOnlyConfig())
user facing API changes
signature of quantize_
## before#defquantize(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
...,
): ...
## after - intermediate state, support both old and new for one release#defquantize(
model: torch.nn.Module,
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
...,
): ...
## after - long term state#defquantize(
model: torch.nn.Module,
config: AOBaseConfig,
...,
): ...
usage example
An example for int4_weight_only
## before#quantize_(m, int4_weight_only(group_size=32))
## after, with new user facing names#quantize_(m, Int4WeightOnlyConfig(group_size=32))
## AND, after, with BC names#quantize_(m, int4_weight_only(group_size=32))
developer facing changes
See the PR details for examples, but they can be summarized as:
## old## quantize_ calls the instance of calling this function on each module of the modeldefint4_weight_only(group_size: int, ...) ->Callable:
defnew_callable(weight: torch.Tensor):
# configuration is captured here via local variables
...
# return type is a Callablereturn_get_linear_subclass_inserter(new_callable)
## new## config base classclassAOBaseConfig(abc.ABC):
pass# user facing configuration of a workflow@dataclassclassInt4WeightOnlyConfig(AOBaseConfig):
group_size: int=128
...
# not user facing transform of a module according to a worfklow's configuration@register_quantize_module_handler(Int4WeightOnlyConfig)def_int4_weight_only_transform(
module: torch.nn.Module,
config: Int4WeightOnlyConfig,
) ->torch.nn.Module:
# map to AQT, not user facing
...
summary
This issue tracks the migration of
quantize_per-workflow configuration from Callables to configs..We are migrating the way
quantize_workflows are configured from callables (tensor subclass inserters) to direct configuration (config objects). Motivation: align with the rest of the ecosystem, enable inspection of configs after instantiation, remove a common source of confusion.What is changing:
Specifically, here is how the signature of
quantize_'s second argument will change:quantize_changed fromapply_tensor_subclasstoconfig. Since the vast majority of callsites today are passing in configuration with a positional argument, this change should not affect most people.quantize_will change fromCallable[[torch.nn.Module], torch.nn.Module]toconfig: AOBaseConfig, following a deprecation process detailed below.int8_weight_only) to camel case (Int8WeightOnlyConfig). All argument names for each config are kept as-is. We will keep the old snake case names (int8_weight_only) around and alias them to the new names (int8_weight_only = Int8WeightOnlyConfig), to avoid breaking callsites. We plan to keep the old names forever. Here are all the workflow config name changes:int4_weight_onlyInt4WeightOnlyConfigfloat8_dynamic_activation_float8_weightFloat8DynamicQuantizationFloat8WeightConfigfloat8_static_activation_float8_weightFloat8StaticActivationFloat8WeightConfigfloat8_weight_onlyFloat8WeightOnlyConfigfpx_weight_onlyFPXWeightOnlyConfiggemlite_uintx_weight_onlyGemliteUIntXWeightOnlyConfigint4_dynamic_activation_int4_weightInt4DynamicActivationInt4WeightConfigint8_dynamic_activation_int4_weightInt8DynamicActivationInt4WeightConfigint8_dynamic_activation_int8_semi_sparse_weightint8_dynamic_activation_int8_weightInt8DynamicActivationInt8WeightConfigint8_weight_onlyInt8WeightOnlyConfiguintx_weight_onlyUIntXWeightOnlyConfigConfiguration for prototype workflows using
quantize_will be migrated at a later time.sparsify_will be migrated in a similar fashion at a later time.How these changes can affect you:
quantize_API workflows and are passing in config by a positional argument (quantize_(model, int8_weight_only(group_size=128))), you are not affected. This syntax will keep working going forward. You have the option to migrate your callsite to the new config name (quantize_(model, Int8WeightOnlyConfig(group_size=128))at your own pace.quantize_API workflows and are passing in config by a keyword argument (quantize_(model, tensor_subclass_inserter=int8_weight_only(group_size=128))), your callsite will break. You will need to change your callsite toquantize_(model, config=int8_weight_only(group_size=128)). We don't expect many people to be in this bucket.quantize_API, you will need to use the new configuration system. Please see migration ofquantize_workflow configuration from callables to configs #1690 for details.sparsify_, you are not affected for now and a similar change will happen in a future version of torchao.This migration will be a two step process:
We will keep the old callable syntax supported by
quantize_for one release cycle, and delete it afterwards. We will keep the old names as aliases for new names going forward (example:int4_weight_onlyas an alias ofInt4WeightOnlyConfig) to keep existing callsites working without changes.impact on API users
If you are just using the torchao
quantize_API as specified in the README, this is not BC-breaking. For example, the following syntax will keep working.Note that the type of the object created by
int8_weight_only()will change from a Callable to a config. You have the option to migrate to the explicit config creation, as follows:user facing API changes
signature of quantize_
usage example
An example for
int4_weight_onlydeveloper facing changes
See the PR details for examples, but they can be summarized as:
migration status
quantize_ non-prototype workflow configuration
quantize_ prototype workflow configuration
Grep for callsites:
grep -r "quantize_(" torchao/prototypequantize_used here is a different function, so nothing to doexperimental
sparsify_
sparsify_to configs #1856tutorials (replace with new registration API)
replace docblocks and public facing descriptions with new names
verify partner integrations still work
confirmed two out of three here: vkuzo/pytorch_scripts#28
delete old path (one version after migration)
configargument #1861