Introduction
This hands-on machine learning tutorial is based on the event reconstruction challenges of the ALPHA-g experiment at CERN.
Our goal is to predict the position where antihydrogen atoms annihilate, a task known as vertex reconstruction.
Instead of building everything from scratch, you'll start from a fully working baseline model. This allows us to focus on:
- Understanding the physics and data.
- Exploring the model's design.
- Improve its performance.
- Experiment with more advanced ideas.
By the end of this tutorial, you'll have a clear picture of how deep learning can be applied to real-world problems, and you'll be ready to tackle your own machine learning projects.
What is ALPHA-g?
ALPHA-g is an experiment at CERN that seeks to answer a fundamental question: does antimatter fall under gravity the same way as matter?
To investigate this ALPHA-g traps antihydrogen, a neutral antimatter atom made of an antiproton and a positron, in a magnetic trap. The atom is then released and allowed to fall. When it annihilates on contact with matter, it produces several detectable particles.
Our job is to figure out where that annihilation happened. This position, the vertex, is key to determine how the antihydrogen moved and whether gravity influenced it.
The detector records a series of 3D points left by charged particles resulting from the annihilation. Reconstructing the vertex from these points is what this tutorial is all about.
If you'd like a visual overview of the ALPHA-g experiment, this short animation by CERN provides an excellent introduction.
Vertex Reconstruction Overview
When antihydrogen annihilates, it typically produces 2-4 charged pions, which travel through the detector and leave behind a trail of hits.
The challenge is to infer the origin point of these tracks i.e. the location of the annihilation.
In traditional reconstruction pipelines, this is done using geometry-based algorithms that:
- Identify clusters of hits.
- Attempt to form tracks.
- Fit back to a common origin.
In this tutorial, you'll work with a deep learning-based approach:
- We provide you with a working model that predicts the z-coordinate of the vertex directly from a cloud of hits.
- You'll explore how this model works and then experiment with ways to make it better.
By the end of the tutorial, you'll have applied meaningful improvements to a real-world ML system used in fundamental physics research.
Why Machine Learning?
Reconstructing a vertex from a cloud of hit positions is a perfect use case for machine learning.
Why?
- It avoids the need to identify and fit individual particle tracks, which can be difficult or ambiguous.
- It can handle messy or sparse events where geometry-based methods might struggle.
- ML models can be trained end-to-end to optimize for the specific reconstruction task, and once trained, they're fast and scalable.
Session 1: Getting Started
In this session you'll get hands-on with the dataset and train your first machine learning model for vertex reconstruction.
You will:
- Explore what the detector data looks like.
- Preprocess point clouds for training and validation.
- Train a working baseline model.
This gives you a full end-to-end view, from raw data to vertex prediction!
By the end, you-ll be ready to start evaluating and improving this model in Session 2.
Setting up the Environment
Before you can run code for this tutorial, you need to set up your environment.
This section walks you through:
- Setting up your Python environment.
- Verifying everything is working.
- Getting the tutorial code.
Step 1: Set Up Your Python Environment
We provide a containerized environment that includes all the necessary dependencies for this tutorial. To access this environment, follow the instructions given here.
Step 2: Verify the Setup
To verify that everything is set up correctly, run the following command:
python -c "import torch; print(torch.__version__)"
If you see the version of PyTorch printed without any errors, your setup is successful.
Step 3: Get the Tutorial Code
First, fork this repository to your own GitHub account. This allows you to make changes and save your work without affecting the original repository. You can find instructions on how to fork a repository here.
Finally, clone your forked repository from within a terminal session inside the container:
git clone https://github.com/YOUR-USERNAME/AdvancedTutorial.git
Exploring the Data
Each annihilation event produces a number of charged particles that leave a trail of hits in the ALPHA-g detector. We have prepared a dataset with 100,000 events:
/fast_scratch_1/TRISEP_data/AdvancedTutorial/small_dataset.parquet
Each event in the dataset contains:
- A true vertex position:
(x, y, z)
i.e. the origin of the annihilation. - A set of 3D hit positions:
[(x1, y1, z1), ..., (xn, yn, zn)]
.
You can get a quick overview of the dataset by launching a Python interpreter and running the following code:
>>> import polars as pl
>>> df = pl.read_parquet("/fast_scratch_1/TRISEP_data/AdvancedTutorial/small_dataset.parquet")
>>> print(df)
"""
shape: (100_000, 2)
┌─────────────────────────────────┬─────────────────────────────────┐
│ target ┆ point_cloud │
│ --- ┆ --- │
│ array[f32, 3] ┆ list[array[f32, 3]] │
╞═════════════════════════════════╪═════════════════════════════════╡
│ [1.455413, 15.901725, -571.578… ┆ [[46.253021, 175.486893, -558.… │
│ [22.550814, 3.005712, 834.1053… ┆ [[171.600311, -59.034386, 1070… │
│ [4.511479, -8.75235, -1014.025… ┆ [[33.221375, -178.431686, -106… │
│ [-9.729183, 4.537313, -970.073… ┆ [[33.232174, -178.489685, -110… │
│ [16.184988, -9.351818, -66.521… ┆ [[152.121017, -98.965576, -218… │
│ … ┆ … │
│ [5.014961, -12.403949, 62.5900… ┆ [[-158.989197, -87.506714, 226… │
│ [-7.504006, -18.486027, 689.96… ┆ [[-181.138474, 11.128488, -30.… │
│ [-19.474358, 13.368332, -59.90… ┆ [[-120.215218, -135.953278, -1… │
│ [-8.580782, 1.076302, -866.250… ┆ [[15.571182, -180.818787, -846… │
│ [-18.750277, 6.330079, 969.173… ┆ [[176.59346, -41.938202, 798.0… │
└─────────────────────────────────┴─────────────────────────────────┘
"""
Before training any model, it's important to understand the structure and characteristics of the data.
Activity:
- Where do annihilation events occur? A skewed distribution in vertex z positions might cause the model to "cheat" by always guessing the most common region.
- Do all events have the same number of hits? Variable-length point clouds will require special handling in the model architecture.
To help you answer these questions, we've provided the script:
AdvancedTutorial/code/visualization.py
.
You can run it directly to visualize key properties of a dataset:
# Target z distribution
python visualize.py target-z /path/to/dataset.parquet
# Point cloud size distribution
python visualize.py cloud-size /path/to/dataset.parquet
Iterating Through the Dataset with PyTorch
To train a model, we need to iterate through the dataset. PyTorch provides a
primitive
torch.utils.data.Dataset
class that allows us to decouple the data loading from the model
training/batching process.
We've provided a
PyTorch-compatible dataset class.
It wraps a .parquet
file and gives you easy access to the data in a
PyTorch-friendly way. Create a new Python script in the AdvancedTutorial/code/
directory:
from data.dataset import PointCloudDataset
config = {"cloud_size": 140}
dataset = PointCloudDataset(
"/fast_scratch_1/TRISEP_data/AdvancedTutorial/small_dataset.parquet", config
)
index = 0 # First event
point_cloud, target = dataset[index]
Try running the code above and plot some point clouds and their corresponding targets (annihilation vertices).
Activity:
- Inspect the
PointCloudDataset
class. How does it handle variable-length point clouds?- Using the first 10 events (indices 0-9), plot the point clouds and their targets. Do they look like you expected?
You can make a 3D scatter plot usingmatplotlib
:import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(projection="3d") ax.scatter(point_cloud[0], point_cloud[1], point_cloud[2]) ax.scatter(0, 0, target.item(), color="red")
- Using the next 10 events (indices 10-19), plot the point clouds without their targets. Make an educated guess about the target vertex position based on the point cloud. Compare your guess with the actual target positions.
Data Preprocessing
Before we can train a model, we need to split our dataset into three parts:
- Training set: The data the model learns from.
- Validation set: Used to tune hyperparameters and monitor performance during training.
- Test set: Used only at the end to evaluate final model performance.
You can then create these splits with a new script like this:
import polars as pl
complete_df = pl.read_parquet(
"/fast_scratch_1/TRISEP_data/AdvancedTutorial/small_dataset.parquet"
)
# Example: Take the first 10000 events
subset_df = complete_df.slice(offset=0, length=10000)
subset_df.write_parquet("/path/to/first_10000.parquet")
Activity:
- Create a training subset with 80% of the data.
- Create a validation subset with 10% of the data.
- Create a test subset with the remaining 10% of the data.
Note that these splits should not overlap. Shared events between the splits can lead to overfitting and an overly optimistic evaluation of the model's performance.
Once you've created your splits, it's a good idea to repeat the data exploration steps on each one of them. Check the target distribution and point cloud sizes to make sure nothing looks unusual (which could affect the model's performance if the splits are not representative of the whole dataset).
Tip: If you want to shuffle a dataset, you can do it with:
shuffled_df = df.sample(fraction=1.0, shuffle=True)
Training and Evaluating the Baseline Model
Now that your data is prepared, let's train a model that predicts the vertex position from detector hits.
We've provided a pre-built model and a training script, which you can run like:
python train.py /path/to/train.parquet /path/to/validate.parquet --output-dir /path/to/output
This will run a full training loop and create the following files in the
/path/to/output
directory:
config.toml
: the hyperparameters used during training.training_log.csv
: a CSV file with the training and validation loss at each epoch.model.pt
: the trained model.
Training the model can take a while depending on the size of your dataset, the
number of epochs, and the hardware you are using. By default, the training
script will run for 50 epochs. You can inspect the training_log.csv
file to
monitor training progress.
Activity:
Write a script to plot the training and validation loss as a function of epoch.
What do you observe?
Once the model is trained, you can evaluate how well it performs on the validation set using the provided test script:
python test.py /path/to/model.pt /path/to/validate.parquet --output /path/to/output.csv
This will create a CSV file with both the true and predicted vertex positions for each event.
Activity:
- Write a script to plot predicted vertex position vs. true vertex position.
- Create a histogram of the residuals (predicted - true).
- Plot the residuals as a function of true z-position. Is there any pattern?
If you finish early or are curious to experiment, try following the bonus activity to see how a small change can impact model performance. Otherwise, feel free to jump ahead to Session 2, where we'll explore the model in more depth.
Bonus Activity (Optional)
You've now trained and evaluated a baseline model. Let's tweak one component to see what happens!
Try changing the loss function here from Mean Squared Error to Mean Absolute Error.
Then retrain the model and evaluate it again.
Session 2: Understanding and Improving the Model
Now that you've trained your first model, it's time to dig deeper and start making it better.
You will:
- Break down the architecture of the baseline model.
- Extend it with additional layers.
- Experiment with hyperparameters to improve performance.
This session will help you build intuition for how the model works and how to improve it.
By the end, you'll have hands-on experience modifying a neural network and optimizing training.
Understanding the Model Architecture
Before improving our model, let's understand its structure.
The model you're using is a simplified
PointNet. Surprisingly, we can understand
everything by just studying the very simple (~20 actual lines of code)
_TNet
class.
This architecture consists of three main components:
-
Feature Extraction:
The input tensor has shape(B, C, L)
, whereB
is the batch size,C
is the number of channels (input features), andL
is the number of points.The feature extractor applies a series of 1D convolution layers to represent the input points in a higher-dimensional space of abstract features.
Let's see a minimal example:
import torch import torch.nn as nn x = torch.randn(2, 3, 4) print("Input shape:", x.shape) net = nn.Conv1d(3, 5, 1) x = net(x) print("Output shape:", x.shape)
Note that the feature extractor also includes batch normalization and ReLU activation layers, which are essential for training deep neural networks.
Activity:
Change thenet
in the example above tonn.BatchNorm1d(3)
ornn.ReLU()
, and manually compute what these layers would do to the input tensor (printx
before and after the operation to verify your calculations).Together, these layers make up the entire feature extractor. Just basic operations, chained one after another. Once you break it down, there's no mystery!
-
Global Features Aggregation:
Once each point has been mapped to a higher-dimensional feature vector, we need to summarize the entire collection of points into a single, fixed-size representation.This is done using a simple operation known as max pooling. It takes the maximum value across all points for each feature, resulting in a single vector that captures the most significant features of the entire point cloud.
Activity:
Change thenet(x)
in the example above to anx.max(dim=2).values
transformation and check the output values and shape.Note that this operation is inherently order-invariant, making it suitable for point clouds, where the order of points doesn't matter.
-
Fully Connected Regressor:
After pooling, we are left with a single(B, F)
tensor; one global feature vector per batch. The final step is to map this to our final prediction.This is done by the regressor block, a series of fully connected linear layers, each followed by batch normalization and ReLU activation.
Activity:
Try passing a(2, 1024)
tensor through annn.Linear(1024, 1)
layer and then aReLU()
to see how this maps feature vectors toward output predictions.
You now understand the full architecture of _TNet
. This structure is compact
but powerful, and it is nearly the complete architecture of our model.
Activity:
Open the full model and compare the_TNet
class to the fullRegressor
class.
What is the difference between them?
Try sketching the full model using a diagram or by describing it in your own words.
Improving the Model
Now that you understand the full structure of the model, it's time to explore a simple but effective improvement.
As you saw in the previous section, the model has an
inner _TNet
that learns the coefficients of a matrix used to
align the abstract features.
This feature alignment allows the network to adaptively transform the point
features, often leading to better performance.
Our New Idea: A Learnable Z-Shift
We're now going to add a second internal _TNet
, but this one has a
different goal: Instead of predicting a full matrix, it will predict a single
scalar; a z-shift.
This shift will be subtracted from the z-coordinate of all points before any feature extraction. The idea is to let the network:
- First guess a coarse global z-position of the vertex, and
- Then predict the small delta from this shifted position.
Activity:
- Modify your
Regressor
class to include a new_TNet
instance at the start of your model (this_TNet
should have anout_dim
of 1).- Use this new
_TNet
to predict a z-shift given the input point cloud (before any feature extraction).- Subtract this z-shift from the z-coordinate of all points in the point cloud. You can do this like:
x = torch.stack((x[:, 0, :], x[:, 1, :], x[:, 2, :] - input_trans), dim=1)
- Let the rest of the model continue unchanged i.e. it is now predicting a delta from the shifted point cloud.
- Update the return value of the
forward
method to include the z-shift plus the delta.
Once you've made these changes, retrain the model and evaluate it again. Do you see an improvement in performance?
Hyperparameter Tuning
Now that you have made some improvements to the model, it's time to focus on hyperparameter tuning. Small changes in settings like batch size, cloud size, or number of epochs can significantly impact model performance.
The provided training script supports setting hyperparameters using environment variables, making it easy to scan different values using e.g. simple shell scripts. For example, you can set the number of epochs like this:
RECO_TRAIN_TRAINING__MAX_EPOCHS=100 python train.py /path/to/train /path/to/val
Activity:
Set up a hyperparameter scan to run overnight. You can focus on as many hyperparameters as you like.
Note that you can use the --dry-run
flag to test your hyperparameter
assignments. This will only save the used values to a file without actually
running the training.
Supported Hyperparameters
Here's a summary of tunable hyperparameters and how to set them:
Environment Variable | Hyperparameter Description | Default Value |
---|---|---|
RECO_TRAIN_DATA__CLOUD_SIZE | Size of the point cloud | 140 |
RECO_TRAIN_MODEL__CONV_FEATURE_EXTRACTOR_PRE | List of Conv1d layer dimensions for pre-alignment feature extraction | [64] |
RECO_TRAIN_MODEL__CONV_FEATURE_EXTRACTOR_POST | List of Conv1d layer dimensions for post-alignment feature extraction | [128,1024] |
RECO_TRAIN_MODEL__FC_REGRESSOR | List of Linear layer dimensions | [512,256] |
RECO_TRAIN_MODEL__INPUT_TRANSFORM_NET__CONV_FEATURE_EXTRACTOR | Inner _TNet (z-shift) list of Conv1d layer dimensions | [64,128,1024] |
RECO_TRAIN_MODEL__INPUT_TRANSFORM_NET__FC_REGRESSOR | Inner _TNet (z-shift) list of Linear layer dimensions | [512,256] |
RECO_TRAIN_MODEL__FEATURE_TRANSFORM_NET__CONV_FEATURE_EXTRACTOR | Inner _TNet (feature alignment) list of Conv1d layer dimensions | [64,128,1024] |
RECO_TRAIN_MODEL__FEATURE_TRANSFORM_NET__FC_REGRESSOR | Inner _TNet (feature alignment) list of Linear layer dimensions | [512,256] |
RECO_TRAIN_TRAINING__BATCH_SIZE | Batch size | 64 |
RECO_TRAIN_TRAINING__MAX_EPOCHS | Maximum number of training epochs | 50 |
Note you can mix and match any of these. Just prefix your training command with the relevant environment variable assignments.
Optional: Large Dataset
If you’ve completed the main activities and have a specific idea that could benefit from access to more data, you can optionally use a larger dataset we’ve provided:
/fast_scratch_1/TRISEP_data/AdvancedTutorial/large_dataset.parquet
This version contains 4 million events, but we don't recommend using it by default. For most learning goals in this tutorial, the smaller dataset is more than sufficient and will train models much faster. Training on the larger dataset can take ~30 minutes per epoch, so it is only recommended if you have a specific hypothesis or idea that requires more data.
Session 3: Presenting Your Results
Congratulations on completing the first two sessions of this tutorial!
In this third and final session, you'll take time to reflect on your work, choose your best model configuration based on your training outputs, and prepare a short presentation.
Each group will have:
- 45 minutes to prepare their presentation.
- 3 minutes to present their work to the rest of the class.
Your presentation is short, so focus on what you find most interesting or important. Here are some ideas of what you might want to include:
- A quick overview of your model and any modifications.
- Highlights from your hyperparameter tuning results.
- A challenge you encountered or any insights you gained.
- Suggestions for future improvements.
Want to Go Further?
If you're curious to keep exploring after the tutorial, here are a few ideas:
- Try predicting the full vertex position: (x, y, z) instead of just z.
- Experiment with different architectures, loss functions, or regularization techniques.
You're also welcome to check out the repository where the actual ALPHA-g model is being developed.
If you find an improvement or bug while playing with this tutorial, contributions are always welcome!