Skip to content

PPO

Proximal Policy Optimization (PPO) training algorithm for MLIR RL.

This module implements the core PPO training loop including trajectory collection, policy updates, value function updates, and benchmark evaluation. It manages the interaction between the RL environment and the neural network models.

collect_trajectory(data, model, step)

Collect a trajectory using the model and the environment.

Parameters:

Name Type Description Default
data Benchmarks

The benchmarks dataset.

required
model HiearchyModel

The model to use.

required
step int

The current step of the main loop.

required

Returns:

Type Description
TrajectoryData

The collected trajectory.

Source code in mlir_rl_artifact/ppo.py
def collect_trajectory(data: Benchmarks, model: Model, step: int) -> TrajectoryData:
    """Collect a trajectory using the model and the environment.

    Args:
        data: The benchmarks dataset.
        model: The model to use.
        step: The current step of the main loop.

    Returns:
        The collected trajectory.
    """
    dm = DaskManager()
    fl = FileLogger()
    exe = Execution()
    cfg = Config()

    eps = None
    if 'epsilon' in cfg.exploration:
        ratio = step / cfg.nb_iterations
        final_eps = 0.001
        eps = final_eps + (cfg.init_epsilon - final_eps) * (1 - ratio)

    print_info(f"Collecting {cfg.bench_count} benchmarks using {dm.num_workers} workers...")
    traj_start = time()

    # Prepare benchmarks to explore
    indices = torch.randperm(len(data))[:cfg.bench_count].long().tolist()
    if len(indices) < cfg.bench_count:
        indices = (indices * cfg.bench_count)[:cfg.bench_count]
    envs: list[Env] = []
    states: list[OperationState] = []
    observations: list[torch.Tensor] = []
    tcs: list[TrajectoryCollector] = []
    for idx in indices:
        env = Env()
        state = env.reset(data, idx)
        envs.append(env)
        states.append(state)
        observations.append(Observation.from_state(state))
        tcs.append(TrajectoryCollector())

    with GPUOccupier().gpu_needed():
        while (active_states := [(i, s) for i, s in enumerate(states) if not s.terminal]):
            # Sample states that are not terminal yet
            obss = torch.cat([observations[i] for i, _ in active_states])
            actions_index, actions_bev_log_p, entropies = model.sample(obss.to(device), eps=eps)
            actions_index, actions_bev_log_p, entropies = actions_index.cpu(), actions_bev_log_p.cpu(), entropies.cpu()
            fl['train/entropy'].extend(entropies.tolist())

            # Record data and update states
            for (i, state), obs, action_index, action_bev_log_p in zip(active_states, obss, actions_index, actions_bev_log_p):
                obs = obs.unsqueeze(0)

                # Get action and use it to get next state
                action = ActionSpace.action_by_index(action_index, state)
                states[i] = next_state = envs[i].step(state, action)
                observations[i] = next_obs = Observation.from_state(next_state)

                # If the benchmark is not done yet, keep next operation state instead
                done = False
                if next_state.terminal:
                    next_op_state = envs[i].get_next_op_state(next_state)
                    if next_op_state is not None:
                        states[i] = next_op_state
                        observations[i] = Observation.from_state(next_op_state)
                    else:
                        done = True

                # Record available data
                tcs[i].append(
                    Observation.get_part(obs, NumLoops).long().item(),
                    action_index.unsqueeze(0),
                    obs,
                    next_obs,
                    action_bev_log_p.item(),
                    0.0,  # This will be filled after execution
                    done
                )

    traj_end_sampling = time()

    results = dm.map_objs(__execute_states, states, data, exe.main_exec_data, training=True, obj_str=lambda s: s.bench_name)

    traj_end_exec_states = time()

    results = [
        (*e.failed_seq(s.transformation_history), float(dm.batch_timeout))
        if not r else r
        for r, s, e in zip(results, states, envs)
    ]

    all_rewards, all_speedups, all_exec_times, cache_misses, worker_times = tuple(zip(*results))
    new_cache_data: dict[str, dict[str, int]] = {}
    for tc, state, rewards, speedup, exec_time in zip(tcs, states, all_rewards, all_speedups, all_exec_times):
        # Update trajectory
        tc.rewards = rewards
        # Log metrics
        fl['train/reward'].extend(rewards)
        fl['train/final_speedup'].append(speedup)
        # Get new cache data
        if exec_time is not None:
            cache_key = exe.get_code_cache_key(state.transformation_history)
            if state.bench_name not in new_cache_data:
                new_cache_data[state.bench_name] = {}
            new_cache_data[state.bench_name][cache_key] = exec_time

    tc = sum(tcs, TrajectoryCollector())
    exe.update_execution_cache(new_cache_data)

    traj_end = time()

    sampling_time = traj_end_sampling - traj_start
    exec_states_time = traj_end_exec_states - traj_end_sampling
    collection_time = traj_end - traj_start
    print_info((
        f"Timings: collection: {timedelta(seconds=collection_time)}"
        f", sampling: {timedelta(seconds=sampling_time)}"
        f", exec: {timedelta(seconds=exec_states_time)}"
    ))

    return tc.to_trajectory()

ppo_update(trajectory, model, optimizer)

Update the policy and value models using PPO algorithm.

Performs PPO training on the collected trajectory data by computing policy loss, value loss, and entropy bonus, then updating model parameters via backpropagation.

Parameters:

Name Type Description Default
trajectory TrajectoryData

The trajectory data collected from environment.

required
model HiearchyModel

The model to update.

required
optimizer Optimizer

The optimizer for model parameters.

required
Source code in mlir_rl_artifact/ppo.py
def ppo_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer):
    """Update the policy and value models using PPO algorithm.

    Performs PPO training on the collected trajectory data by computing policy loss,
    value loss, and entropy bonus, then updating model parameters via backpropagation.

    Args:
        trajectory: The trajectory data collected from environment.
        model: The model to update.
        optimizer: The optimizer for model parameters.
    """
    fl = FileLogger()
    cfg = Config()

    trajectory.update_attributes(model)

    ppo_start = time()
    data_loader = trajectory.loader(cfg.ppo_batch_size, 1)
    for _ in range(cfg.ppo_epochs):
        for batch in data_loader:
            batch: list[torch.Tensor] = [e.to(device, non_blocking=True) for e in batch]
            (
                _,
                actions_index,
                obs,
                _,
                actions_bev_log_p,
                _, _,
                values,
                _,
                actions_old_log_p,
                off_policy_rates,
                returns,
                advantages,
            ) = batch
            max_abs_adv = advantages.abs().max()
            if cfg.normalize_adv == 'standard' and advantages.size(0) > 1:
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            elif cfg.normalize_adv == 'max-abs' and max_abs_adv > 0:
                advantages = advantages / max_abs_adv

            with torch.enable_grad():
                actions_log_p, new_values, entropies = model(obs, actions_index)

                policy_loss, clip_frac = model.policy_model.loss(actions_log_p, actions_bev_log_p, off_policy_rates, advantages)
                loss = policy_loss

                if cfg.value_epochs == 0:
                    value_loss = model.value_model.loss(new_values, values, returns)
                    loss += cfg.value_coef * value_loss

                if 'entropy' in cfg.exploration:
                    entropy_loss = -entropies.mean()
                    loss += cfg.entropy_coef * entropy_loss

            approx_kl = (actions_old_log_p - actions_log_p).pow(2).mean() / 2

            optimizer.zero_grad()
            try:
                loss.backward()
                clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
            except Exception as e:
                print_error(
                    'Error during PPO update\n'
                    f'Error: {e}'
                )

            # Logging
            fl['train_ppo/policy_loss'].append(policy_loss.item())
            fl['train_ppo/clip_frac'].append(clip_frac.item())
            fl['train_ppo/clip_factor'].append(clip_factor.item())
            fl['train_ppo/approx_kl'].append(approx_kl.item())
            if cfg.value_epochs == 0:
                fl['train_ppo/value_loss'].append(value_loss.item())
            if 'entropy' in cfg.exploration:
                fl['train_ppo/entropy_loss'].append(entropy_loss.item())
    ppo_end = time()
    print_info(f"PPO fit in {timedelta(seconds=ppo_end - ppo_start)}")

value_update(trajectory, model, optimizer)

Update the value function model using trajectory data.

Trains the value model to predict state values by minimizing MSE loss between predicted and computed returns.

Parameters:

Name Type Description Default
trajectory TrajectoryData

The trajectory data with returns computed.

required
model HiearchyModel

The hierarchical model to update.

required
optimizer Optimizer

The optimizer for value model parameters.

required
Source code in mlir_rl_artifact/ppo.py
def value_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer):
    """Update the value function model using trajectory data.

    Trains the value model to predict state values by minimizing MSE loss between
    predicted and computed returns.

    Args:
        trajectory: The trajectory data with returns computed.
        model: The hierarchical model to update.
        optimizer: The optimizer for value model parameters.
    """
    fl = FileLogger()
    cfg = Config()

    trajectory.update_attributes(model)

    value_start = time()
    data_loader = trajectory.loader(cfg.value_batch_size, 1)
    for _ in range(cfg.value_epochs):
        for batch in data_loader:
            batch: list[torch.Tensor] = [e.to(device, non_blocking=True) for e in batch]
            (
                _, _,
                obs,
                _, _, _, _,
                values,
                _, _, _,
                returns,
                _,
            ) = batch
            with torch.enable_grad():
                new_values = model.value_model(obs)

                loss = model.value_model.loss(new_values, values, returns)

            optimizer.zero_grad()
            try:
                loss.backward()
                clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
            except Exception as e:
                print_error(
                    'Error during Value update\n'
                    f'Error: {e}'
                )

            # Logging
            fl['train_value/loss'].append(loss.item())
            fl['train_value/clip_factor'].append(clip_factor.item())
    value_end = time()
    print_info(f"Value fit in {timedelta(seconds=value_end - value_start)}")

evaluate_benchmarks(model, data)

Evaluate the model on all benchmarks and measure optimization results.

Runs the trained model in greedy mode on all benchmarks, applies optimizations, and measures the resulting execution times and speedups.

Parameters:

Name Type Description Default
model HiearchyModel

The trained model to evaluate.

required
data Benchmarks

The benchmark dataset to evaluate on.

required

Returns:

Type Description
dict[str, int]

Dictionary mapping benchmark names to execution times (in nanoseconds).

dict[str, float]

Dictionary mapping benchmark names to speedup factors.

Source code in mlir_rl_artifact/ppo.py
def evaluate_benchmarks(model: Model, data: Benchmarks) -> tuple[dict[str, int], dict[str, float]]:
    """Evaluate the model on all benchmarks and measure optimization results.

    Runs the trained model in greedy mode on all benchmarks, applies optimizations,
    and measures the resulting execution times and speedups.

    Args:
        model: The trained model to evaluate.
        data: The benchmark dataset to evaluate on.

    Returns:
        Dictionary mapping benchmark names to execution times (in nanoseconds).
        Dictionary mapping benchmark names to speedup factors.
    """
    dm = DaskManager()
    fl = FileLogger()
    exe = Execution()

    print_info("Evaluation started...")
    eval_start = time()

    # Prepare benchmarks to explore
    indices = range(len(data))
    envs: list[Env] = []
    states: list[OperationState] = []
    observations: list[torch.Tensor] = []
    for idx in indices:
        env = Env()
        state = env.reset(data, idx)
        envs.append(env)
        states.append(state)
        observations.append(Observation.from_state(state))

    with GPUOccupier().gpu_needed():
        while (active_states := [(i, s) for i, s in enumerate(states) if not s.terminal]):
            # Sample states that are not terminal yet
            obss = torch.cat([observations[i] for i, _ in active_states])
            actions_index, _, entropies = model.sample(obss.to(device), greedy=True)
            actions_index, entropies = actions_index.cpu(), entropies.cpu()
            fl['eval/entropy'].extend(entropies.tolist())

            # Record data and update states
            for (i, state), obs, action_index in zip(active_states, obss, actions_index):
                obs = obs.unsqueeze(0)

                # Get action and use it to get next state
                action = ActionSpace.action_by_index(action_index, state)
                states[i] = next_state = envs[i].step(state, action)
                observations[i] = Observation.from_state(next_state)

                # If the benchmark is not done yet, keep next operation state instead
                if next_state.terminal:
                    next_op_state = envs[i].get_next_op_state(next_state)
                    if next_op_state is not None:
                        states[i] = next_op_state
                        observations[i] = Observation.from_state(next_op_state)

    print_success("Schedules inferred! Applying optimizations...", flush=True)
    results = dm.map_objs(__execute_states, states, data, exe.main_exec_data, training=False, obj_str=lambda s: s.bench_name)
    results = [
        (*e.failed_seq(s.transformation_history), float(dm.batch_timeout))
        if not r else r
        for r, s, e in zip(results, states, envs)
    ]
    all_rewards, all_speedups, all_exec_times, _, _ = tuple(zip(*results))
    new_cache_data: dict[str, dict[str, int]] = {}
    bench_execs: dict[str, int] = {}
    bench_speedups: dict[str, float] = {}
    for state, rewards, speedup, exec_time in zip(states, all_rewards, all_speedups, all_exec_times):
        fl['eval/reward'].extend(rewards)
        fl['eval/cumulative_reward'].append(sum(rewards))
        fl['eval/final_speedup'].append(speedup)
        if exec_time is not None:
            fl[f'eval/exec_time/{state.bench_name}'].append(exec_time)
            fl[f'eval/speedup/{state.bench_name}'].append(speedup)
            bench_execs[state.bench_name] = exec_time
            bench_speedups[state.bench_name] = speedup
            cache_key = exe.get_code_cache_key(state.transformation_history)
            if state.bench_name not in new_cache_data:
                new_cache_data[state.bench_name] = {}
            new_cache_data[state.bench_name][cache_key] = exec_time

        print_success("\n--------------------", add_label=False)
        print_success("Bench:", state.bench_name, add_label=False)
        print_success("Transformations:", add_label=False)
        for op_seq in state.transformation_history:
            op_tag = op_seq[0].operation_tag
            print_success(f"  - {op_tag}: {list(map(str, op_seq))}", add_label=False)
        print_success("Speedup:", speedup, add_label=False)
        print_success("Execution time:", exec_time, add_label=False)
        print_success("--------------------\n", add_label=False)

    if len(all_speedups) > 0:
        fl['eval/average_speedup'].append(sum(all_speedups) / len(all_speedups))
    exe.update_execution_cache(new_cache_data)

    eval_end = time()
    print_info(f"Evaluation time: {timedelta(seconds=eval_end - eval_start)}")

    return bench_execs, bench_speedups

__execute_states(state, exec_data_file, benchs, main_exec_data)

Execute a benchmark with the transformation sequence stored in state.

Worker function for parallel execution. Initializes environment, applies transformations, and measures execution results.

Parameters:

Name Type Description Default
state OperationState

The operation state containing transformation history.

required
exec_data_file str

Path to the execution cache file.

required
benchs Benchmarks

The benchmark dataset.

required
main_exec_data dict[str, dict[str, int]] | None

Pre-computed execution data.

required

Returns:

Type Description
list[float]

List of rewards for each action in the sequence.

float

Speedup factor (ratio of original to optimized time).

int | None

Execution time in nanoseconds (None if execution failed).

bool

Cache miss flag (False if result was cached).

float

Worker execution time in seconds.

Source code in mlir_rl_artifact/ppo.py
def __execute_states(state: OperationState, exec_data_file: str, benchs: Benchmarks, main_exec_data: Optional[dict[str, dict[str, int]]]) -> tuple[list[float], float, Optional[int], bool, float]:
    """Execute a benchmark with the transformation sequence stored in state.

    Worker function for parallel execution. Initializes environment, applies transformations,
    and measures execution results.

    Args:
        state: The operation state containing transformation history.
        exec_data_file: Path to the execution cache file.
        benchs: The benchmark dataset.
        main_exec_data: Pre-computed execution data.

    Returns:
        List of rewards for each action in the sequence.
        Speedup factor (ratio of original to optimized time).
        Execution time in nanoseconds (None if execution failed).
        Cache miss flag (False if result was cached).
        Worker execution time in seconds.
    """
    print("Handling benchmark:", state.bench_name, flush=True)
    worker_start = time()

    Execution(exec_data_file, main_exec_data)
    env = Env()
    env.reset(benchs, state.bench_idx)
    rewards, speedup, new_exec_time, cache_miss = env.apply_and_run_sequence(state.transformation_history)

    worker_end = time()
    worker_time = worker_end - worker_start

    return rewards, speedup, new_exec_time, cache_miss, worker_time