How to interpret output of .predict() from fitted scikit-survival model in python?
Asked Answered
H

2

11

I'm confused how to interpret the output of .predict from a fitted CoxnetSurvivalAnalysis model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:

import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis

# load data
data_X, data_y = load_veterans_lung_cancer()

# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']

X = data_X.copy()
for c in categorical_cols:
    dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
    X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)

# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))

so here's the X going into the model:

   Age_in_years  Celltype  Karnofsky_score  Months_from_Diagnosis  \
0          69.0  squamous             60.0                    7.0   
1          64.0  squamous             70.0                    5.0   
2          38.0  squamous             60.0                    3.0   

  Prior_therapy Treatment  
0            no  standard  
1           yes  standard  
2            no  standard  

...moving on to fitting model and generating predictions:

# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)    

# What are these predictions?    
preds = coxnet.predict(X)

preds has same number of records as X, but their values are wayyy different than the values in data_y, even when predicted on the same data they were fit on.

print(preds.mean()) 
print(data_y['Survival_in_days'].mean())

output:

-0.044114643249153422
121.62773722627738

So what exactly are preds? Clearly .predict means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat for a given X? I'm new to survival analysis so I'm obviously missing something.

Holds answered 13/11, 2017 at 22:5 Comment(3)
Did you ever figure this out?Department
@Department it looks like its the hazard ratioCourbevoie
@francium87d, sort of. I posted this question on github (github.com/sebp/scikit-survival/issues/15), and the library author mentioned that predictions are risk scores on an arbitrary scale, which means you can usually only determine the sequence of events, but not their exact time. So that answers this "how do I interpret" question I guess, but doesn't really get me closer to what I really wanted, which was a prediction of likely survival time.To get that, apparently I need to use estimator.predict_survival_function in some manner.Holds
H
4

I posted this question on github, though the author renamed the issue question.

I got some helpful explanation of what the predict output is, but still am not sure how to get to a set of predicted survival times, which is what I really want. Here's a couple helpful explanations from that github thread:

predictions are risk scores on an arbitrary scale, which means you can 
usually only determine the sequence of events, but not their exact time.

-sebp (library author)

It [predict] returns a type of risk score. Higher value means higher
risk of your event (class value = True)...You were probably looking
for a predicted time. You can get the predicted survival function with
estimator.predict_survival_function as in the example 00
notebook...EDIT: Actually, I’m trying to extract this but it’s been a
bit of a pain to munge

-pavopax.

There's more explanation at the github thread, though I wasn't really able to follow all of it. I need to play around with predict_survival_function and predict_cumulative_hazard_function and see if I can get to a set of predictions for most likely survival time by row in X, which is what I really want.

I'm not going to accept this answer here, in case anyone else has a better one.

Holds answered 25/11, 2017 at 0:39 Comment(0)
P
0

With the X input, you get an evaluation of the input array:

def predict(self, X, alpha=None):
    """The linear predictor of the model.
    Parameters
    ----------
    X : array-like, shape = (n_samples, n_features)
        Test data of which to calculate log-likelihood from
    alpha : float, optional
        Constant that multiplies the penalty terms. If the same alpha was used during training, exact
        coefficients are used, otherwise coefficients are interpolated from the closest alpha values that
        were used during training. If set to ``None``, the last alpha in the solution path is used.
    Returns
    -------
    T : array, shape = (n_samples,)
        The predicted decision function
    """
    X = check_array(X)
    coef = self._get_coef(alpha)
    return numpy.dot(X, coef)

The definition check_array comes from another library. You can review the code of coxnet.

Parson answered 24/11, 2017 at 15:9 Comment(1)
I think the key part from this code is probably coef = self._get_coef(alpha) not X = check_array(X).Holds

© 2022 - 2024 — McMap. All rights reserved.