scatterplot matrix with matplotlib

i couldn’t find a method that draws a scatterplot matrix with matplotlib.pyplot. so i made it. i’m not an expert, so, any advice/criticism is welcome ;)

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt

def scatterplot(data, data_name):
    """Makes a scatterplot matrix:
    Inputs:
      data - a list of data [dataX, dataY,dataZ,...];
             all elements must have same length

      data_name - a list of descriptions of the data;
                  len(data) should be equal to len(data_name)

    Output:
      fig - matplotlib.figure.Figure Object

    """

    N = len(data)
    fig = plt.figure()

    for i in range(N):
        for j in range(N):
            ax = fig.add_subplot(N,N,i*N+j+1)

            if j == 0: ax.set_ylabel(data_name[i],size='12')
            if i == 0: ax.set_title(data_name[j],size='12')
            if i == j:
                ax.hist(data[i], 10)
            else:
                ax.scatter(data[j], data[i])

    return fig

# Example
if __name__ == "__main__":
    import numpy as np
    import numpy.random as npr

    X = npr.randn(100)
    Y = 1.2 * X + npr.normal(0.0, 0.1, 100)
    Z = - Y ** 2 + X + 0.05 * npr.random(100)
    W = X + Y - Z + npr.normal(0.0, 2.0, 100)

    data = [X, Y, Z, W]
    data_name = ['Data X', 'Data Y', 'Data Z', 'Data W']

    fig = scatterplot(data, data_name)

    fig.savefig('scatterplot.png', dpi=120)
    plt.show()

the output of the example looks like the figure below…

1 Comment to "scatterplot matrix with matplotlib"

  1. Jim's Gravatar Jim
    01/21/2012 - 11:05 AM | Permalink

    I haven’t given this a go yet but this looks awesome! I love the histograms too – something you don’t get with the pairs command in R.

Leave a Reply