I was doing some machine learning coding and I wanted to programmatically create a scatter plot. My data looks like:
xs = np.array([1, 2, 3, 4, 5, 6, 7, 8]) ys = np.array([3, 4, 8, 5, 6, 7, 3, 2]) lbls = np.array([0, 2, 1, 1, 1, 2, 0, 0])
So, there were eight data points. The first one was (1,3) and had label=0, the second point was (2,4) with label=2, and so on. There are three class labels: 0, 1, 2 so I wanted three colors. Now I’ve created such a scatter plot using matplotlib many times, but I can never remember the exact syntax. So I did a what I thought would be a quick Internet search.
This was one of those weird cases where I found reference after reference — but they were all looking at very unusual scenarios, not the basic scenario I wanted to perform.
Anyway, I eventually put the pieces of the puzzle together but it took much longer than I thought it would — about 30 minutes when I was expected maybe 2 minutes.
The crux of the issue is that in most situations, you want to plot many points with just a few classes, so you process data one class at a time rather than by one data item at a time. I wanted to process by data item, not by class.
One of the morals of this story is that I’ve become spoiled by the information that’s available on the Internet. I remember the pre-Internet days when tackling a small problem often meant a walk to the library, and a search through several books.
Left: Men prefer red. In an experiment where a woman was pictured in a red dress and the same dress in blue, men consistently rated the woman in red as more attractive. Center: Color affects taste. In an experiment, hot chocolate was served in red, orange, cream, and white mugs. People consistently rated the hot chocolate served in the white mug and cream mug as better tasting. Right: Some colors can affect your mood. Although the effect varies, a pink room tends to make people calmer, and a yellow room can make people mildly nauseous, which is why airplanes rarely have yellow interiors.
# scatter_test.py import matplotlib.pyplot as plt import numpy as np xs = np.array([1, 2, 3, 4, 5, 6, 7, 8]) ys = np.array([3, 4, 8, 5, 6, 7, 3, 2]) lbls = np.array([0, 2, 1, 1, 1, 2, 0, 0]) # can use named colors or HTML codes colormap = np.array(['red', 'lightseagreen', '#F39C12']) plt.scatter(xs, ys, s=120, c=colormap[lbls]) plt.show()