Skip to content

Observation

Observation encoding for operation state representation.

This module provides classes for encoding operation features into observation tensors used by the RL policy and value networks. It includes components for operation features, producer features, action history, action masks, and loop counts.

ObservationPart

Abstract base class for observation parts.

size() classmethod

Get the size of this observation part.

Returns:

Type Description
int

The size of the observation part.

Source code in mlir_rl_artifact/observation.py
@classmethod
def size(cls) -> int:
    """Get the size of this observation part.

    Returns:
        The size of the observation part.
    """
    raise NotImplementedError

from_state(state) classmethod

Create the observation part from the current state.

Parameters:

Name Type Description Default
state OperationState

The current operation state.

required

Returns:

Type Description
Tensor

The observation part tensor.

Source code in mlir_rl_artifact/observation.py
@classmethod
def from_state(cls, state: OperationState) -> torch.Tensor:
    """Create the observation part from the current state.

    Args:
        state: The current operation state.

    Returns:
        The observation part tensor.
    """
    raise NotImplementedError

OpFeatures

Bases: ObservationPart

Class representing operation features in the observation

Attributes:

Name Type Description
arith_ops

List of supported arithmetic operations

__formula_str_to_list(formula) staticmethod

Turns assignement formula to a list of (index, factor)

Example

formula: "%x1 - %x2 + %x3 * 5 - %x5 * 3"

returns: [('%x1', 1), ('%x2', -1), ('%x3', 5), ('%x5', -3)]

Parameters:

Name Type Description Default
formula str

the formula as a string input

required

Returns:

Type Description
list[tuple[str, int]]

list of (index, factor) pairs

Source code in mlir_rl_artifact/observation.py
@staticmethod
def __formula_str_to_list(formula: str) -> list[tuple[str, int]]:
    """Turns assignement formula to a list of (index, factor)

    Example:
        formula: `"%x1 - %x2 + %x3 * 5 - %x5 * 3"`

        returns: `[('%x1', 1), ('%x2', -1), ('%x3', 5), ('%x5', -3)]`

    Args:
        formula: the formula as a string input

    Returns:
        list of (index, factor) pairs
    """
    formula = formula + ' +'
    terms = formula.split(' ')

    running_factor = 1
    running_term = None

    save = []

    for term in terms:

        if term.startswith('%'):
            running_term = term
        elif term == '+':
            save.append((running_term, running_factor))
            running_factor = 1
        elif term == '-':
            save.append((running_term, running_factor))
            running_factor = -1
        elif term.isnumeric():
            running_factor *= int(term)

    if save[0][0] is None:
        save = save[1:]

    return save

ProducerOpFeatures

Bases: OpFeatures

Class representing producer operation features in the observation

ActionHistory

Bases: ObservationPart

Class representing action history in the observation

ActionMask

Bases: ObservationPart

Class representing action mask in the observation

NumLoops

Bases: ObservationPart

Class representing number of loops in the observation

Observation

Class to manage creation and use of observations

Attributes:

Name Type Description
parts list[type[ObservationPart]]

List of observation parts

cumulative_sizes() classmethod

Get cumulative sizes of all observation parts.

Returns:

Type Description
list[int]

List of cumulative sizes of all observation parts.

Source code in mlir_rl_artifact/observation.py
@classmethod
def cumulative_sizes(cls) -> list[int]:
    """Get cumulative sizes of all observation parts.

    Returns:
        List of cumulative sizes of all observation parts.
    """
    sizes = [0]
    for part in cls.parts:
        sizes.append(sizes[-1] + part.size())
    return sizes

part_number(part) classmethod

Get the index of a part in the observation.

Parameters:

Name Type Description Default
part type[ObservationPart]

The part to get the index of.

required

Returns:

Type Description
int

The index of the part in the observations.

Source code in mlir_rl_artifact/observation.py
@classmethod
def part_number(cls, part: type[ObservationPart]) -> int:
    """Get the index of a part in the observation.

    Args:
        part: The part to get the index of.

    Returns:
        The index of the part in the observations.
    """
    return cls.parts.index(part)

get_part(obs, part, squeeze=True) classmethod

Get a specific part of the observation.

Parameters:

Name Type Description Default
obs Tensor

The observation tensor.

required
part type[ObservationPart]

The part to get.

required
squeeze bool

Whether to squeeze the part if it has a size of 1. i.e return [batch_size] instead of [batch_size, 1]

True

Returns:

Type Description
Tensor

The tensor representing the part.

Source code in mlir_rl_artifact/observation.py
@classmethod
def get_part(cls, obs: torch.Tensor, part: type[ObservationPart], squeeze: bool = True) -> torch.Tensor:
    """Get a specific part of the observation.

    Args:
        obs: The observation tensor.
        part: The part to get.
        squeeze: Whether to squeeze the part if it has a size of 1.
            i.e return [batch_size] instead of [batch_size, 1]

    Returns:
        The tensor representing the part.
    """
    part_idx = cls.part_number(part)
    cum_sizes = cls.cumulative_sizes()
    start = cum_sizes[part_idx]
    if part.size() == 1 and squeeze:
        return obs[:, start]
    end = cum_sizes[part_idx + 1]
    return obs[:, start:end]

get_parts(obs, *parts) classmethod

Get multiple parts of the observation in a single tensor.

Parameters:

Name Type Description Default
obs Tensor

The observation tensor.

required
*parts type[ObservationPart]

The parts to get.

()

Returns:

Type Description
Tensor

The tensor representing the parts.

Source code in mlir_rl_artifact/observation.py
@classmethod
def get_parts(cls, obs: torch.Tensor, *parts: type[ObservationPart]) -> torch.Tensor:
    """Get multiple parts of the observation in a single tensor.

    Args:
        obs: The observation tensor.
        *parts: The parts to get.

    Returns:
        The tensor representing the parts.
    """
    return torch.cat([cls.get_part(obs, part, False) for part in parts], dim=1)

from_state(state) classmethod

Create the full observation from the current state.

Parameters:

Name Type Description Default
state OperationState

The current operation state.

required

Returns:

Type Description
Tensor

The full observation tensor.

Source code in mlir_rl_artifact/observation.py
@classmethod
def from_state(cls, state: OperationState) -> torch.Tensor:
    """Create the full observation from the current state.

    Args:
        state: The current operation state.

    Returns:
        The full observation tensor.
    """
    obs_parts = [part.from_state(state) for part in cls.parts]
    return torch.cat(obs_parts).unsqueeze(0)

from_states(states) classmethod

Create the full observation for all the states.

Parameters:

Name Type Description Default
states list[OperationState]

The list of operation states.

required

Returns:

Type Description
Tensor

The full observation tensor. States concatenated along the first dimension.

Source code in mlir_rl_artifact/observation.py
@classmethod
def from_states(cls, states: list[OperationState]) -> torch.Tensor:
    """Create the full observation for all the states.

    Args:
        states: The list of operation states.

    Returns:
        The full observation tensor. States concatenated along the first dimension.
    """
    return torch.cat([cls.from_state(s) for s in states])