A couple years ago, I came across a language called Julia. It's multiple dispatch feature was very interesting; I wanted to know how it worked under the hood, but I didn't have the knowledge to do that yet. So here I am, finally giving it a try. Now that I have an implementation, I realized there is nothing tying this algorithm to runtime dispatch; I think it could be used in a language with static dispatch as well. If you're interested in learning about multiple dispatch, I left some links at the end of the post. So I guess this post is just about selecting the most specific function for a given set of arguments in a language with subtyping. Ok, let's get started.
Describing the Problem
I'm assuming you're familiar with the concept of subtyping and function overloading. For a brief explanation, subtyping is a way to relate types to each other. For example, if we have a type A
and a type B
and B
is a subtype of A
, we can say that B
is an A
so we can use B
anywhere we need an A
. And function overloading is a way to define multiple functions with the same name but different arguments. The function that is called is the one that matches the arguments the "best".
Let's say we have a subtyping hierarchy like this:
abstract type Number
struct Complex <: Number
struct Real <: Complex
And function like this:
function foo(x::Number, y::Number)
println("Number, Number")
end
function foo(x::Complex, y::Complex)
println("Complex, Complex")
end
function foo(x::Real, y::Real)
println("Real, Real")
end
We want to select the most specific function for the given arguments. For example, if we call it with foo(Real(), Real())
we call the last definition; if we call it with foo(Complex(), Number())
we call the first definition because we can't pass a Number
where we need a Complex
. The thing is, when we called foo(Real(), Real())
it would be totally fine to call the second definition; the code would work just fine. So we not only want to find a method that conforms to the arguments but also a way to rank them.
Modeling Subtyping
Let's start by modeling subtyping first. I want something like this:
any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)
string = Type.new("String", any)
puts real.is?(number) # true
puts real.is?(complex) # true
puts real.is?(real) # true
puts real.is?(string) # false
puts real.is?(any) # true
puts string.is?(any) # true
All types are subtype of Any
which has no supertype. This type is also called the top type. You'll notice that a type is a subtype of itself. Why? Well, if we return to the definition of subtyping, a type is a subtype of another if it can be used in place of the other. So if we have a function that takes a Real
we can pass a Real
to it. So Real
is a subtype of Real
. And Any
is also a supertype of all types, which includes itself, which I think is quite nice.
We could implement this API like this:
class Type
attr_reader :name, :supertype
def initialize(name, supertype = nil)
@name = name
@supertype = supertype
end
def is?(type)
return true if type == self
return false if @supertype.nil?
@supertype.is?(type)
end
def ==(type)
@name == type.name
end
end
We already talked about every type being a subtype of itself, so first return statements do that check. The second one accounts for the case where we call any.is?(real)
. The third return statement recursively traverses the subtyping hierarchy to find if type
is in there.
I also overrode the ==
method to have value equality.
Modeling Function Signatures
A signature is just a list of types.
class Signature
attr_reader :types
def initialize(types)
@types = types
end
def ==(signature)
@types == signature.types
end
end
We need a way to know if it is legal to call a signature with a given list of argument types. Let's look at a simple case:
function f(x::Real) end
We obviously shouldn't be able to call this function with f(Complex())
because Complex
is not a subtype of Real
.
This generalizes to multiple arguments as well. We should be able to call f(Real(), Real())
but not f(Complex(), Real())
which I mentioned before. There is nothing that makes the first argument special. So the gist of it is, we should be able to call a function with a list of types (t1, t2, ..., tn)
where the signature of the function is (s1, s2, ..., sn)
if every t
is a subtype of the corresponding type s
.
The implementation is quite simple:
class Signature
# Other stuff...
def conforms?(signature)
# for a signature to conform to this one:
# 1. it must have the same number of types
return false if signature.types.length != @types.length
# 2. each type must be a subtype of the corresponding type in this signature
@types.zip(signature.types).all? { |a, b| a.is?(b) }
end
end
Ranking Conforming Signatures
Now we come to the meat of the problem. How do we select the most specific function for a given list of argument types? We need a way to rank them. Let's look at the simplest case again.
function f(x::Number) end
function f(x::Complex) end
f(Real())
If I asked you which function should be called, you would say the second one, right? Why? Well, when we consider the subtype hierarchy Real <: Complex <: Number <: Any
, the closest to type Real
is Complex
so you choose that. This is the main idea behind the ranking algorithm. We need to find the closest type to the given type in the signature.
We need a way to get the distance from the given type:
class Type
# Other stuff...
def distance(type)
# We only want to consider the chain not the tree so we don't consider types from different branches. e.g. Real and String
raise "Not a subtype" unless is?(type)
# If we are at the same type, we are 0 distance away
return 0 if self == type
# We are whatever distance `type` away from our supertype + 1.
# Example:
# real.distance(any) = 1 + complex.distance(any)
# = 1 + 1 + number.distance(any)
# = 1 + 1 + 1 + any.distance(any)
# = 1 + 1 + 1 + 0
# = 3
1 + @supertype.distance(type)
end
end
We can already rank functions with a single argument with this. One way to generalize this to multiple arguments is to find the distance between each corresponding type between signature type and argument type and sum them up.
class Signature
# Other stuff...
def distance(signature)
@types.zip(signature.types).sum { |a, b| a.distance(b) }
end
end
If you're a bit interested in math, this is a metric. A signature is like an n-dimensional point, and the distance between two points is the sum of the distances between each corresponding scalar(type) in the point. Honestly, you can even use Euclidean distance if you want, but I think this metric suffices. What we're using here is called Manhattan distance by the way.
If we put this to work:
any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)
f1 = Signature.new([number, number])
f2 = Signature.new([complex, complex])
f3 = Signature.new([real, real])
call_signature = Signature.new([real, real])
puts [f1, f2, f3].min_by { |f| call_signature.distance(f) } == f3 # true
A function isn't just a signature though, it also has a name.
class Function
attr_accessor :name, :signature
def initialize(name, signature)
@name = name
@signature = signature
end
def to_s
"#{@name}#{@signature}"
end
def ==(other)
@name == other.name && @signature == other.signature
end
end
And we'll have a table of functions that contains all the definitions and gives the most specific one for a given call.
class FunctionTable
attr_accessor :functions
def initialize()
@functions = []
end
def add(function)
raise "Function already exists" if @functions.include?(function)
@functions << function
end
def find(function)
# find all the signatures that conform to the given signature.
candidates = @functions.select { |m| function.signature.conforms?(m.signature) && m.name == function.name }
# sort them by distance from closest to furthest.
sorted_by_distance = candidates.sort_by { |m| function.signature.distance(m.signature) }
# find the closest one.
# There may be more than one with the same distance, so we find all of them.
distances = sorted_by_distance.map { |m| function.signature.distance(m.signature) }
min_distance = distances.min
closest_functions = sorted_by_distance.select { |m| function.signature.distance(m.signature) == min_distance }
raise "Ambiguous function call between #{closest_functions}" if closest_functions.length > 1
closest_functions.first
end
end
The most interesting method is find
.
- It finds all the functions with the same name and conforming signature.
- Sorts them by distance.
- Finds the closest one.
- If there are more than one with the same distance, it raises an error. I choose to raise an error but one thing you could also do is return the return the second closest method(I wonder if there is a metric where two different point can't have the same distance from a different third point).
If we put it all together:
any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)
f1 = Function.new("f", Signature.new([number, number]))
f2 = Function.new("f", Signature.new([complex, complex]))
f3 = Function.new("f", Signature.new([real, real]))
table = FunctionTable.new
table.add(f1)
table.add(f2)
table.add(f3)
call_signature = Signature.new([real, real])
puts table.find(Function.new("f", call_signature)) == f3 # true
Further Reading
- https://journal.stuffwithstuff.com/2012/06/12/multimethods-global-scope-and-monkey-patching/
- https://journal.stuffwithstuff.com/2010/10/01/solving-the-expression-problem/
- https://journal.stuffwithstuff.com/2011/04/21/multimethods-multiple-inheritance-multiawesome/
- https://eli.thegreenplace.net/2016/the-expression-problem-and-its-solutions/
- https://www.youtube.com/watch?v=kc9HwsxE1OY&t=442s
Top comments (0)