Tiled parallelization action for MLIR loop transformations.
This module implements the tiled parallelization transformation action, which applies
tiling with parallelization using forall constructs.
TiledParallelization(parameters, state=None, /, *, iterators=None, **extras)
Bases: Tiling
Class representing Tiled Parallelization action
Source code in mlir_rl_artifact/actions/tiled_parallelization.py
| def __init__(
self,
parameters: list[int],
state: Optional[OperationState] = None,
/, *,
iterators: Optional[list[str]] = None,
**extras
):
if (state is None) == (iterators is None):
raise ValueError("Either state or iterators must be provided and not both")
if state:
iterators = [loop.iterator_type.value for loop in state.operation_features.nested_loops]
super().__init__(parameters, state, iterators=iterators, **extras)
self.parallel_params = [
0 if iterator == IteratorType.Reduction.value
else param for param, iterator in zip(self.parameters, iterators)
]
self.tiling_params = [
param if iterator == IteratorType.Reduction.value
else 0 for param, iterator in zip(self.parameters, iterators)
]
|