i couldn’t find a method that draws a scatterplot matrix with matplotlib.pyplot. so i made it. i’m not an expert, so, any advice/criticism is welcome
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
def scatterplot(data, data_name):
"""Makes a scatterplot matrix:
Inputs:
data - a list of data [dataX, dataY,dataZ,...];
all elements must have same length
data_name - a list of descriptions of the data;
len(data) should be equal to len(data_name)
Output:
fig - matplotlib.figure.Figure Object
"""
N = len(data)
fig = plt.figure()
for i in range(N):
for j in range(N):
ax = fig.add_subplot(N,N,i*N+j+1)
if j == 0: ax.set_ylabel(data_name[i],size='12')
if i == 0: ax.set_title(data_name[j],size='12')
if i == j:
ax.hist(data[i], 10)
else:
ax.scatter(data[j], data[i])
return fig
# Example
if __name__ == "__main__":
import numpy as np
import numpy.random as npr
X = npr.randn(100)
Y = 1.2 * X + npr.normal(0.0, 0.1, 100)
Z = - Y ** 2 + X + 0.05 * npr.random(100)
W = X + Y - Z + npr.normal(0.0, 2.0, 100)
data = [X, Y, Z, W]
data_name = ['Data X', 'Data Y', 'Data Z', 'Data W']
fig = scatterplot(data, data_name)
fig.savefig('scatterplot.png', dpi=120)
plt.show()
the output of the example looks like the figure below…

I haven’t given this a go yet but this looks awesome! I love the histograms too – something you don’t get with the pairs command in R.