Skip to content

Tiled fusion

Tiled fusion action for MLIR loop transformations.

This module implements the tiled fusion transformation action, which applies tiling and fusion of producer-consumer operations.

TiledFusion(parameters, state=None, /, *, producer_tag=None, producer_operand_idx=None, **extras)

Bases: TiledParallelization

Class representing Tiled Fusion action

Source code in mlir_rl_artifact/actions/tiled_fusion.py
def __init__(
    self,
    parameters: list[int],
    state: Optional[OperationState] = None,
    /, *,
    producer_tag: Optional[str] = None,
    producer_operand_idx: Optional[int] = None,
    **extras
):
    args_is_none = [
        producer_tag is None,
        producer_operand_idx is None
    ]
    if (state is None) in args_is_none:
        raise ValueError("Either state or preprocessing attributes must be provided and not both")
    if state:
        producer_tag = state.producer_tag
        producer_operand_idx = state.producer_operand_idx
    assert producer_tag is not None and producer_operand_idx is not None
    super().__init__(
        parameters,
        state,
        producer_tag=producer_tag,
        producer_operand_idx=producer_operand_idx,
        **extras
    )

    self.producer_tag = producer_tag
    self.producer_operand_idx = producer_operand_idx
    self.producer_feats_updated = False

update_producer_features(state, bench_feats)

Update the features of the prducer after the fusion.

Note
  • This update modifies the bench features inplace
  • Currently we only support having one use in the containing op
Source code in mlir_rl_artifact/actions/tiled_fusion.py
def update_producer_features(self, state: OperationState, bench_feats: BenchmarkFeatures):
    """Update the features of the prducer after the fusion.

    Note:
        - This update modifies the bench features inplace
        - Currently we only support having one use in the containing op
    """
    prod_feats = state.producer_features.copy()

    self.__update_consumers_and_producers(prod_feats, state)

    self.__record_implicit_tiling(prod_feats, state)

    self.__insert_in_bench_feats(prod_feats, state, bench_feats)

    self.__handle_producer_original_op(bench_feats)

    self.producer_feats_updated = True