Observation
Observation encoding for operation state representation.
This module provides classes for encoding operation features into observation tensors used by the RL policy and value networks. It includes components for operation features, producer features, action history, action masks, and loop counts.
ObservationPart
Abstract base class for observation parts.
size()
classmethod
Get the size of this observation part.
Returns:
| Type | Description |
|---|---|
int
|
The size of the observation part. |
from_state(state)
classmethod
Create the observation part from the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
OperationState
|
The current operation state. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The observation part tensor. |
Source code in mlir_rl_artifact/observation.py
OpFeatures
Bases: ObservationPart
Class representing operation features in the observation
Attributes:
| Name | Type | Description |
|---|---|---|
arith_ops |
List of supported arithmetic operations |
__formula_str_to_list(formula)
staticmethod
Turns assignement formula to a list of (index, factor)
Example
formula: "%x1 - %x2 + %x3 * 5 - %x5 * 3"
returns: [('%x1', 1), ('%x2', -1), ('%x3', 5), ('%x5', -3)]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula
|
str
|
the formula as a string input |
required |
Returns:
| Type | Description |
|---|---|
list[tuple[str, int]]
|
list of (index, factor) pairs |
Source code in mlir_rl_artifact/observation.py
ProducerOpFeatures
ActionHistory
ActionMask
NumLoops
Observation
Class to manage creation and use of observations
Attributes:
| Name | Type | Description |
|---|---|---|
parts |
list[type[ObservationPart]]
|
List of observation parts |
cumulative_sizes()
classmethod
part_number(part)
classmethod
Get the index of a part in the observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
part
|
type[ObservationPart]
|
The part to get the index of. |
required |
Returns:
| Type | Description |
|---|---|
int
|
The index of the part in the observations. |
Source code in mlir_rl_artifact/observation.py
get_part(obs, part, squeeze=True)
classmethod
Get a specific part of the observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
The observation tensor. |
required |
part
|
type[ObservationPart]
|
The part to get. |
required |
squeeze
|
bool
|
Whether to squeeze the part if it has a size of 1. i.e return [batch_size] instead of [batch_size, 1] |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The tensor representing the part. |
Source code in mlir_rl_artifact/observation.py
get_parts(obs, *parts)
classmethod
Get multiple parts of the observation in a single tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
The observation tensor. |
required |
*parts
|
type[ObservationPart]
|
The parts to get. |
()
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The tensor representing the parts. |
Source code in mlir_rl_artifact/observation.py
from_state(state)
classmethod
Create the full observation from the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
OperationState
|
The current operation state. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The full observation tensor. |
Source code in mlir_rl_artifact/observation.py
from_states(states)
classmethod
Create the full observation for all the states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
list[OperationState]
|
The list of operation states. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The full observation tensor. States concatenated along the first dimension. |