aaanalysis.TreeModel.predict_proba

TreeModel.predict_proba(X)[source]

Obtain Monte Carlo estimate of class prediction probabilities for the positive class in X.

Predictions are performed using all tree-based models from the list_model_classes attribute and feature selections from the is_selected attribute.

Parameters:

X (array-like, shape (n_samples, n_features)) – Feature matrix. Rows typically correspond to samples and columns to features.

Returns:

  • pred (array-like, shape (n_samples)) – Array with the average prediction score for the positive class for each sample.

  • pred_std (array-like, shape (n_samples)) – Array with the standard deviation of prediction scores for the positive class for each sample.

See also

Examples

To demonstrate the TreeModel().predict_proba()method, we obtain the DOM_GSEC example dataset and its respective feature set (see [Breimann25a]):

import aaanalysis as aa
aa.options["verbose"] = False # Disable verbosity

df_seq = aa.load_dataset(name="DOM_GSEC")
labels = df_seq["label"].to_list()
df_feat = aa.load_features(name="DOM_GSEC").head(100)

# Create feature matrix
sf = aa.SequenceFeature()
df_parts = sf.get_df_parts(df_seq=df_seq)
X = sf.feature_matrix(features=df_feat["feature"], df_parts=df_parts)

We can not fit the TreeModel, which will internally fit 3 tree-based models over 5 training rounds be default:

tm = aa.TreeModel()
tm = tm.fit(X, labels=labels)

Using the TreeModel().predict_proba() method calculates probability predictions by averaging across multiple models and rounds, using a Monte Carlo approach for robust estimation:

pred, pred_std = tm.predict_proba(X)

df_seq["prediction"] = pred
df_seq["pred_std"] = pred_std

print("Prediction scores for 5 substrates")
aa.display_df(df_seq.head(5))

print("Prediction scores for 5 non-substrates")
aa.display_df(df_seq.tail(5))
Prediction scores for 5 substrates
  entry sequence label tmd_start tmd_stop jmd_n tmd jmd_c prediction pred_std
1 P05067 MLPGLALLLLAAWTA...GYENPTYKFFEQMQN 1 701 723 FAEDVGSNKG AIIGLMVGGVVIATVIVITLVML KKKQYTSIHH 0.981000 0.013928
2 P14925 MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS 1 868 890 KLSTEPGSGV SVVLITTLLVIPVLVLLAIVMFI RWKKSRAFGD 0.973000 0.012884
3 P70180 MRSLLLFTFSACVLL...RELREDSIRSHFSVA 1 477 499 PCKSSGGLEE SAVTGIVVGALLGAGLLMAFYFF RKKYRITIER 0.976000 0.008000
4 Q03157 MGPTSPAARGQGRRW...HGYENPTYRFLEERP 1 585 607 APSGTGVSRE ALSGLLIMGAGGGSLIVLSLLLL RKKKPYGTIS 0.989000 0.008000
5 Q06481 MAATGTAAAAATGRL...GYENPTYKYLEQMQI 1 694 716 LREDFSLSSS ALIGLLVIAVAIATVIVISLVML RKRQYGTISH 0.996000 0.003742
Prediction scores for 5 non-substrates
  entry sequence label tmd_start tmd_stop jmd_n tmd jmd_c prediction pred_std
122 P36941 MLLPWATSAPGLAWG...TPSNRGPRNQFITHD 0 226 248 PLPPEMSGTM LMLAVLLPLAFFLLLATVFSCIW KSHPSLCRKL 0.008000 0.005099
123 P25446 MLWIWAVLPLVLAGS...STPDTGNENEGQCLE 0 170 187 NCRKQSPRNR LWLLTILVLLIPLVFIYR KYRKRKCWKR 0.150000 0.025884
124 Q9P2J2 MVWCLGLAVLSLVIS...AYRQPVPHPEQATLL 0 738 760 PGLLPQPVLA GVVGGVCFLGVAVLVSILAGCLL NRRRAARRRR 0.122000 0.017205
125 Q96J42 MVPAAGRRPPRVMRL...SIRWLIPGQEQEHVE 0 324 342 LPSTLIKSVD WLLVFSLFFLISFIMYATI RTESIRWLIP 0.037000 0.012490
126 P0DPA2 MRVGGAFHLLLVCLS...DCAEGPVQCKNGLLV 0 265 287 KVSDSRRIGV IIGIVLGSLLALGCLAVGIWGLV CCCCGGSGAG 0.011000 0.004899