Base
Base action classes for MLIR loop transformations.
This module defines the abstract base class for transformation actions and provides the action interface that all concrete transformation actions must implement.
Action(arg1=None, arg2=None, /, *, operation_tag=None, **extras)
Base action class
Source code in mlir_rl_artifact/actions/base.py
__repr__()
String representation of the action with extra params
Source code in mlir_rl_artifact/actions/base.py
__str__()
from_str(state, action_str)
classmethod
Create an action from a string representation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
OperationState
|
current state to apply the action on |
required |
action_str
|
str
|
string representation of the action |
required |
Returns:
| Type | Description |
|---|---|
Action
|
action created from the string representation |
Source code in mlir_rl_artifact/actions/base.py
params_size()
classmethod
Return the size of the parameters in the index for this action type
Returns:
| Type | Description |
|---|---|
int
|
size of the parameters for this action type |
network_output_size()
classmethod
Return the size of the network output for this action type
Returns:
| Type | Description |
|---|---|
int
|
size of the network output for this action type |
mask_size()
classmethod
Return the size of the mask for this action type
Returns:
| Type | Description |
|---|---|
int
|
size of the mask for this action type |
history_size()
classmethod
Return the size of the history for this action type
Returns:
| Type | Description |
|---|---|
int
|
size of the history for this action type |
is_allowed(state)
classmethod
Check if this action type is allowed in the current state
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
OperationState
|
current state to check the action on |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the action is allowed, False otherwise |
Source code in mlir_rl_artifact/actions/base.py
action_mask(state)
classmethod
Return the action mask for this action type in the current state
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
OperationState
|
current state to check the action on |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
action mask for this action type, or None if not applicable |
Source code in mlir_rl_artifact/actions/base.py
action_history(seq)
classmethod
Return the action history for this action type in the current state
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq
|
list[Action]
|
sequence of actions in the current state |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
action history for this action type, or None if not applicable |
Source code in mlir_rl_artifact/actions/base.py
distribution(logits)
classmethod
Create a distribution for this action type based on the logits
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Logits for the action selection. |
required |
Returns:
| Type | Description |
|---|---|
Distribution
|
A distribution object for this action type. |
Source code in mlir_rl_artifact/actions/base.py
uniform_distribution(logits, num_loops)
classmethod
Create a uniform distribution for this action type based on the logits and number of loops
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Logits for the action selection. |
required |
num_loops
|
Tensor
|
Number of loops in the operation state. |
required |
Returns:
| Type | Description |
|---|---|
Distribution
|
A uniform distribution object for this action type. |
Source code in mlir_rl_artifact/actions/base.py
distribution_stats(distribution, index, eps_distribution, eps=None)
classmethod
Calculate the log probabilities and entropies for the distribution
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distribution
|
Distribution
|
The distribution to calculate stats for. |
required |
eps_distribution
|
Distribution | None
|
The epsilon distribution for exploration. |
required |
index
|
Tensor
|
The params index. |
required |
eps
|
float | None
|
Epsilon value for exploration. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
Log probabilities and entropies. |
Source code in mlir_rl_artifact/actions/base.py
sample(distribution, eps_distribution, num_loops, uniform, greedy)
classmethod
Sample an action based on the distribution
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distribution
|
Distribution
|
The distribution to sample from. |
required |
eps_distribution
|
Distribution
|
The epsilon distribution for exploration. |
required |
num_loops
|
Tensor
|
Number of loops in the operation state. |
required |
uniform
|
bool
|
Whether to sample uniformly. |
required |
greedy
|
bool
|
Whether to sample greedily. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Sampled action index. |
Source code in mlir_rl_artifact/actions/base.py
apply(module)
update_features(operation_features)
Update the operation features based on the action
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operation_features
|
OperationFeatures
|
The operation features to update. |
required |
Returns:
| Type | Description |
|---|---|
OperationFeatures
|
The updated operation features. |