Working with tuples¶
How to zip several tuples¶
Problem:¶
While zip
works in jitted code, it produces a generator and not a tuple.
Some functions, like literal_unroll
, require a tuple to work.
Solution:¶
from numba.extending import overload
from numba import types
from numba.extending import intrinsic
from numba.core.cgutils import unpack_tuple
def tuple_zip(*args):
return tuple(zip(args))
@overload(tuple_zip)
def tuple_zip_ovrl(*args):
return lambda *args: tuple_zip_intr(*args)
@intrinsic
def tuple_zip_intr(tyctx, *tys):
if len(tys) > 1:
tys = types.StarArgTuple(tys)
elif len(tys) == 1:
raise ValueError("Only one argument received. Tuples to be zipped must be passed as individual arguments")
nitems = min((x.count for x in tys))
tuples = [types.Tuple(inner_ty) for inner_ty in zip(*tys)]
ret = types.Tuple(tuples)
def codegen(cgctx, builder, sig, args):
assert len(args) == 1 # it is a vararg tuple
args_tup = unpack_tuple(builder, args[0])
values = []
for i in range(nitems):
inner_vals = [builder.extract_value(x, i) for x in args_tup]
inner_tup = cgctx.make_tuple(builder, tuples[i], inner_vals)
values.append(inner_tup)
res = cgctx.make_tuple(builder, sig.return_type, values)
cgctx.nrt.incref(builder, sig.return_type, res) # increment RC
return res
sig = ret(tys)
return sig, codegen
Example:¶
from numba import literal_unroll, njit
@njit
def f1():
return 1
@njit
def f2():
return 2
@njit
def f3():
return 3
@njit
def f4():
return 4
a = (f1, f2)
b = (f3, f4)
c = (f1, f4)
@njit
def foo(a, b):
for x in literal_unroll(tuple_zip(a, b, c)):
f, g, h = x
print(f()+g()+h())
foo(a, b)
How to use tuples of different length as keys in a dictionary¶
Problem:¶
The keys of a dictionary must be all of the same type. The type of a tuple is determined by its length, and therefore tuples of different length cannot be used as keys in the same dictionary.
Solution:¶
from numba import njit, literal_unroll, types
from numba.typed import Dict
import numpy as np
from numba.experimental import structref
from numba.extending import overload
import operator
# The idea here is to wrap a typed.Dict in another type, the "TupleKeyDictType".
# The purpose of this is so that operations like __getitem__ and __setitem__
# can be proxied through functions that call `hash` on the key. This makes it
# possible to have something that behaves like a dictionary, but supports
# heterogeneous keys (tuples of varying size/type).
# Define a the new type and register it
@structref.register
class TupleKeyDictType(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
# Define the Python side proxy class
class TupleKeyDict(structref.StructRefProxy):
@property
def wrapped_dict(self):
return TupleKeyDict_get_wrapped_dict(self)
# Set up the wiring for it, "wrapped_dict" is the only member in the "struct"
# and it refers to the typed.Dict instance in use
structref.define_proxy(TupleKeyDict, TupleKeyDictType, ["wrapped_dict"])
# Overload operator.getitem for the TupleKeyDictType, note how defers the look
# up to the wrapped_dict member and hashes the key
@overload(operator.getitem)
def ol_tkd_getitem(inst, key):
if isinstance(inst, TupleKeyDictType):
def impl(inst, key):
return inst.wrapped_dict[hash(key)]
return impl
# Overload operator.setitem for the TupleKeyDictType, again, it's hashing the
# key before use.
@overload(operator.setitem)
def ol_tkd_setitem(inst, key, value):
if isinstance(inst, TupleKeyDictType):
def impl(inst, key, value):
inst.wrapped_dict[hash(key)] = value
return impl
Example:¶
@njit
def foo(keys, values):
# Create a dictionary to wrap
wrapped_dictionary = Dict.empty(types.intp, types.complex128)
# wrap it
tkd_inst = TupleKeyDict(wrapped_dictionary)
# Add some items, this is a bit contrived...
# keys is heterogeneous in dtype (different sized tuples) so needs loop
# body versioning for iteration (i.e. literal_unroll).
idx = 0
for k in literal_unroll(keys):
tkd_inst[k] = values[idx]
idx += 1
# print the wrapped instance
print(tkd_inst.wrapped_dict)
# demo getitem
print("getitem", (1, 2), "gives", tkd_inst[(1, 2)])
keyz = ((1, 2), (3, 4, 5), (6,))
valuez = (1j, 2j, 3j)
foo(keyz, valuez)