Extending Inference with User defined type

With the available interface to extend the inference, described at previous section, Klara also provide wrapper to defined any user type, and how should the user type evolve over any python operation. This is how we extend the inference to support z3.

At sections below, we’ll look at how we integrate z3 step by step, and by the end of the sections, you should have a good understandings on how to introduce any new user type.

Defining the wrapper

First, we’ll define the data type class, which will subclass from klara.InferProxy.:

import klara

class Z3Proxy(klara.InferProxy):
    def __init__(self, z3_expr):
        super(Z3Proxy, self).__init__(z3_expr)

klara.InferProxy itself is subclass of klara.Const node, which the value attribute will hold our datatype, in this case, a Z3 expression. With this, we’ve defined our type and ready to use it.

Define inferring point

Next, we’ll need to tell the inference system which node will inferred to our data type. We can use the method described in extending to register plugins to customize the system in order to yield Z3Proxy on certain node. In this case, we’ll want to convert an argument with type annotation to z3 expression. I.e.:

def foo(a: int, b: str):
    return 1 if a > 2 and b == "s" else 2

We’ll want to convert python klara.Compare node: a > 2 to a z3 expression equivalent, the comparison will convert to something like below:

>>> a = z3.Int("a")
>>> b = z3.String("b")
>>> expr = z3.And(a > 2, b == z3.StringVal("s"))
>>> Z3Proxy(expr)
And(a > 2, b == "s")

We then need to define our custom infer function for klara.Arg.:

import builtins

AST2Z3TYPE_MAP = {"int": z3.Int, "float": z3.Real, "bool": z3.Bool, "str": z3.String}

@klara.inference.inference_transform_wrapper
def _infer_arg(node: klara.Arg, context):
    name = node.arg
    z3_var_type = AST2Z3TYPE_MAP[node.annotation]
    z3_var = z3_var_type(name)
    proxy = Z3Proxy(z3_var)
    yield klara.inference.InferenceResult.load_result(proxy)

klara.MANAGER.register_transform(nodes.Arg, _infer_arg)

Note

This is a very minimal implementation, and it does not handle errors. (e.g. when annotation is another type). In case there is error, the function can raise klara.inference.UseInferenceDefault to proceed with default or other plugins.

Define python operation

With above, we should be able to yield the z3 expression on any annotated function argument. So far, we’ve only covered z3 variable construction. We’ll also need to specify how this variable go through binary operation, compare, etc… (e.g. to build a + 2 > 12 z3 expression). We can do it easily by using special dunder method, but with __k_ prefix. This is to avoid clashing with the actual dunder method.:

class Z3Proxy(klara.InferProxy):
    def __init__(self, z3_expr=None):
        super(Z3Proxy, self).__init__(z3_expr)

    def __k_add__(self, other: klara.Const):
        """represent __add__ dunder method"""
        left = self.value
        right = other.value
        expr = left + right
        # we'll create a new Z3Proxy, wrapping the new expression
        return klara.inference.InferenceResult.load_result(Z3Proxy(expr))

    def __k_eq__(self, other: klara.Const):
        left = self.value
        right = other.value
        expr = left == right
        return klara.inference.InferenceResult.load_result(Z3Proxy(expr))

    def __k_bool__(self):
        yield klara.inference.InferenceResult(self, status=True)

Note

the reason why __k_bool__ is needed because in Compare node, Python will call bool() on the value to determine if the result is true or false. source

Using it

We should be able to obtain any expression with only + and == operation. We can then use the inferred value, for example, query the z3 solver:

source = """
    def foo(a: int):
        return a + 2 == 12
    """
tree = klara.parse(source)
for res in tree.body[0].infer_return_value():
    z3.solve(res.result.value)

Putting it all together

import z3
import klara


class Z3Proxy(klara.InferProxy):
    def __init__(self, z3_expr):
        super(Z3Proxy, self).__init__(z3_expr)

    def __k_add__(self, other: klara.Const):
        """represent __add__ dunder method"""
        left = self.value
        right = other.value
        expr = left + right
        # we'll create a new Z3Proxy, wrapping the new expression
        return klara.inference.InferenceResult.load_result(Z3Proxy(expr))

    def __k_eq__(self, other: klara.Const):
        left = self.value
        right = other.value
        expr = left == right
        return klara.inference.InferenceResult.load_result(Z3Proxy(expr))

    def __k_bool__(self):
        yield klara.inference.InferenceResult(self, status=True)


AST2Z3TYPE_MAP = {"int": z3.Int, "float": z3.Real, "bool": z3.Bool, "str": z3.String}


@klara.inference.inference_transform_wrapper
def _infer_arg(node: klara.Arg, context):
    name = node.arg
    z3_var_type = AST2Z3TYPE_MAP[str(node.annotation)]
    z3_var = z3_var_type(name)
    proxy = Z3Proxy(z3_var)
    yield klara.inference.InferenceResult.load_result(proxy)


klara.MANAGER.register_transform(klara.Arg, _infer_arg)

source = """
    def foo(a: int):
        return a + 2 == 12
    """
tree = klara.parse(source)
for res in tree.body[0].infer_return_value():
    z3.solve(res.result.value)
Which will print::

[a = 10]