#pragma once #include #include #include #include namespace torch { namespace data { namespace transforms { /// A `BatchTransform` that applies a user-provided functor to a batch. template class BatchLambda : public BatchTransform { public: using typename BatchTransform::InputBatchType; using typename BatchTransform::OutputBatchType; using FunctionType = std::function; /// Constructs the `BatchLambda` from the given `function` object. explicit BatchLambda(FunctionType function) : function_(std::move(function)) {} /// Applies the user-provided function object to the `input_batch`. OutputBatchType apply_batch(InputBatchType input_batch) override { return function_(std::move(input_batch)); } private: FunctionType function_; }; // A `Transform` that applies a user-provided functor to individual examples. template class Lambda : public Transform { public: using typename Transform::InputType; using typename Transform::OutputType; using FunctionType = std::function; /// Constructs the `Lambda` from the given `function` object. explicit Lambda(FunctionType function) : function_(std::move(function)) {} /// Applies the user-provided function object to the `input`. OutputType apply(InputType input) override { return function_(std::move(input)); } private: FunctionType function_; }; } // namespace transforms } // namespace data } // namespace torch