With X and your labels table, train multiple types of task heads to predict clinical outcomes.
A
Readmission risk — Binary
Logistic regression → ROC-AUC
After importing packages and splitting our data, Task A is a binary classifier predicting subject readmission as recorded in the readmission_risk column.
# --- Setup ---
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score, accuracy_score, mean_absolute_error
from lifelines import CoxPHFitter
# X, pids from embedding step above. Align labels to X row order:
labels_df = pd.read_parquet("your_labels.parquet")
labels = labels_df.set_index("subject_id").loc[pids].reset_index()
X_train, X_test, labels_train, labels_test = train_test_split(
X, labels, test_size=0.2, random_state=42
)
print(f"Training on {len(X_train)} samples, testing on {len(X_test)}")
# --- Task A: Binary (Readmission Risk) ---
print("\n--- Task A: Binary Classification ---")
clf_bin = LogisticRegression(max_iter=1000)
clf_bin.fit(X_train, labels_train["readmission_risk"])
y_prob = clf_bin.predict_proba(X_test)[:, 1]
auc = roc_auc_score(labels_test["readmission_risk"], y_prob)
print(f"-> ROC-AUC: {auc:.3f}")
B
Phenotype stage — Multiclass
Logistic regression → Accuracy
Task B is a multi-class classifier to predict cancer stage 1–4, recorded in phenotype_class. Accuracy is reported.
# --- Task B: Multiclass (Phenotype) ---
print("\n--- Task B: Multiclass Phenotyping ---")
clf_multi = LogisticRegression(max_iter=1000)
clf_multi.fit(X_train, labels_train["phenotype_class"])
y_pred_class = clf_multi.predict(X_test)
acc = accuracy_score(labels_test["phenotype_class"], y_pred_class)
print(f"-> Accuracy: {acc:.3f}")
C
Survival months — Regression
Ridge regression → MAE
Task C uses the continuous overall_survival_months to predict months of survival for subjects who died. Mean absolute error is reported.
# --- Task C: Regression (Survival months) ---
print("\n--- Task C: Regression ---")
reg = Ridge(alpha=1.0)
reg.fit(X_train, labels_train["overall_survival_months"])
y_pred_reg = reg.predict(X_test)
mae = mean_absolute_error(labels_test["overall_survival_months"], y_pred_reg)
print(f"-> MAE: {mae:.2f}")
D
Cox proportional hazards — Survival
PCA + Cox PH → C-Index
Task D predicts Cox proportional hazards between arbitrary groups for risk of death, using overall_survival_months and event_observed. The concordance index is reported.
# --- Task D: Survival (Cox PH) ---
print("\n--- Task D: Survival Analysis ---")
pca = PCA(n_components=10)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)
cox_df = pd.DataFrame(X_train_pca, columns=[f"PC{i}" for i in range(10)])
cox_df["T"] = labels_train["overall_survival_months"].values
cox_df["E"] = labels_train["event_observed"].values
cph = CoxPHFitter()
cph.fit(cox_df, duration_col="T", event_col="E")
test_cox_df = pd.DataFrame(X_test_pca, columns=[f"PC{i}" for i in range(10)])
test_cox_df["T"] = labels_test["overall_survival_months"].values
test_cox_df["E"] = labels_test["event_observed"].values
c_index = cph.score(test_cox_df, scoring_method="concordance_index")
print(f"-> C-Index: {c_index:.3f}")
You may abstract these concepts to use Standard Model embeddings as input to other classifier types, including more complex downstream models.