How to work with partial dataset in PyTorch C++ API

I’m working on a project using the PyTorch C++ frontend with the MNIST dataset. My goal is to work with only a portion of the full dataset instead of loading everything.

In regular PyTorch Python, I could easily use torch.utils.data.Subset or similar tools to get a smaller chunk of data. But I can’t find equivalent functionality in the C++ API.

Currently my dataset loading looks like this:

auto dataset = torch::data::datasets::MNIST(data_path)
                 .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
                 .map(torch::data::transforms::Stack<>());
const size_t dataset_size = dataset.size().value();

I attempted to build a custom sampler to handle this:

class PartialSampler : public torch::data::samplers::Sampler<> {
public:
    explicit PartialSampler(std::vector<size_t> idx)
        : sample_indices(std::move(idx)) {}

    c10::optional<std::vector<size_t>> next(size_t batch_count) override {
        std::vector<size_t> result;
        while (result.size() < batch_count && position < sample_indices.size()) {
            result.push_back(sample_indices[position++]);
        }
        if (result.empty()) {
            return c10::nullopt;
        }
        return result;
    }

    void reset() {
        position = 0;
    }

    c10::optional<size_t> size() {
        return sample_indices.size();
    }

private:
    std::vector<size_t> sample_indices;
    size_t position = 0;
};

When I try to use this sampler with the data loader, I get compilation errors about missing BatchRequestType. What’s the proper way to work with dataset subsets in PyTorch C++?

your sampler looks fine, but you’re missing the template parameter. try inheriting from torch::data::samplers::Sampler<std::vector<size_t>> instead of just Sampler<>. the BatchRequestType error happens because the base class needs to know what type of batch requests it handles.

Nice custom sampler idea! Why not try torch::data::samplers::RandomSampler with a tweaked range? What’s your subset size vs the full MNIST dataset? Might point us toward an easier fix.

Nice custom sampler idea! Why not try torch::data::make_data_loader with a smaller batch size and just break early once you’ve processed what you need? What’s the exact compilation error you’re seeing with BatchRequestType? That’d help figure out what’s going wrong.

Had this exact problem with custom datasets in the C++ API. The issue is your sampler inheritance - you’re inheriting from torch::data::samplers::Sampler<> without specifying the batch request type, which breaks compilation. Fix it by properly templating your sampler class. Use torch::data::samplers::Sampler<std::vector<size_t>> instead of the empty template. This tells the compiler what your next() method returns. You’ll also need to override the save() and load() methods for serialization support (they can stay empty for basic functionality). Once you fix the template issue, just create your data loader with torch::data::make_data_loader() and pass your custom sampler instance. This worked perfectly for me when I needed to train on specific data subsets for validation.