Algorithms API

This section documents the algorithm implementations in TwisteRL.

RL Module

Algorithm Base Class

The Algorithm base class provides the core training loop. Key methods:

  • learn(num_steps, best_metrics=None): Main training loop

  • learn_step(): Single training iteration (collect, transform, train)

  • collect(): Collect rollout data using the Rust collector

  • train(torch_data): Train for num_epochs calling train_step

  • evaluate(kwargs): Evaluate the current policy

  • solve(state, deterministic, num_searches, num_mcts_searches, C, max_expand_depth): Solve from a given state

PPO Implementation

The PPO class implements Proximal Policy Optimization.

Key methods:

  • data_to_torch(data): Convert collected data to PyTorch tensors

  • train_step(torch_data): Perform one gradient update

Training losses:

  • Policy loss (clipped surrogate objective)

  • Value function loss (MSE)

  • Entropy bonus

AlphaZero Implementation

The AZ class implements AlphaZero with MCTS.

Key methods:

  • data_to_torch(data): Convert MCTS data to PyTorch tensors

  • train_step(torch_data): Train policy and value heads

Data Collection

Data collection is handled by Rust collectors (twisterl.collector.PPOCollector and twisterl.collector.AZCollector).

The collectors return data objects with:

  • obs: Observations

  • logits: Policy logits

  • values: Value predictions

  • rewards: Rewards

  • actions: Actions taken

  • additional_data: Algorithm-specific data (returns, advantages for PPO; remaining_values for AZ)

Training Loop

Training is run via the command line:

python -m twisterl.train --config path/to/config.json

Or programmatically:

from twisterl.utils import prepare_algorithm, load_config

config = load_config("path/to/config.json")
algorithm = prepare_algorithm(config, run_path="runs/my_run")
algorithm.learn(num_steps=1000)

Metrics and Logging

Training metrics are logged to TensorBoard:

tensorboard --logdir runs/

Logged metrics include:

  • Benchmark/difficulty: Current difficulty level

  • Benchmark/success: Success rate on evaluation

  • Benchmark/reward: Average reward

  • Losses/value: Value function loss

  • Losses/policy: Policy loss

  • Losses/entropy: Entropy (PPO only)

  • Times/*: Timing breakdown for each step

Checkpointing

Checkpoints are saved automatically in safetensors format:

  • checkpoint_last.safetensors: Most recent checkpoint (frequency controlled by logging.checkpoint_freq)

  • checkpoint_best.safetensors: Best performing checkpoint

Load a checkpoint:

python -m twisterl.train --config config.json --load_checkpoint_path runs/my_run/checkpoint_best.safetensors