How to extract all functions and API calls used in a Python source code?

Let us consider the following Python source code;

def package_data(pkg, roots):
    data = []
    for root in roots:
        for dirname, _, files in os.walk(os.path.join(pkg, root)):
            for fname in files:
                data.append(os.path.relpath(os.path.join(dirname, fname), pkg))

    return {pkg: data}

From this source code, I want to extract all the functions and API calls. I found a similar question and solution. I ran the solution given here and it generates the output [os.walk, data.append]. But I am looking for the following output [os.walk, os.path.join, data.append, os.path.relpath, os.path.join].

What I understood after analyzing the following solution code, this can visit the every node before the first bracket and drop rest of the things.

import ast

class CallCollector(ast.NodeVisitor):
    def __init__(self):
        self.calls = []
        self.current = None

    def visit_Call(self, node):
        # new call, trace the function expression
        self.current = ''
        self.visit(node.func)
        self.calls.append(self.current)
        self.current = None

    def generic_visit(self, node):
        if self.current is not None:
            print("warning: {} node in function expression not supported".format(
                  node.__class__.__name__))
        super(CallCollector, self).generic_visit(node)

    # record the func expression 
    def visit_Name(self, node):
        if self.current is None:
            return
        self.current += node.id

    def visit_Attribute(self, node):
        if self.current is None:
            self.generic_visit(node)
        self.visit(node.value)  
        self.current += '.' + node.attr

tree = ast.parse(yoursource)
cc = CallCollector()
cc.visit(tree)
print(cc.calls)

Can anyone please help me to modified this code so that this code can traverse the API calls inside the bracket?

N.B: This can be done using regex in python. But it requires a lot of manual labors to find out the appropriate API calls. So, I am looking something with help of Abstract Syntax Tree.

1 answer

  • answered 2018-07-20 21:25 MSeifert

    Not sure if this is the best or simplest solution but at least it does work as intended for your case:

    import ast
    
    class CallCollector(ast.NodeVisitor):
        def __init__(self):
            self.calls = []
            self._current = []
            self._in_call = False
    
        def visit_Call(self, node):
            self._current = []
            self._in_call = True
            self.generic_visit(node)
    
        def visit_Attribute(self, node):
            if self._in_call:
                self._current.append(node.attr)
            self.generic_visit(node)
    
        def visit_Name(self, node):
            if self._in_call:
                self._current.append(node.id)
                self.calls.append('.'.join(self._current[::-1]))
                # Reset the state
                self._current = []
                self._in_call = False
            self.generic_visit(node)
    

    Gives for your example:

    ['os.walk', 'os.path.join', 'data.append', 'os.path.relpath', 'os.path.join']
    

    The problem is that you have to do a generic_visit in all visits to ensure you walk the tree properly. I also used a list as current to join the (reversed) afterwards.

    One case I found that doesn't work with this approach is on chained operations, for example: d.setdefault(10, []).append(10).


    Just in case you're interested in how I arrived at that solution:

    Assume a very simple implementation of a node-visitor:

    import ast
    
    class CallCollector(ast.NodeVisitor):
        def generic_visit(self, node):
            try:
                print(node, node.id)
            except AttributeError:
                try:
                    print(node, node.attr)
                except AttributeError:
                    print(node)
            return super().generic_visit(node)
    

    This will print a lot of stuff, however if you look at the result you'll see some patterns, like:

    ...
    <_ast.Call object at 0x000001AAEE8FFA58>
    <_ast.Attribute object at 0x000001AAEE8FFBE0> walk
    <_ast.Name object at 0x000001AAEE8FF518> os
    ...
    

    and

    ...
    <_ast.Call object at 0x000001AAEE8FF160>
    <_ast.Attribute object at 0x000001AAEE8FF588> join
    <_ast.Attribute object at 0x000001AAEE8FFC50> path
    <_ast.Name object at 0x000001AAEE8FF5C0> os
    ...
    

    So first the call-node is visited, then the attributes (if any) and then finally the name. So you have to reset the state when you visit a call-node, append all attributes to it and stop if you hit a name node.

    One could do it within the generic_visit but it's probably better to do it in the methods visit_Call, ... and then just call generic_visit from these.


    A word of caution is probably in order: This works great for simple cases but as soon as it becomes non-trivial this will not work reliably. For example what if you import a subpackage? What if you bind the function to a local variable? What if you call the result of a getattr result? Listing the functions that are called by static analysis in Python is probably impossible, because beside the ordinary problems there's also frame-hacking and dynamic assignments (for example if some import or called function re-assigned the name os in your module).