Skip to content

Base

Base action classes for MLIR loop transformations.

This module defines the abstract base class for transformation actions and provides the action interface that all concrete transformation actions must implement.

Action(arg1=None, arg2=None, /, *, operation_tag=None, **extras)

Action(*, operation_tag: str, **extras)
Action(state: OperationState, /, **extras)
Action(
    parameters: list[int],
    /,
    *,
    operation_tag: str,
    **extras,
)
Action(
    parameters: list[int],
    state: OperationState,
    /,
    **extras,
)

Base action class

Source code in mlir_rl_artifact/actions/base.py
def __init__(
    self,
    arg1: Optional[Union[OperationState, list[int]]] = None,
    arg2: Optional[OperationState] = None,
    /, *,
    operation_tag: Optional[str] = None,
    **extras
):
    if isinstance(arg1, OperationState):
        parameters = None
        state = arg1
    else:
        parameters = arg1
        state = arg2
    if (state is None) == (operation_tag is None):
        raise ValueError("Either state or operation tag must be provided and not both")
    if state:
        operation_tag = state.operation_tag
    self.operation_tag = operation_tag
    self.parameters = parameters
    self.extras = {'operation_tag': operation_tag, **extras}
    if 'process_params' in self.extras:
        del self.extras['process_params']

__repr__()

String representation of the action with extra params

Source code in mlir_rl_artifact/actions/base.py
def __repr__(self) -> str:
    """String representation of the action with extra params"""
    params_list = list(map(str, self.parameters)) if self.parameters else []
    params_list.extend(f'{k} = {v}' for k, v in self.extras.items())

    return f"{self.__class__.__name__}({', '.join(params_list)})"

__str__()

String representation of the action

Source code in mlir_rl_artifact/actions/base.py
def __str__(self) -> str:
    """String representation of the action"""
    return f"{self.symbol}({','.join(map(str, self.parameters)) if self.parameters else ''})"

from_str(state, action_str) classmethod

Create an action from a string representation

Parameters:

Name Type Description Default
state OperationState

current state to apply the action on

required
action_str str

string representation of the action

required

Returns:

Type Description
Action

action created from the string representation

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def from_str(cls, state: OperationState, action_str: str) -> 'Action':
    """Create an action from a string representation

    Args:
        state: current state to apply the action on
        action_str: string representation of the action

    Returns:
        action created from the string representation
    """
    symbol = action_str.split('(')[0]
    if symbol != cls.symbol:
        raise ValueError(f'Symbol mismatch for class {cls.__name__}: {symbol} != {cls.symbol}')

    parameters = list(map(int, action_str.split('(')[1].split(')')[0].split(',')))
    if not parameters:
        return cls(state)
    return cls(parameters, state, process_params=False)

params_size() classmethod

Return the size of the parameters in the index for this action type

Returns:

Type Description
int

size of the parameters for this action type

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def params_size(cls) -> int:
    """Return the size of the parameters in the index for this action type

    Returns:
        size of the parameters for this action type
    """
    return 0

network_output_size() classmethod

Return the size of the network output for this action type

Returns:

Type Description
int

size of the network output for this action type

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def network_output_size(cls) -> int:
    """Return the size of the network output for this action type

    Returns:
        size of the network output for this action type
    """
    return 0

mask_size() classmethod

Return the size of the mask for this action type

Returns:

Type Description
int

size of the mask for this action type

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def mask_size(cls) -> int:
    """Return the size of the mask for this action type

    Returns:
        size of the mask for this action type
    """
    return cls.network_output_size()

history_size() classmethod

Return the size of the history for this action type

Returns:

Type Description
int

size of the history for this action type

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def history_size(cls) -> int:
    """Return the size of the history for this action type

    Returns:
        size of the history for this action type
    """
    return 0

is_allowed(state) classmethod

Check if this action type is allowed in the current state

Parameters:

Name Type Description Default
state OperationState

current state to check the action on

required

Returns:

Type Description
bool

True if the action is allowed, False otherwise

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def is_allowed(cls, state: OperationState) -> bool:
    """Check if this action type is allowed in the current state

    Args:
        state: current state to check the action on

    Returns:
        True if the action is allowed, False otherwise
    """
    return True

action_mask(state) classmethod

Return the action mask for this action type in the current state

Parameters:

Name Type Description Default
state OperationState

current state to check the action on

required

Returns:

Type Description
Tensor | None

action mask for this action type, or None if not applicable

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def action_mask(cls, state: OperationState) -> Optional[torch.Tensor]:
    """Return the action mask for this action type in the current state

    Args:
        state: current state to check the action on

    Returns:
        action mask for this action type, or None if not applicable
    """
    return None

action_history(seq) classmethod

Return the action history for this action type in the current state

Parameters:

Name Type Description Default
seq list[Action]

sequence of actions in the current state

required

Returns:

Type Description
Tensor | None

action history for this action type, or None if not applicable

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def action_history(cls, seq: list['Action']) -> Optional[torch.Tensor]:
    """Return the action history for this action type in the current state

    Args:
        seq: sequence of actions in the current state

    Returns:
        action history for this action type, or None if not applicable
    """
    return None

distribution(logits) classmethod

Create a distribution for this action type based on the logits

Parameters:

Name Type Description Default
logits Tensor

Logits for the action selection.

required

Returns:

Type Description
Distribution

A distribution object for this action type.

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def distribution(cls, logits: torch.Tensor) -> Distribution:
    """Create a distribution for this action type based on the logits

    Args:
        logits: Logits for the action selection.

    Returns:
        A distribution object for this action type.
    """
    raise NotImplementedError

uniform_distribution(logits, num_loops) classmethod

Create a uniform distribution for this action type based on the logits and number of loops

Parameters:

Name Type Description Default
logits Tensor

Logits for the action selection.

required
num_loops Tensor

Number of loops in the operation state.

required

Returns:

Type Description
Distribution

A uniform distribution object for this action type.

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def uniform_distribution(cls, logits: torch.Tensor, num_loops: torch.Tensor) -> Distribution:
    """Create a uniform distribution for this action type based on the logits and number of loops

    Args:
        logits: Logits for the action selection.
        num_loops: Number of loops in the operation state.

    Returns:
        A uniform distribution object for this action type.
    """
    return cls.distribution(logits)

distribution_stats(distribution, index, eps_distribution, eps=None) classmethod

Calculate the log probabilities and entropies for the distribution

Parameters:

Name Type Description Default
distribution Distribution

The distribution to calculate stats for.

required
eps_distribution Distribution | None

The epsilon distribution for exploration.

required
index Tensor

The params index.

required
eps float | None

Epsilon value for exploration. Defaults to None.

None

Returns:

Type Description
tuple[Tensor, Tensor]

Log probabilities and entropies.

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def distribution_stats(cls, distribution: Distribution, index: torch.Tensor, eps_distribution: Optional[Distribution], eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]:
    """Calculate the log probabilities and entropies for the distribution

    Args:
        distribution: The distribution to calculate stats for.
        eps_distribution: The epsilon distribution for exploration.
        index: The params index.
        eps: Epsilon value for exploration. Defaults to None.

    Returns:
        Log probabilities and entropies.
    """
    raise NotImplementedError

sample(distribution, eps_distribution, num_loops, uniform, greedy) classmethod

Sample an action based on the distribution

Parameters:

Name Type Description Default
distribution Distribution

The distribution to sample from.

required
eps_distribution Distribution

The epsilon distribution for exploration.

required
num_loops Tensor

Number of loops in the operation state.

required
uniform bool

Whether to sample uniformly.

required
greedy bool

Whether to sample greedily.

required

Returns:

Type Description
Tensor

Sampled action index.

Source code in mlir_rl_artifact/actions/base.py
@classmethod
def sample(cls, distribution: Distribution, eps_distribution: Distribution, num_loops: torch.Tensor, uniform: bool, greedy: bool) -> torch.Tensor:
    """Sample an action based on the distribution

    Args:
        distribution: The distribution to sample from.
        eps_distribution: The epsilon distribution for exploration.
        num_loops: Number of loops in the operation state.
        uniform: Whether to sample uniformly.
        greedy: Whether to sample greedily.

    Returns:
        Sampled action index.
    """
    raise NotImplementedError

apply(module)

Apply action on the current code

Parameters:

Name Type Description Default
module Module

current code to apply the action on

required

Returns:

Type Description
Module

the new transformed code

Source code in mlir_rl_artifact/actions/base.py
def apply(self, module: Module) -> Module:
    """Apply action on the current code

    Args:
        module: current code to apply the action on

    Returns:
        the new transformed code
    """
    if not self.ready:
        return

    self._apply_ready(module)

update_features(operation_features)

Update the operation features based on the action

Parameters:

Name Type Description Default
operation_features OperationFeatures

The operation features to update.

required

Returns:

Type Description
OperationFeatures

The updated operation features.

Source code in mlir_rl_artifact/actions/base.py
def update_features(self, operation_features: OperationFeatures) -> OperationFeatures:
    """Update the operation features based on the action

    Args:
        operation_features: The operation features to update.

    Returns:
        The updated operation features.
    """
    return operation_features