Algorithms API
This section documents the algorithm implementations in TwisteRL.
RL Module
Algorithm Base Class
The Algorithm base class provides the core training loop. Key methods:
learn(num_steps, best_metrics=None): Main training looplearn_step(): Single training iteration (collect, transform, train)collect(): Collect rollout data using the Rust collectortrain(torch_data): Train fornum_epochscallingtrain_stepevaluate(kwargs): Evaluate the current policysolve(state, deterministic, num_searches, num_mcts_searches, C, max_expand_depth): Solve from a given state
PPO Implementation
The PPO class implements Proximal Policy Optimization.
Key methods:
data_to_torch(data): Convert collected data to PyTorch tensorstrain_step(torch_data): Perform one gradient update
Training losses:
Policy loss (clipped surrogate objective)
Value function loss (MSE)
Entropy bonus
AlphaZero Implementation
The AZ class implements AlphaZero with MCTS.
Key methods:
data_to_torch(data): Convert MCTS data to PyTorch tensorstrain_step(torch_data): Train policy and value heads
Data Collection
Data collection is handled by Rust collectors (twisterl.collector.PPOCollector and twisterl.collector.AZCollector).
The collectors return data objects with:
obs: Observationslogits: Policy logitsvalues: Value predictionsrewards: Rewardsactions: Actions takenadditional_data: Algorithm-specific data (returns, advantages for PPO; remaining_values for AZ)
Training Loop
Training is run via the command line:
python -m twisterl.train --config path/to/config.json
Or programmatically:
from twisterl.utils import prepare_algorithm, load_config
config = load_config("path/to/config.json")
algorithm = prepare_algorithm(config, run_path="runs/my_run")
algorithm.learn(num_steps=1000)
Metrics and Logging
Training metrics are logged to TensorBoard:
tensorboard --logdir runs/
Logged metrics include:
Benchmark/difficulty: Current difficulty levelBenchmark/success: Success rate on evaluationBenchmark/reward: Average rewardLosses/value: Value function lossLosses/policy: Policy lossLosses/entropy: Entropy (PPO only)Times/*: Timing breakdown for each step
Checkpointing
Checkpoints are saved automatically in safetensors format:
checkpoint_last.safetensors: Most recent checkpoint (frequency controlled bylogging.checkpoint_freq)checkpoint_best.safetensors: Best performing checkpoint
Load a checkpoint:
python -m twisterl.train --config config.json --load_checkpoint_path runs/my_run/checkpoint_best.safetensors