How to plot a separator line between two data classes?
Asked Answered
P

4

5

I have a simple exercise that I am not sure how to do. I have the following data sets:

male100

    Year    Time
0   1896    12.00
1   1900    11.00
2   1904    11.00
3   1906    11.20
4   1908    10.80
5   1912    10.80
6   1920    10.80
7   1924    10.60
8   1928    10.80
9   1932    10.30
10  1936    10.30
11  1948    10.30
12  1952    10.40
13  1956    10.50
14  1960    10.20
15  1964    10.00
16  1968    9.95
17  1972    10.14
18  1976    10.06
19  1980    10.25
20  1984    9.99
21  1988    9.92
22  1992    9.96
23  1996    9.84
24  2000    9.87
25  2004    9.85
26  2008    9.69

and the second one:

female100

    Year    Time
0   1928    12.20
1   1932    11.90
2   1936    11.50
3   1948    11.90
4   1952    11.50
5   1956    11.50
6   1960    11.00
7   1964    11.40
8   1968    11.00
9   1972    11.07
10  1976    11.08
11  1980    11.06
12  1984    10.97
13  1988    10.54
14  1992    10.82
15  1996    10.94
16  2000    11.12
17  2004    10.93
18  2008    10.78

I have the following code:

y = -0.014*male100['Year']+38

plt.plot(male100['Year'],y,'r-',color = 'b')
ax = plt.gca() # gca stands for 'get current axis'
ax = male100.plot(x=0,y=1, kind ='scatter', color='g', label="Mens 100m", ax = ax)
female100.plot(x=0,y=1, kind ='scatter', color='r', label="Womens 100m", ax = ax)

Which produces this result:

enter image description here

I need to plot a line that would go exactly between them. So the line would leave all of the green points below it, and the red point above it. How do I do so?

I've tried playing with the parameters of y, but to no avail. I also tried fitting a linear regression to male100 , female100 , and the merged version of them (across rows), but couldn't get any results.

Any help would be appreciated!

Perimorph answered 18/11, 2020 at 13:44 Comment(0)
S
6

A solution is using support vector machine (SVM). You can find two margins that separate two classes of points. Then, the average line of two support vectors is your answer. Notice that it's happened just when these two set of points are linearly separable. enter image description here
You can use the following code to see the result:

Data Entry

male = [
(1896  ,  12.00),
(1900  ,  11.00),
(1904  ,  11.00),
(1906  ,  11.20),
(1908  ,  10.80),
(1912  ,  10.80),
(1920  ,  10.80),
(1924  ,  10.60),
(1928  ,  10.80),
(1932  ,  10.30),
(1936  ,  10.30),
(1948  ,  10.30),
(1952  ,  10.40),
(1956  ,  10.50),
(1960  ,  10.20),
(1964  ,  10.00),
(1968  ,  9.95),
(1972  ,  10.14),
(1976  ,  10.06),
(1980  ,  10.25),
(1984  ,  9.99),
(1988  ,  9.92),
(1992  ,  9.96),
(1996  ,  9.84),
(2000  ,  9.87),
(2004  ,  9.85),
(2008  ,  9.69)
        ]
female = [
(1928,    12.20),
(1932,    11.90),
(1936,    11.50),
(1948,    11.90),
(1952,    11.50),
(1956,    11.50),
(1960,    11.00),
(1964,    11.40),
(1968,    11.00),
(1972,    11.07),
(1976,    11.08),
(1980,    11.06),
(1984,    10.97),
(1988,    10.54),
(1992,    10.82),
(1996,    10.94),
(2000,    11.12),
(2004,    10.93),
(2008,    10.78)
]

Main Code

Notice that the value of C is important here. If it is selected to 1, you can't get the preferred result.

from sklearn import svm
import numpy as np
import matplotlib.pyplot as plt

X = np.array(male + female)
Y = np.array([0] * len(male) + [1] * len(female))

# fit the model
clf = svm.SVC(kernel='linear', C=1000) # C is important here
clf.fit(X, Y)
plt.figure(figsize=(8, 4))
# get the separating hyperplane
w = clf.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(-1000, 10000)
yy = a * xx - (clf.intercept_[0]) / w[1]
plt.figure(1, figsize=(4, 3))
plt.clf()
plt.plot(xx, yy, "k-") #********* This is the separator line ************

plt.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired,
 edgecolors="k")
plt.xlim((1890, 2010))  
plt.ylim((9, 13)) 
plt.show()
Sweetie answered 18/11, 2020 at 13:58 Comment(0)
S
2

I believe your idea of making use of regression lines is correct - if they aren't used, the line would be merely superficial (and impossible to justify if the points overlap in the event of messy data). Therefore, using some randomly made data with a known linear relationship, we can do the following:

import random
import numpy as np
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression

x_values = np.arange(0, 51, 1)

y_points_1 = [i * 2 + random.randint(5, 30) for i in x_points]
y_points_2 = [i - random.randint(5, 30) for i in x_points]

x_points = x_values.reshape(-1, 1)

def regression(x, y):
    model = LinearRegression().fit(x, y)
    y_pred = model.predict(x)
    
    return y_pred

barrier = [(regression(x=x_points, y=y_points_1)[i] + value) / 2 for i, value in enumerate(regression(x=x_points, y=y_points_2))]

plt.plot(x_points, regression(x=x_points, y=y_points_1))
plt.plot(x_points, regression(x=x_points, y=y_points_2))
plt.plot(x_points, barrier)
plt.scatter(x_values, y_points_1)
plt.scatter(x_values, y_points_2)
plt.grid(True)
plt.show()

Giving us the following plot:

enter image description here

This method also works for an overlap in the data points, so if we change the random data slightly and apply the same process:

x_values = np.arange(0, 51, 1)

y_points_1 = [i * 2 + random.randint(-10, 30) for i in x_points]
y_points_2 = [i - random.randint(-10, 30) for i in x_points]

We get something like the following:

enter image description here

It is important to note that the lists used here are of the same length, so you would need to add some predicted points to the female data after applying regression in order to make use of the line between them. These points would merely be along the regression line with the x-values corresponding to those present in the male data.

Subroutine answered 18/11, 2020 at 15:53 Comment(0)
G
2

Because sklearn might be a bit over the top for a linear fit and to get rid of the condition that you would need the same number of data points for male and female data, here the same implementation with numpy.polyfit. This also demonstrates that their approach is not a solution to the problem.

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

#data import
male = pd.read_csv("test1.txt", delim_whitespace=True)
female = pd.read_csv("test2.txt", delim_whitespace=True)

#linear fit of both populations
pmale   = np.polyfit(male.Year, male.Time, 1)
pfemale = np.polyfit(female.Year, female.Time, 1)

#more appealing presentation, let's pretend we do not just fit a line
x_fitmin=min(male.Year.min(), female.Year.min())
x_fitmax=max(male.Year.max(), female.Year.max())
x_fit=np.linspace(x_fitmin, x_fitmax, 100)

#create functions for the three fit lines
male_fit   = np.poly1d(pmale)
print(male_fit)
female_fit = np.poly1d(pfemale)
print(female_fit)
sep        = np.poly1d(np.mean([pmale, pfemale], axis=0))
print(sep)

#plot all markers and lines
ax = male.plot(x="Year", y="Time", c="blue", kind="scatter", label="male")
female.plot(x="Year", y="Time", c="red", kind="scatter", ax=ax, label="female")
ax.plot(x_fit, male_fit(x_fit), c="blue", ls="dotted", label="male fit")
ax.plot(x_fit, female_fit(x_fit), c="red", ls="dotted", label="female fit")
ax.plot(x_fit, sep(x_fit), c="black", ls="dashed", label="separator")

plt.legend()
plt.show()

Sample output:

-0.01333 x + 36.42
 
-0.01507 x + 40.92
 
-0.0142 x + 38.67

enter image description here

And one point is still in the wrong section. However - I find this question so interesting because I expected answers from the sklearn crowd for non-linear data groups. I even installed sklearn in anticipation! If in the next days nobody posts a good solution with SVMs, I will set a bounty on this question.

Gradatim answered 18/11, 2020 at 23:23 Comment(5)
Thanks for the answer, but there is a blue point in the left end which is above the line!Perimorph
I said it replicates what Chadd Robertson proposed demonstrating that their approach is not a solution to the problem to spare you the disappointment that you go through all this hassle implementing the general approach with your data for nothing. I made this now clear in my post.Gradatim
I agree with this completely. However, it should be noted that when comparing two sets of random data, the desire to have a line cut them perfectly may not always be satisfiable (that's the reason I grouped the data points more closely in my second plot). But the answer by OmG looks like it accounts for this - definitely a better approach.Subroutine
I tried the convex hull approach after posting this. But what looks like an easy task - the separation of two clearly distinct convex hull figures - turned out to be more difficult to formalize. Tbh, most of my time I wasted on trying to understand the barely documented qhull feature. But I love this question because it lead me into territories I rarely visit. Should I come up with a complete qhull solution (OmG cheats in this answer by omitting the difficult part and taking the equation from his brilliant SVM approach), I will post it belatedly.Gradatim
That would be appreciated, this seemingly easy question actually turned out to be a bigger monster than I initially thought. Thinking about it now, the convex hull approach seems to be the best starting point - I played around with making the perimeters and casting normal vectors from them, but have since put it aside.Subroutine
S
2

One solution is the geometrical approach. You can find the convex hull of each data class, then find a line that goes through these two convex hulls. To find the line, you can find inner tangent line between two convex hulls using this code, and rotate it a little bit.

enter image description here

You can use the following code:

from scipy.spatial import ConvexHull, convex_hull_plot_2d

male = np.array(male)
female = np.array(female)

hull_male = ConvexHull(male)
hull_female = ConvexHull(female)

plt.plot(male[:,0], male[:,1], 'o')
for simplex in hull_male.simplices:
    plt.plot(male[simplex, 0], male[simplex, 1], 'k-')

# Here, the separator line comes from SMV‌ result. 
# Just to show the a separator as an exmple
# plt.plot(xx, yy, "k-")
    
plt.plot(female[:,0], female[:,1], 'o')
for simplex in hull_female.simplices:
    plt.plot(female[simplex, 0], female[simplex, 1], 'k-')
    
plt.xlim((1890, 2010))  
plt.ylim((9, 13)) 
Sweetie answered 19/11, 2020 at 17:41 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.