This is the curiously-recurring template pattern, or CRTP for short. A major advantage of this technique is that it enabled so-called static polymorphism, meaning that functions in torch::data::datasets::Dataset
can call into functions of CustomDataset
, without needing to make those functions virtual (and thus deal with the runtime mess of virtual method dispatch and so on). You can also perform compile-time metaprogramming such as compile-time enable_if
s depending on the properties of the custom dataset type.
In the case of PyTorch, BaseDataset
(the superclass of Dataset
) uses this technique heavily to support operations such as mapping and filtering:
template <typename TransformType>
MapDataset<Self, TransformType> map(TransformType transform) & {
return datasets::map(static_cast<Self&>(*this), std::move(transform));
}
Note the static cast of this
to the derived type (legal as long as CRTP is properly applied); datasets::map
constructs a MapDataset
object which is also parametrized by the dataset type, allowing the MapDataset
implementation to statically call methods such as get_batch
(or encounter a compile-time error if they do not exist).
Furthermore, since MapDataset
receives the custom dataset type as a type parameter, compile-time metaprogramming is possible:
/// The implementation of `get_batch()` for the stateless case, which simply
/// applies the transform to the output of `get_batch()` from the dataset.
template <
typename D = SourceDataset,
typename = torch::disable_if_t<D::is_stateful>>
OutputBatchType get_batch_impl(BatchRequestType indices) {
return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
}
/// The implementation of `get_batch()` for the stateful case. Here, we follow
/// the semantics of `Optional.map()` in many functional languages, which
/// applies a transformation to the optional's content when the optional
/// contains a value, and returns a new optional (of a different type) if the
/// original optional returned by `get_batch()` was empty.
template <typename D = SourceDataset>
torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
BatchRequestType indices) {
if (auto batch = dataset_.get_batch(std::move(indices))) {
return transform_.apply_batch(std::move(*batch));
}
return nullopt;
}
Notice that the conditional enable is dependent on SourceDataset
, which we only have available because the dataset is parametrized with this CRTP pattern.