Skip to content

Tiled parallelization

Tiled parallelization action for MLIR loop transformations.

This module implements the tiled parallelization transformation action, which applies tiling with parallelization using forall constructs.

TiledParallelization(parameters, state=None, /, *, iterators=None, **extras)

Bases: Tiling

Class representing Tiled Parallelization action

Source code in mlir_rl_artifact/actions/tiled_parallelization.py
def __init__(
    self,
    parameters: list[int],
    state: Optional[OperationState] = None,
    /, *,
    iterators: Optional[list[str]] = None,
    **extras
):
    if (state is None) == (iterators is None):
        raise ValueError("Either state or iterators must be provided and not both")
    if state:
        iterators = [loop.iterator_type.value for loop in state.operation_features.nested_loops]
    super().__init__(parameters, state, iterators=iterators, **extras)

    self.parallel_params = [
        0 if iterator == IteratorType.Reduction.value
        else param for param, iterator in zip(self.parameters, iterators)
    ]
    self.tiling_params = [
        param if iterator == IteratorType.Reduction.value
        else 0 for param, iterator in zip(self.parameters, iterators)
    ]