I'm confused how to interpret the output of .predict
from a fitted CoxnetSurvivalAnalysis
model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:
import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis
# load data
data_X, data_y = load_veterans_lung_cancer()
# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']
X = data_X.copy()
for c in categorical_cols:
dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)
# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))
so here's the X going into the model:
Age_in_years Celltype Karnofsky_score Months_from_Diagnosis \
0 69.0 squamous 60.0 7.0
1 64.0 squamous 70.0 5.0
2 38.0 squamous 60.0 3.0
Prior_therapy Treatment
0 no standard
1 yes standard
2 no standard
...moving on to fitting model and generating predictions:
# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)
# What are these predictions?
preds = coxnet.predict(X)
preds
has same number of records as X
, but their values are wayyy different than the values in data_y
, even when predicted on the same data they were fit on.
print(preds.mean())
print(data_y['Survival_in_days'].mean())
output:
-0.044114643249153422
121.62773722627738
So what exactly are preds
? Clearly .predict
means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat
for a given X
? I'm new to survival analysis so I'm obviously missing something.
predictions are risk scores on an arbitrary scale, which means you can usually only determine the sequence of events, but not their exact time.
So that answers this "how do I interpret" question I guess, but doesn't really get me closer to what I really wanted, which was a prediction of likely survival time.To get that, apparently I need to useestimator.predict_survival_function
in some manner. – Holds