Naive std::function implementation

(889 words)

After exploring std::function in a previous post, I thought that it might be a good practice to implement a simple (and partial) std::function. It turned out to be much less code than I anticipated. I hope you’ll like it.

Features

While std::function has a few typedefs and methods, the core functionality is assignment and invocation:

Even though std::function has a bit more functionality to it, we will only implement the above 2 methods.

Declaration

The standard specifies that std::function will be declared as follows:

namespace std {
	template <typename>
	class function; // no definition

	template <typename ReturnValue, typename ... Args>
	class function<ReturnValue(Args...)> {
		// ...
	};
}

Why have a std::function with no definition that is never used? Well, ideally you’d only have the second version. However, that version is a partial template specialization of the first. Another way could have been to define std::function as:

template <typename ReturnValue, typename ... Args>
class function { ... };

But this would mean that clients would look like std::function<int, bool, float> rather than std::function<int(bool, float)>. I personally think that the latter is much nicer, but there’s just no syntax to express this without partial specialization.

So let’s copy that:

template <typename>
class naive_function; // no definition

template <typename ReturnValue, typename ... Args>
class naive_function<ReturnValue(Args...)> {
public:
	// operator= goes here
	// operator() goes here
private:
	...
};

Now any attempt to (mis)use naive_function with a simple argument list (example: naive_function<bool, int>) will yield a compiler error along the lines of “using undefined class naive_function”.

Groundwork

Before we move to implement operator= and operator() we need to write some supporting code. The following classes will be internal private to naive_function, so they know ReturnValue and Args.

Let’s start with an interface:

class ICallable {
public:
	virtual ~ICallable() = default;
	virtual ReturnValue Invoke(Args...) = 0;
};

Easy enough. Now for a concrete implementor:

template <typename T>
class CallableT : public ICallable {
public:
	CallableT(const T& t)
		: t_(t) {
	}

	~CallableT() override = default;

	ReturnValue Invoke(Args... args) override {
		return t_(args...);
	}

private:
	T t_;
};

Implementation

With the help of the above very simple classes it is now almost trivial to implement naive_function:

template <typename ReturnValue, typename... Args>
class naive_function<ReturnValue(Args...)> {
public:
	template <typename T>
	naive_function& operator=(T t) {
		callable_ = std::make_unique<CallableT<T>>(t);
		return *this;
	}

	ReturnValue operator()(Args... args) const {
		assert(callable_);
		return callable_->Invoke(args...);
	}

private:
	// ICallable as implemented above.
	// CallableT as implemented above.

	std::unique_ptr<ICallable> callable_;
};

There’s not even a lot of magic here:

operator= is templated, where T is anything that can be called with Args... and return ReturnValue. There we dynamically create a CallableT<T> which is assigned to callable_ (of type std::unique_ptr<ICallable>). Now the vtable knows how to execute the proper code at runtime.

operator() is trivial. It is not allowed to be called before operator= was called, thus the assert. After that, simply Invoke callable_ and return its return-value.

Let’s test our creation:

void func() {
	cout << "func" << endl;
}

struct functor {
	void operator()() {
		cout << "functor" << endl;
	}
};

int main() {
	naive_function<void()> f;
	f = func;
	f();
	f = functor();
	f();
	f = []() { cout << "lambda" << endl; };
	f();
}

Output:

func
functor
lambda

Future improvements

This implementation lacks a few things which I consider beyond the scope of this post, but feel free to implement them on your own:

Forwarding references and perfect forwarding

Specifically in the following places:

Small-object optimization

For more details about Small Object Optimization (SOO) or Small String Optimization (SSO) see my previous posts about std::string and std::function

Clang reserves 16 bytes for small objects in order to save dynamic allocations. In naive_function we always allocate dynamically.

Special handling for operator= with naive_function

We have a templated operator=. Can you guess what happens in the following piece of code?:

naive_function<void()> f;
naive_function<void()> f2;
f2 = f;

I was surprised by this, but this actually fails to compile (tested on Visual Studio and clang). Reason is that copy-assignment-operator is deleted due to the fact that callable_ has no copy-assignment-operator. It does not fallback to our operator=.

But even if we got it to work, it would create an inefficient double-dereference (or more if this was assigned to yet another naive_function).

An operator which would copy the internals would save this, and will also behave more sanely when the user changes objects that have been copied.

Appendix: Full code

#include <iostream>
#include <memory>
#include <cassert>
using namespace std;

template <typename T>
class naive_function;

template <typename ReturnValue, typename... Args>
class naive_function<ReturnValue(Args...)> {
public:
	template <typename T>
	naive_function& operator=(T t) {
		callable_ = std::make_unique<CallableT<T>>(t);
		return *this;
	}

	ReturnValue operator()(Args... args) const {
		assert(callable_);
		return callable_->Invoke(args...);
	}

private:
	class ICallable {
	public:
		virtual ~ICallable() = default;
		virtual ReturnValue Invoke(Args...) = 0;
	};

	template <typename T>
	class CallableT : public ICallable {
	public:
		CallableT(const T& t)
			: t_(t) {
		}

		~CallableT() override = default;

		ReturnValue Invoke(Args... args) override {
			return t_(args...);
		}

	private:
		T t_;
	};

	std::unique_ptr<ICallable> callable_;
};

void func() {
	cout << "func" << endl;
}

struct functor {
	void operator()() {
		cout << "functor" << endl;
	}
};

int main() {
	naive_function<void()> f;
	f = func;
	f();
	f = functor();
	f();
	f = []() { cout << "lambda" << endl; };
	f();
}

Comments