XClose

Research Software Engineering Summer School

Home
Menu

Cython

Cython can be viewed as an extension of Python where variables and functions are annotated with extra information, in particular types. The resulting Cython source code will be compiled into optimized C or C++ code, and thereby yielding substantial speed-up of slow Python code. In other words, Cython provides a way of writing Python with comparable performance to that of C/C++.

Start Coding in Cython

Cython code must, unlike Python, be compiled. This happens in the following stages:

  • The cython code in .pyx file will be translated to a C file.
  • The C file will be compiled by a C compiler into a shared library, which will be directly loaded into Python.

In a Jupyter notebook, everything is a lot easier. One needs only to load the Cython extension (%load_ext Cython) at the beginning and put %%cython mark in front of cells of Cython code. Cells with Cython mark will be treated as a .pyx code and consequently, compiled into C.

For details, please see Building Cython Code.

Pure python Mandelbrot set:

In [1]:
xmin = -1.5
ymin = -1.0
xmax = 0.5
ymax = 1.0
resolution = 300
xstep = (xmax - xmin) / resolution
ystep = (ymax - ymin) / resolution
xs = [(xmin + (xmax - xmin) * i / resolution) for i in range(resolution)]
ys = [(ymin + (ymax - ymin) * i / resolution) for i in range(resolution)]
In [2]:
def mandel(position, limit=50):
    value = position
    while abs(value) < 2:
        limit -= 1
        value = value**2 + position
        if limit < 0:
            return 0
    return limit

Compiled by Cython:

In [3]:
%load_ext Cython
In [4]:
%%cython

def mandel_cython(position, limit=50):
    value = position
    while abs(value) < 2:
        limit -= 1
        value = value**2 + position
        if limit < 0:
            return 0
    return limit

Let's verify the result

In [5]:
from matplotlib import pyplot as plt
%matplotlib inline
f, axarr = plt.subplots(1, 2)
axarr[0].imshow([[mandel(complex(x, y)) for x in xs] for y in ys], interpolation='none')
axarr[0].set_title('Pure Python')
axarr[1].imshow([[mandel_cython(complex(x, y)) for x in xs] for y in ys], interpolation='none')
axarr[1].set_title('Cython')
Out[5]:
Text(0.5, 1.0, 'Cython')
No description has been provided for this image
In [6]:
%timeit [[mandel(complex(x,y)) for x in xs] for y in ys] # pure python
%timeit [[mandel_cython(complex(x,y)) for x in xs] for y in ys] # cython
345 ms ± 639 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
305 ms ± 688 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

We have improved the performance of a factor of 1.5 by just using the Cython compiler, without changing the code!

Cython with C Types

But we can do better by telling Cython what C data type we would use in the code. Note we're not actually writing C, we're writing Python with C types.

typed variable

In [7]:
%%cython
def var_typed_mandel_cython(position, limit=50):
    cdef double complex value # typed variable
    value = position
    while abs(value) < 2:
        limit -= 1
        value = value**2 + position
        if limit < 0:
            return 0
    return limit

typed function + typed variable

In [8]:
%%cython
cpdef call_typed_mandel_cython(double complex position,
                               int limit=50): # typed function
    cdef double complex value # typed variable
    value = position
    while abs(value)<2:
        limit -= 1
        value = value**2 + position
        if limit < 0:
            return 0
    return limit

performance of one number:

In [9]:
# pure python
%timeit a = mandel(complex(0, 0)) 
7.45 μs ± 33.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [10]:
# primitive cython
%timeit a = mandel_cython(complex(0, 0)) 
6.44 μs ± 16.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [11]:
# cython with C type variable
%timeit a = var_typed_mandel_cython(complex(0, 0)) 
3.69 μs ± 1.62 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [12]:
# cython with typed variable + function
%timeit a = call_typed_mandel_cython(complex(0, 0))
1.04 μs ± 1.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Cython with numpy ndarray

You can use NumPy from Cython exactly the same as in regular Python, but by doing so you are losing potentially high speedups because Cython has support for fast access to NumPy arrays.

In [13]:
import numpy as np
ymatrix, xmatrix = np.mgrid[ymin:ymax:ystep, xmin:xmax:xstep]
values = xmatrix + 1j * ymatrix
In [14]:
%%cython
import numpy as np
cimport numpy as np 

cpdef numpy_cython_1(np.ndarray[double complex, ndim=2] position, 
                     int limit=50): 
    cdef np.ndarray[long,ndim=2] diverged_at
    cdef double complex value
    cdef int xlim
    cdef int ylim
    cdef double complex pos
    cdef int steps
    cdef int x, y

    xlim = position.shape[1]
    ylim = position.shape[0]
    diverged_at = np.zeros([ylim, xlim], dtype=int)
    for x in xrange(xlim):
        for y in xrange(ylim):
            steps = limit
            value = position[y,x]
            pos = position[y,x]
            while abs(value) < 2 and steps >= 0:
                steps -= 1
                value = value**2 + pos
            diverged_at[y,x] = steps
  
    return diverged_at
Content of stderr:
In file included from /opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/_core/include/numpy/ndarraytypes.h:1909,
                 from /opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/_core/include/numpy/ndarrayobject.h:12,
                 from /opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/_core/include/numpy/arrayobject.h:5,
                 from /home/runner/.cache/ipython/cython/_cython_magic_e4fa8e122ebc90bcf13ecaca8c44cb98ca78016e.c:1250:
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/_core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
   17 | #warning "Using deprecated NumPy API, disable it with " \
      |  ^~~~~~~

Note the double import of numpy: the standard numpy module and a Cython-enabled version of numpy that ensures fast indexing of and other operations on arrays. Both import statements are necessary in code that uses numpy arrays. The new thing in the code above is declaration of arrays by np.ndarray.

In [15]:
%timeit data_cy = [[mandel(complex(x,y)) for x in xs] for y in ys] # pure python
338 ms ± 2.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [16]:
%timeit data_cy = [[call_typed_mandel_cython(complex(x,y)) for x in xs] for y in ys] # typed cython
236 ms ± 239 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [17]:
%timeit numpy_cython_1(values) # ndarray
223 ms ± 387 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

A trick of using np.vectorize

In [18]:
numpy_cython_2 = np.vectorize(call_typed_mandel_cython)
In [19]:
%timeit numpy_cython_2(values) #  vectorize
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/lib/_function_base_impl.py:2480: RuntimeWarning: divide by zero encountered in call_typed_mandel_cython (vectorized)
  outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/numpy/lib/_function_base_impl.py:2480: RuntimeWarning: invalid value encountered in call_typed_mandel_cython (vectorized)
  outputs = ufunc(*inputs)
230 ms ± 356 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Calling C functions from Cython

Example: compare sin() from Python and C library

In [20]:
%%cython
import math
cpdef py_sin():
    cdef int x
    cdef double y
    for x in range(1e7):
        y = math.sin(x)
In [21]:
%%cython
from libc.math cimport sin as csin # import from C library
cpdef c_sin():
    cdef int x
    cdef double y
    for x in range(1e7):
        y = csin(x)
In [22]:
%timeit [math.sin(i) for i in range(int(1e7))] # python
1.01 s ± 6.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [23]:
%timeit py_sin()                                # cython call python library
999 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [24]:
%timeit c_sin()                                 # cython call C library
3.19 ms ± 4.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)