Skip to content

State

State representation and feature extraction for MLIR operations.

This module provides data structures for representing benchmark and operation states, including features like loops, memory accesses, and operation types. It also provides functionality for extracting these features from MLIR AST using the AstDumper tool.

OperationType

Bases: Enum

Enumeration of operation types for MLIR operations.

Attributes:

Name Type Description
Generic

Generic operation type.

Matmul

Matrix multiplication operation type.

Conv

Convolutional operation type.

Pooling

Pooling operation type.

Add

Add operation type.

unknown

Unknown operation type.

IteratorType

Bases: Enum

Enumeration of iterator types for loop dimensions.

Attributes:

Name Type Description
Parallel

Parallel iterator type.

Reduction

Reduction iterator type.

NestedLoopFeatures(arg, lower_bound, upper_bound, step, iterator_type) dataclass

Dataclass to store the nested loops features data.

Attributes:

Name Type Description
arg str

The argument representing the loop iterator.

lower_bound int

The lower bound of the loop.

upper_bound int

The upper bound of the loop.

step int

The loop step.

iterator_type IteratorType

The type of the loop iterator.

copy()

Copy the current NestedLoopFeatures object.

Returns:

Type Description
NestedLoopFeatures

The copy.

Source code in mlir_rl_artifact/state.py
def copy(self) -> 'NestedLoopFeatures':
    """Copy the current [NestedLoopFeatures][..] object.

    Returns:
        The copy.
    """
    return NestedLoopFeatures(self.arg, self.lower_bound, self.upper_bound, self.step, self.iterator_type)

OperationFeatures(operation_name, operation_type, op_count, load_data, store_data, nested_loops, producers, consumers, vectorizable, pre_actions) dataclass

Dataclass to store the operation features data.

Attributes:

Name Type Description
operation_name str

The name of the mlir operation.

operation_type OperationType

The type of the operation.

op_count dict[str, int]

Number of arithmetic operations in the operation.

load_data list[list[str]]

List of load accesses where each load is represented by the list of access arguments.

store_data list[list[str]]

List of store accesses where each store is represented by the list of access arguments.

nested_loops list[NestedLoopFeatures]

List of nested loops where each loop is represented by the NestedLoopFeatures dataclass.

producers list[tuple[str, int]]

List of tags of operations that are consumed by the current operation along with their operand indices.

consumers list[tuple[str, int]]

List of tags of operations that consume the current operation.

vectorizable bool

Flag to indicate if the operation is vectorizable.

pre_actions list[Action]

List actions that are already applied the current operatiom.

copy()

Copy the current OperationFeatures object.

Returns:

Type Description
OperationFeatures

The copy.

Source code in mlir_rl_artifact/state.py
def copy(self) -> 'OperationFeatures':
    """Copy the current [OperationFeatures][..] object.

    Returns:
        The copy.
    """
    return OperationFeatures(
        self.operation_name,
        self.operation_type,
        self.op_count.copy(),
        [load.copy() for load in self.load_data],
        [store.copy() for store in self.store_data],
        [loop.copy() for loop in self.nested_loops],
        self.producers.copy(),
        self.consumers.copy(),
        self.vectorizable,
        self.pre_actions.copy()
    )

BenchmarkFeatures(bench_name, code, operation_tags, operations, root_exec_time) dataclass

Dataclass to store the benchmark features data.

Attributes:

Name Type Description
bench_name str

The name of the benchmark.

code str

The MLIR code of the benchmark.

operation_tags list[str]

List of operation tags.

operations dict[str, OperationFeatures]

List of operations where each operation is represented by the OperationFeatures dataclass.

root_exec_time int

Execution time of the benchmark in nanoseconds without any transformation.

copy()

Copy the current BenchmarkFeatures object.

Returns:

Type Description
BenchmarkFeatures

The copy.

Source code in mlir_rl_artifact/state.py
def copy(self) -> 'BenchmarkFeatures':
    """Copy the current [BenchmarkFeatures][..] object.

    Returns:
        The copy.
    """
    return BenchmarkFeatures(
        self.bench_name,
        self.code,
        self.operation_tags.copy(),
        {tag: op.copy() for tag, op in self.operations.items()},
        self.root_exec_time
    )

OperationState(bench_idx, bench_name, operation_tag, original_operation_features, operation_features, producer_tag, producer_operand_idx, producer_features, transformation_history, terminal) dataclass

Dataclass to store the operation state data.

Attributes:

Name Type Description
bench_idx int

The index of the benchmark.

bench_name str

The name of the benchmark.

operation_tag str

The tag of the operation.

original_operation_features OperationFeatures

The features of the operation that will be kept always unchanged.

operation_features OperationFeatures

The features of the operation.

producer_tag str | None

The tag of the selected producer.

producer_operand_idx int | None

The index of the producer's operand.

producer_features OperationFeatures | None

The features of the selected producer.

transformation_history list[list[Action]]

List of transformations with their parameters applied to the operation.

terminal bool

Flag to indicate if the state is terminal.

current_history property

Get the transformation sequence of the current operation being optimized.

Returns:

Type Description
list[Action]

The transformation sequence.

step_count property

Get the number of steps in the current transformation sequence.

Returns:

Type Description
int

The number of steps.

latest_action property

Get the latest action in the current transformation sequence.

Returns:

Type Description
Action | None

The latest action.

has_incomplete_action property

Check if the latest action is incomplete.

Returns:

Type Description
bool

True if the latest action is incomplete, False otherwise.

record_action(action)

Record an action in the current transformation sequence.

Note

If the latest action is incomplete, it will be replaced by the new action, and all the past incomplete actions will be kept in sub_actions.

Parameters:

Name Type Description Default
action Action

The action to record.

required
Source code in mlir_rl_artifact/state.py
def record_action(self, action: 'Action'):
    """Record an action in the current transformation sequence.

    Note:
        If the latest action is incomplete, it will be replaced by the new action,
        and all the past incomplete actions will be kept in `sub_actions`.

    Args:
        action: The action to record.
    """
    if self.has_incomplete_action:
        # Case where the last action should be replaced
        action.sub_actions = self.latest_action.sub_actions + [self.latest_action]
        self.current_history[-1] = action
    else:
        self.current_history.append(action)

copy()

Copy the current OperationState object.

Returns:

Type Description
OperationState

The copy.

Source code in mlir_rl_artifact/state.py
def copy(self) -> 'OperationState':
    """Copy the current [OperationState][..] object.

    Returns:
        The copy.
    """
    return OperationState(
        self.bench_idx,
        self.bench_name,
        self.operation_tag,
        self.original_operation_features.copy(),
        self.operation_features.copy(),
        self.producer_tag,
        self.producer_operand_idx,
        self.producer_features.copy() if self.producer_features is not None else None,
        [seq.copy() for seq in self.transformation_history],
        self.terminal
    )

extract_bench_features_from_code(bench_name, code, root_execution_time)

Extract benchmark features from the given code.

Parameters:

Name Type Description Default
bench_name str

the benchmark name

required
code str

the code to extract features from

required
root_execution_time int

the root execution time

required

Returns:

Type Description
BenchmarkFeatures

the extracted benchmark features

Source code in mlir_rl_artifact/state.py
def extract_bench_features_from_code(bench_name: str, code: str, root_execution_time: int) -> BenchmarkFeatures:
    """Extract benchmark features from the given code.

    Args:
        bench_name: the benchmark name
        code: the code to extract features from
        root_execution_time: the root execution time

    Returns:
        the extracted benchmark features
    """
    result = subprocess.run(
        f'{os.getenv("AST_DUMPER_BIN_PATH")} -',
        shell=True,
        input=code.encode('utf-8'),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    raw_ast_info = result.stdout.decode('utf-8')
    if result.returncode != 0:
        raise Exception(result.stderr.decode('utf-8'))

    return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time)

extract_bench_features_from_file(bench_name, file_path, root_execution_time)

Extract benchmark features from the code in the file.

Parameters:

Name Type Description Default
bench_name str

the benchmark name

required
file_path str

the file path

required
root_execution_time int

the root execution time

required

Returns:

Type Description
BenchmarkFeatures

the extracted benchmark features

Source code in mlir_rl_artifact/state.py
def extract_bench_features_from_file(bench_name: str, file_path: str, root_execution_time: int) -> BenchmarkFeatures:
    """Extract benchmark features from the code in the file.

    Args:
        bench_name: the benchmark name
        file_path: the file path
        root_execution_time: the root execution time

    Returns:
        the extracted benchmark features
    """
    result = subprocess.run(
        f'{os.getenv("AST_DUMPER_BIN_PATH")} {file_path}',
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    raw_ast_info = result.stdout.decode('utf-8')
    if result.returncode != 0:
        raise Exception(result.stderr.decode('utf-8'))

    return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time)

__extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time)

Extracts benchmark features from the code's AST result and execution time.

Parameters:

Name Type Description Default
bench_name str

the benchmark name

required
raw_ast_info str

the raw AST information

required
root_execution_time int

the root execution time

required

Returns:

Type Description
BenchmarkFeatures

extracted benchmark features

Source code in mlir_rl_artifact/state.py
def __extract_bench_features_from_ast_result(bench_name: str, raw_ast_info: str, root_execution_time: int) -> BenchmarkFeatures:
    """Extracts benchmark features from the code's AST result and execution time.

    Args:
        bench_name: the benchmark name
        raw_ast_info: the raw AST information
        root_execution_time: the root execution time

    Returns:
        extracted benchmark features
    """
    cfg = Config()

    info, full_code = raw_ast_info.split("########################################")
    operations_lines, graph_str = info.split('#BEGIN_GRAPH')

    operations_blocks = operations_lines.split('#START_OPERATION')
    operations_blocks = [block.strip() for block in operations_blocks if block]

    ops_tags = []
    operations: dict[str, OperationFeatures] = {}
    true_loads_count: dict[str, int] = {}
    for operation_block in operations_blocks:
        rest, operation_tag = operation_block.split("#START_TAG")
        operation_tag = operation_tag.strip().split("\n")[0]
        log_info = f"- Bench: {bench_name}\n- Operation: {operation_tag}"

        operation_name, rest = rest.split("#START_VECTORIZABLE")
        operation_name = operation_name.strip()
        operation_type = __get_operation_type(operation_name)

        nested_loops = []
        op_count = {}
        load_data: list[list[str]] = []
        store_data: list[list[str]] = []

        vectorizable_str, rest = rest.split("#START_NESTED_LOOPS")
        assert vectorizable_str.strip() in ["true", "false"], f"Vectorizable string is not valid: {vectorizable_str}"
        vectorizable = vectorizable_str.strip() == "true"

        nested_loops_str, rest = rest.split("#START_LOAD_DATA")
        for nested_loop_str in nested_loops_str.strip().split("\n"):
            if not nested_loop_str:
                continue
            arg, low, high, step, iter = nested_loop_str.strip().split(" ")
            nested_loops.append(NestedLoopFeatures(
                arg=f'%{arg}',
                lower_bound=int(low),
                upper_bound=int(high),
                step=int(step),
                iterator_type=IteratorType(iter)
            ))
        if len(nested_loops) > cfg.max_num_loops:
            print_error(f"Number of loops {len(nested_loops)} is not supported\n" + log_info)
            continue

        loads_data_str, rest = rest.split("#START_STORE_DATA")
        loads_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', loads_data_str)
        for load_data_str in loads_data_str.strip().split("\n"):
            if not load_data_str:
                continue
            load_data.append(load_data_str.split(", "))
        if any(len(load) > cfg.max_num_load_store_dim for load in load_data):
            print_error(f"Number of load dims {len(load_data[-1])} is not supported\n" + log_info)
            continue
        true_loads_count[operation_tag] = len(load_data)
        if len(load_data) > cfg.max_num_stores_loads:
            # We ignore this overflow, because there are many cases with a huge number of loads
            load_data = load_data[:cfg.max_num_stores_loads]

        stores_data_str, ops_count_str = rest.split("#START_OP_COUNT")
        stores_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', stores_data_str)
        for store_data_str in stores_data_str.strip().split("\n"):
            if not store_data_str:
                continue
            store_data.append(store_data_str.split(", "))
        if any(len(store) > cfg.max_num_load_store_dim for store in store_data):
            print_error(f"Number of store dims {len(store_data[-1])} is not supported\n" + log_info)
            continue
        if len(store_data) > cfg.max_num_stores_loads:
            store_data = store_data[:cfg.max_num_stores_loads]

        for op_count_str in ops_count_str.strip().split("\n"):
            op, count = op_count_str.strip().split(" ")
            op_count[op] = int(count)

        ops_tags.append(operation_tag)
        operations[operation_tag] = OperationFeatures(
            operation_name=operation_name,
            operation_type=operation_type,
            op_count=op_count,
            load_data=load_data,
            store_data=store_data,
            nested_loops=nested_loops,
            producers=[],
            consumers=[],
            vectorizable=vectorizable,
            pre_actions=[]
        )

    # Extracte Producer/Consumer features
    graph_str = graph_str.replace("#END_GRAPH", "")
    graph_lines = [(line.split(' --> ')[0].split(' '), line.split(' --> ')[1].split(' ')) for line in graph_str.strip().split("\n") if line]

    for (producer, res_idx), (consumer, op_idx) in graph_lines:
        op_idx = int(op_idx)
        res_idx = int(res_idx)
        if op_idx >= len(operations[consumer].load_data):
            if 0 <= (op_idx - true_loads_count[consumer]) < len(operations[consumer].store_data):
                # Case where the index falls within the supported number of stores
                # -> align the index
                op_idx = op_idx - true_loads_count[consumer] + len(operations[consumer].load_data)
            else:
                # Case where the index falls within unsupported number of loads or stores
                # -> ignore
                continue

        operations[consumer].producers.append((producer, op_idx))
        operations[producer].consumers.append((consumer, res_idx))

    return BenchmarkFeatures(
        bench_name=bench_name,
        code=full_code,
        operation_tags=ops_tags,
        operations=operations,
        root_exec_time=root_execution_time,
    )

__get_operation_type(operation_name)

Get the operation type from the operation name.

Parameters:

Name Type Description Default
operation_name str

The operation name.

required

Returns:

Type Description
OperationType

The operation type or None if not found.

Source code in mlir_rl_artifact/state.py
def __get_operation_type(operation_name: str) -> OperationType:
    """Get the operation type from the operation name.

    Args:
        operation_name: The operation name.

    Returns:
        The operation type or None if not found.
    """
    for operation_type in OperationType:
        if operation_type.value and operation_type.value in operation_name:
            return operation_type
    return OperationType.unknown