Remove data points below a curve with python
Asked Answered
P

4

6

I need to compare some theoretical data with real data in python. The theoretical data comes from resolving an equation. To improve the comparative I would like to remove data points that fall far from the theoretical curve. I mean, I want to remove the points below and above red dashed lines in the figure (made with matplotlib). Data points and theoretical curves

Both the theoretical curves and the data points are arrays of different length.

I can try to remove the points in a roughly-eye way, for example: the first upper point can be detected using:

data2[(data2.redshift<0.4)&data2.dmodulus>1]
rec.array([('1997o', 0.374, 1.0203223485103787, 0.44354759972859786)], dtype=[('SN_name', '|S10'), ('redshift', '<f8'), ('dmodulus', '<f8'), ('dmodulus_error', '<f8')])    

But I would like to use a less roughly-eye way.

So, can anyone help me finding an easy way of removing the problematic points?

Thank you!

Parceling answered 31/10, 2011 at 19:52 Comment(2)
Just purely from a scientific point of view, I would not remove the points unless there is an EXTREMELY valid reason that you think that they are wrong. You have enough data that the outlying points will not have any effect on the fit, so removing them only serves to make the graph look pretty, without serving any scientific purpose.Massorete
You are right, but I was told to.Parceling
D
4

This might be overkill and is based on your comment

Both the theoretical curves and the data points are arrays of different length.

I would do the following:

  1. Truncate the data set so that its x values lie within the max and min values of the theoretical set.
  2. Interpolate the theoretical curve using scipy.interpolate.interp1d and the above truncated data x values. The reason for step (1) is to satisfy the constraints of interp1d.
  3. Use numpy.where to find data y values that are out side the range of acceptable theory values.
  4. DONT discard these values, as was suggested in comments and other answers. If you want for clarity, point them out by plotting the 'inliners' one color and the 'outliers' an other color.

Here's a script that is close to what you are looking for, I think. It hopefully will help you accomplish what you want:

import numpy as np
import scipy.interpolate as interpolate
import matplotlib.pyplot as plt

# make up data
def makeUpData():
    '''Make many more data points (x,y,yerr) than theory (x,y),
    with theory yerr corresponding to a constant "sigma" in y, 
    about x,y value'''
    NX= 150
    dataX = (np.random.rand(NX)*1.1)**2
    dataY = (1.5*dataX+np.random.rand(NX)**2)*dataX
    dataErr = np.random.rand(NX)*dataX*1.3
    theoryX = np.arange(0,1,0.1)
    theoryY = theoryX*theoryX*1.5
    theoryErr = 0.5
    return dataX,dataY,dataErr,theoryX,theoryY,theoryErr

def makeSameXrange(theoryX,dataX,dataY):
    '''
    Truncate the dataX and dataY ranges so that dataX min and max are with in
    the max and min of theoryX.
    '''
    minT,maxT = theoryX.min(),theoryX.max()
    goodIdxMax = np.where(dataX<maxT)
    goodIdxMin = np.where(dataX[goodIdxMax]>minT)
    return (dataX[goodIdxMax])[goodIdxMin],(dataY[goodIdxMax])[goodIdxMin]

# take 'theory' and get values at every 'data' x point
def theoryYatDataX(theoryX,theoryY,dataX):
    '''For every dataX point, find interpolated thoeryY value. theoryx needed
    for interpolation.'''
    f = interpolate.interp1d(theoryX,theoryY)
    return f(dataX[np.where(dataX<np.max(theoryX))])

# collect valid points
def findInlierSet(dataX,dataY,interpTheoryY,thoeryErr):
    '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
    valid indicies.'''
    withinUpper = np.where(dataY<(interpTheoryY+theoryErr))
    withinLower = np.where(dataY[withinUpper]
                    >(interpTheoryY[withinUpper]-theoryErr))
    return (dataX[withinUpper])[withinLower],(dataY[withinUpper])[withinLower]

def findOutlierSet(dataX,dataY,interpTheoryY,thoeryErr):
    '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
    valid indicies.'''
    withinUpper = np.where(dataY>(interpTheoryY+theoryErr))
    withinLower = np.where(dataY<(interpTheoryY-theoryErr))
    return (dataX[withinUpper],dataY[withinUpper],
            dataX[withinLower],dataY[withinLower])
if __name__ == "__main__":

    dataX,dataY,dataErr,theoryX,theoryY,theoryErr = makeUpData()

    TruncDataX,TruncDataY = makeSameXrange(theoryX,dataX,dataY)

    interpTheoryY = theoryYatDataX(theoryX,theoryY,TruncDataX)

    inDataX,inDataY = findInlierSet(TruncDataX,TruncDataY,interpTheoryY,
                                    theoryErr)

    outUpX,outUpY,outDownX,outDownY = findOutlierSet(TruncDataX,
                                                    TruncDataY,
                                                    interpTheoryY,
                                                    theoryErr)
    #print inlierIndex
    fig = plt.figure()
    ax = fig.add_subplot(211)

    ax.errorbar(dataX,dataY,dataErr,fmt='.',color='k')
    ax.plot(theoryX,theoryY,'r-')
    ax.plot(theoryX,theoryY+theoryErr,'r--')
    ax.plot(theoryX,theoryY-theoryErr,'r--')
    ax.set_xlim(0,1.4)
    ax.set_ylim(-.5,3)
    ax = fig.add_subplot(212)

    ax.plot(inDataX,inDataY,'ko')
    ax.plot(outUpX,outUpY,'bo')
    ax.plot(outDownX,outDownY,'ro')
    ax.plot(theoryX,theoryY,'r-')
    ax.plot(theoryX,theoryY+theoryErr,'r--')
    ax.plot(theoryX,theoryY-theoryErr,'r--')
    ax.set_xlim(0,1.4)
    ax.set_ylim(-.5,3)
    fig.savefig('findInliers.png')

This figure is the result: enter image description here

Durer answered 1/11, 2011 at 13:3 Comment(0)
P
4

At the end I use some of the Yann code:

def theoryYatDataX(theoryX,theoryY,dataX):
'''For every dataX point, find interpolated theoryY value. theoryx needed
for interpolation.'''
f = interpolate.interp1d(theoryX,theoryY)
return f(dataX[np.where(dataX<np.max(theoryX))])

def findOutlierSet(data,interpTheoryY,theoryErr):
    '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
    valid indicies.'''

    up = np.where(data.dmodulus > (interpTheoryY+theoryErr))
    low = np.where(data.dmodulus < (interpTheoryY-theoryErr))
    # join all the index together in a flat array
    out = np.hstack([up,low]).ravel()

    index = np.array(np.ones(len(data),dtype=bool))
    index[out]=False

    datain = data[index]
    dataout = data[out]

    return datain, dataout

def selectdata(data,theoryX,theoryY):
    """
    Data selection: z<1 and +-0.5 LFLRW separation
    """
    # Select data with redshift z<1
    data1 = data[data.redshift < 1]

    # From modulus to light distance:
    data1.dmodulus, data1.dmodulus_error = modulus2distance(data1.dmodulus,data1.dmodulus_error)

    # redshift data order
    data1.sort(order='redshift')

    # Outliers: distance to LFLRW curve bigger than +-0.5
    theoryErr = 0.5
    # Theory curve Interpolation to get the same points as data
    interpy = theoryYatDataX(theoryX,theoryY,data1.redshift)

    datain, dataout = findOutlierSet(data1,interpy,theoryErr)
    return datain, dataout

Using those functions I can finally obtain:

Data selection

Thank you all for your help.

Parceling answered 8/11, 2011 at 14:52 Comment(1)
+1 For showing us your solution and also for keeping the outlying points on the graph.Massorete
S
1

Just look at the difference between the red curve and the points, if it is bigger than the difference between the red curve and the dashed red curve remove it.

diff=np.abs(points-red_curve)
index= (diff>(dashed_curve-redcurve))
filtered=points[index]

But please take the comment from NickLH serious. Your Data looks pretty good without any filtering, your "outlieres" all have a very big error and won't affect the fit much.

Squab answered 31/10, 2011 at 21:15 Comment(2)
It is a good ide, but I find it difficult to calculate the difference between the red curve and the points, since both arrays have different length. One could use interpolate to make a red curve array with the points array length.Parceling
The red_curves are probably made with a function, just but the relavant x-values in it.Squab
B
0

Either you could use the numpy.where() to identify which xy pairs meet your plotting criteria, or perhaps enumerate to do pretty much the same thing. Example:

x_list = [ 1,  2,  3,  4,  5,  6 ]
y_list = ['f','o','o','b','a','r']

result = [y_list[i] for i, x in enumerate(x_list) if 2 <= x < 5]

print result

I'm sure you could change the conditions so that '2' and '5' in the above example are the functions of your curves

Brachiopod answered 2/11, 2011 at 15:8 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.