Skip to content

Interchange

Interchange action for MLIR loop transformations.

This module implements the loop interchange transformation action, which reorders loop dimensions using different encoding methods (enumerate, pointers, continuous).

InterchangeMethod

Bases: Enum

Enumeration of interchange encoding methods.

Interchange(parameters, state=None, /, *, process_params=True, **extras)

Bases: Action

Class representing Interchange action

Source code in mlir_rl_artifact/actions/interchange.py
def __init__(
    self,
    parameters: list[int],
    state: Optional[OperationState] = None,
    /, *,
    process_params: bool = True,
    **extras
):
    if state and process_params:
        # Case where state is provided -> Parameters need processing

        assert len(parameters) == 1, 'uncompatible parameters for constructor call'
        parameter = parameters[0]
        num_loops = len(state.operation_features.nested_loops)
        match Interchange.method:
            case InterchangeMethod.EnumeratedCandidates:
                parameters = self.__get_candidates(num_loops)[parameter]
            case InterchangeMethod.ContinuousEncoding:
                parameters = self.__decode_continuous(parameter, num_loops)
            case InterchangeMethod.LevelsPointers:
                old_action = self.incomplete_interchange(state)
                if old_action:
                    perm_buffer = old_action.parameters
                else:
                    perm_buffer = []

                assert parameter not in perm_buffer, 'repitition detected in permutation'
                parameters = perm_buffer + [parameter]
                assert len(parameters) <= num_loops, 'interchange parameter exceeds number of loops'
                if len(parameters) < num_loops:
                    self.ready = False
    super().__init__(parameters, state, **extras)

log_std class-attribute instance-attribute

Log standard deviation for continuous interchange encoding.

__decode_continuous(parameter, num_loops) staticmethod

Decode the interchange parameter to get the loop permutation.

Parameters:

Name Type Description Default
parameter int

The interchange parameter.

required
num_loops int

The number of loops in the operation.

required

Returns:

Type Description
list[int]

The loop permutation.

Source code in mlir_rl_artifact/actions/interchange.py
@staticmethod
def __decode_continuous(parameter: int, num_loops: int) -> list[int]:
    """Decode the interchange parameter to get the loop permutation.

    Args:
        parameter: The interchange parameter.
        num_loops: The number of loops in the operation.

    Returns:
        The loop permutation.
    """
    x = parameter
    n = num_loops
    if x >= math.factorial(n):
        raise Exception(f"Invalid interchange parameter: {x}")

    # Convert x to factorial number
    fact_x = '0'
    q = x
    d = 2
    while q > 0:
        r = q % d
        q = q // d
        fact_x = str(r) + fact_x
        d += 1

    # Ensure to get exactly n digits
    fact_x = fact_x.zfill(n)[-n:]

    # Decode factorial number following Lehmer code
    nl = list(map(int, fact_x))
    for i in range(len(nl) - 2, -1, -1):
        for j in range(i + 1, len(nl)):
            if nl[j] >= nl[i]:
                nl[j] += 1

    return nl

__get_candidates(num_loops) staticmethod

Get all 1c 2c 3c possible interchanges for num_loops

Parameters:

Name Type Description Default
num_loops int

The number of loops in the operation.

required

Returns:

Type Description
list[list[int]]

The list of all possible interchanges.

Source code in mlir_rl_artifact/actions/interchange.py
@staticmethod
def __get_candidates(num_loops: int) -> list[list[int]]:
    """Get all 1c 2c 3c possible interchanges for `num_loops`

    Args:
        num_loops: The number of loops in the operation.

    Returns:
        The list of all possible interchanges.
    """

    interchanges = []
    for c in [1, 2, 3]:
        level_interchanges = []
        for _ in range(Config().max_num_loops - c):
            level_interchanges.append(list(range(num_loops)))
        for i in range(num_loops - c):
            params = list(range(num_loops))
            params[i], params[i + c] = params[i + c], params[i]
            level_interchanges[i] = params
        interchanges += level_interchanges
    return interchanges