Vectorization action for MLIR loop transformations.
This module implements the vectorization transformation action, which applies
vectorization to operations and handles preprocessing steps like transpose and decompose.
Vectorization(state=None, /, *, requires_transpose=None, requires_decompose=None, decompose_tile_sizes=None, **extras)
Bases: Action
Class representing Vectorization action
Source code in mlir_rl_artifact/actions/vectorization.py
| def __init__(
self,
state: Optional[OperationState] = None,
/, *,
requires_transpose: Optional[bool] = None,
requires_decompose: Optional[bool] = None,
decompose_tile_sizes: Optional[list[int]] = None,
**extras
):
args_is_none = [
requires_transpose is None,
requires_decompose is None,
decompose_tile_sizes is None
]
if (state is None) in args_is_none:
raise ValueError("Either state or preprocessing attributes must be provided and not both")
if state:
op_feats = state.operation_features.copy()
if op_feats.operation_type not in [OperationType.Pooling, OperationType.Conv]:
requires_transpose, requires_decompose, decompose_tile_sizes = False, False, []
else:
if requires_transpose := self.__requires_transpose(op_feats):
op_feats.operation_name = 'linalg.conv_2d_nhwc_hwcf'
decompose_tile_sizes = []
if requires_decompose := self.__requires_decompose(op_feats):
decompose_tile_sizes = self.__decompose_tile_sizes(op_feats)
super().__init__(
state,
requires_transpose=requires_transpose,
requires_decompose=requires_decompose,
decompose_tile_sizes=decompose_tile_sizes,
vectorized=True,
**extras
)
self.preprocessing = []
if requires_transpose:
self.preprocessing.append(lambda m: transform_transpose_conv_2d(m, self.operation_tag))
if requires_decompose:
self.preprocessing.append(lambda m: transform_tile(m, self.operation_tag, decompose_tile_sizes))
self.preprocessing.append(lambda m: transform_decompose(m, self.operation_tag))
self.preprocessing.append(lambda m: transform_pre_vec(m, self.operation_tag))
|
__requires_decompose(operation_features)
classmethod
a.k.a is a two dimensional conv interface op
Source code in mlir_rl_artifact/actions/vectorization.py
| @classmethod
def __requires_decompose(cls, operation_features: OperationFeatures) -> bool:
"""a.k.a is a two dimensional conv interface op"""
if 'conv_2d' in operation_features.operation_name:
return True
if operation_features.operation_type == OperationType.Pooling and len(operation_features.nested_loops) >= 6:
return True
return False
|