Skip to content

Action space and transformation action implementations.

This module defines all available transformation actions for loop nest optimization, including tiling, parallelization, fusion, interchange, and vectorization. It provides the ActionSpace class for action sampling and distribution management.

ActionSpace

Class holding information about the action space

distributions(obs, selection_logits, *actions_logits) classmethod

Create a list of distributions for the actions based on the logits.

Parameters:

Name Type Description Default
obs Tensor

Observation tensor.

required
selection_logits Tensor

Logits for action selection.

required
*actions_logits Tensor

Logits for each action's parameters.

()

Returns:

Type Description
list[Distribution | None]

List of distributions for each action.

Source code in mlir_rl_artifact/actions/__init__.py
@classmethod
def distributions(cls, obs: torch.Tensor, selection_logits: torch.Tensor, *actions_logits: Optional[torch.Tensor]) -> list[Optional[Distribution]]:
    """Create a list of distributions for the actions based on the logits.

    Args:
        obs: Observation tensor.
        selection_logits: Logits for action selection.
        *actions_logits (torch.Tensor): Logits for each action's parameters.

    Returns:
        List of distributions for each action.
    """
    from mlir_rl_artifact.observation import Observation, ActionMask

    actions_mask = Observation.get_part(obs, ActionMask).bool()
    dists_list: list[Optional[Distribution]] = [
        Categorical(logits=selection_logits.where(actions_mask[:, :cls.size()], -torch.inf))
    ]
    cum_sizes = cls.cumulative_mask_sizes()
    for i, action in enumerate(cls.supported_actions):
        if not action.mask_size():
            dists_list.append(None)
            continue

        assert actions_logits[i] is not None, f"action '{action.symbol}' must have logits"
        masked_logits = actions_logits[i].where(actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]], -torch.inf)
        dists_list.append(action.distribution(masked_logits))

    return dists_list

uniform_distributions(obs) classmethod

Create a list of uniform distributions for the actions based on the observation.

Parameters:

Name Type Description Default
obs Tensor

Observation tensor.

required

Returns:

Type Description
list[Distribution | None]

List of distributions for each action.

Source code in mlir_rl_artifact/actions/__init__.py
@classmethod
def uniform_distributions(cls, obs: torch.Tensor) -> list[Optional[Distribution]]:
    """Create a list of uniform distributions for the actions based on the observation.

    Args:
        obs: Observation tensor.

    Returns:
        List of distributions for each action.
    """
    from mlir_rl_artifact.observation import Observation, ActionMask, NumLoops

    actions_mask = Observation.get_part(obs, ActionMask).bool()
    num_loops = Observation.get_part(obs, NumLoops)
    selection_mask = actions_mask[:, :cls.size()]
    dists_list: list[Optional[Distribution]] = [
        Categorical(logits=torch.zeros_like(selection_mask).where(selection_mask, -torch.inf))
    ]
    cum_sizes = cls.cumulative_mask_sizes()
    for i, action in enumerate(cls.supported_actions):
        if not action.mask_size():
            dists_list.append(None)
            continue

        action_mask = actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]]
        logits = torch.zeros_like(action_mask).where(action_mask, -torch.inf)
        dists_list.append(action.uniform_distribution(logits, num_loops))

    return dists_list