Tools to use for conditional density estimation in Python [closed]
Asked Answered
W

2

0

I have a large data set that contains 3 attributes per row: A,B,C

Column A: can take the values 1, 2, and 0. Column B and C: can take any values.

I'd like to perform density estimation using histograms for P(A = 2 | B,C) and plot the results using python.

I do not need the code to do it, I can try and figure that on my own. I just need to know the procedures and the tools that should I use?

Wistrup answered 25/10, 2014 at 1:45 Comment(1)
Take a stab at it and post your code.Crunch
D
6

To answer your over-all question, we should go through different steps and answer different questions:

  • How to read csv file (or text data) ?

  • How to filter data ?

  • How to plot data ?

At each stage, you need to use some techniques and specific tools, you might also have different choices at different stages (You can look on the internet for different alternatives).

1- How to read csv file:

There is a built-in function to go through the csv file where you store your data. But most people recommend Pandas to deal with csv files.

After installing Pandas package, you can read your csv file using Read_CSV command.

import pandas as pd

df= pd.read_csv("file.csv")

As you didn't share the csv file, I will make a random dataset to explain the up-coming steps.

import pandas as pd
import numpy as np

t= [1,1,1,2,0,1,1,0,0,2,1,1,2,0,0,0,0,1,1,1]
df = pd.DataFrame(np.random.randn(20, 2), columns=list('AC'))
df['B']=t  #put a random column with only 0,1,2 values, then insert it to the dataframe

Note: Numpy is a python-Package. It's helpful to work with mathematical operations. You don't primarily need it, but I mentioned it to clear confusion here.

In case you print df in this case, you will get as result:

         A         C    B
0  -0.090162  0.035458  1
1   2.068328 -0.357626  1
2  -0.476045 -1.217848  1
3  -0.405150 -1.111787  2
4   0.502283  1.586743  0
5   1.822558 -0.398833  1
6   0.367663  0.305023  1
7   2.731756  0.563161  0
8   2.096459  1.323511  0
9   1.386778 -1.774599  2
10 -0.512147 -0.677339  1
11 -0.091165  0.587496  1
12 -0.264265  1.216617  2
13  1.731371 -0.906727  0
14  0.969974  1.305460  0
15 -0.795679 -0.707238  0
16  0.274473  1.842542  0
17  0.771794 -1.726273  1
18  0.126508 -0.206365  1
19  0.622025 -0.322115  1

2- - How to filter data: There are different techniques to filter data. The easiest one is by selecting the name of column inside your dataframe + the condition. In our case, the criteria is selecting value "2" in column B.

l= df[df['B']==2]
print l

You can also use other ways such groupby, lambda to go through the data frame and apply different conditions to filter the data.

for key in df.groupby('B'):
    print key 

If you run the above-mentioned scripts you'll get:

For the first one: Only data where B==2

           A         C  B
3  -0.405150 -1.111787  2
9   1.386778 -1.774599  2
12 -0.264265  1.216617  2

For the second one: Printing the results divided in groups.

(0,            A         C  B
4   0.502283  1.586743  0
7   2.731756  0.563161  0
8   2.096459  1.323511  0
13  1.731371 -0.906727  0
14  0.969974  1.305460  0
15 -0.795679 -0.707238  0
16  0.274473  1.842542  0)
(1,            A         C  B
0  -0.090162  0.035458  1
1   2.068328 -0.357626  1
2  -0.476045 -1.217848  1
5   1.822558 -0.398833  1
6   0.367663  0.305023  1
10 -0.512147 -0.677339  1
11 -0.091165  0.587496  1
17  0.771794 -1.726273  1
18  0.126508 -0.206365  1
19  0.622025 -0.322115  1)
(2,            A         C  B
3  -0.405150 -1.111787  2
9   1.386778 -1.774599  2
12 -0.264265  1.216617  2)
  • How to plot your data:

The simplest ways to plot your data is by using matplotlib

The easiest ways to plot data in columns B, is by running :

import random
import matplotlib.pyplot as plt

xbins=range(0,len(l))
plt.hist(df.B, bins=20, color='blue')
plt.show()

You'll get this result:

enter image description here

if you wanna plot the results combined, you should use different colors/techniques to make it useful.

import numpy as np
import matplotlib.pyplot as plt
a = df.A
b = df.B
c = df.C
t= range(20)
plt.plot(t, a, 'r--',  b, 'bs--', c, 'g^--')
plt.legend()
plt.show()

You'll get as a result:

enter image description here

Plotting data is driven by a specific need. You can explore the different ways to plot data by going through the examples of marplotlib.org official website.

Diminish answered 25/10, 2014 at 3:58 Comment(2)
Wow! Your answer was extremely well explained and detailed. Thank you so much! I really appreciate the time you took to answer.Wistrup
@OliverHoffman: Glad you find the answer helpful! :) Good Luck with your project.Diminish
M
0

If you're looking for other tools that do slightly more sophisticated things than nonparametric density estimation with histograms, please check this link to the python repository or directly install the package with

pip install cde

In addition to an extensive documentation, the package implements

  • Nonparametric (conditional & neighborhood kernel density estimation)
  • semiparametric (least squares cde) and
  • parametric neural network-based methods (mixture density networks, kernel density estimation)

Also, the package allows to compute centered moments, statistical divergences (kl-divergence, hellinger, jensen-shannon), percentiles, expected shortfalls and data generating processes (arma-jump, jump-diffusion, GMMs etc.)

Disclaimer: I am one of the package developers.

Madalene answered 19/3, 2019 at 10:11 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.