PyHealth Models

33+ Clinical ML Models

Production-ready models spanning EHR sequence learning, drug recommendation, biosignal analysis, graph neural networks, medical imaging, and clinical NLP — all with a unified training API.

Actively growing — new models added regularly. View planned additions →

All Models 34

Full API reference →
RNN EHR

Gated recurrent network (GRU/LSTM) for sequential EHR data — a strong, interpretable baseline for longitudinal clinical tasks.

from pyhealth.models import RNN
model = RNN(dataset=samples, rnn_type="GRU")
Transformer EHR

Self-attention transformer encoder for EHR sequences — captures long-range dependencies across clinical visits.

from pyhealth.models import Transformer
model = Transformer(dataset=samples)
MLP EHR

Multi-layer perceptron over aggregated patient features — a fast, simple non-sequential baseline.

from pyhealth.models import MLP
model = MLP(dataset=samples, hidden_dim=256)
LogisticRegression EHR

Logistic regression over bag-of-codes features — the standard interpretable clinical baseline.

from pyhealth.models import LogisticRegression
model = LogisticRegression(dataset=samples)
RETAIN EHR

Reverse-time attention network providing visit-level and code-level interpretability for clinical risk prediction (Choi et al., 2016).

from pyhealth.models import RETAIN
model = RETAIN(dataset=samples)
StageNet EHR

Stage-aware LSTM that models disease progression stages from irregular time-series EHR data (Gao et al., 2020).

from pyhealth.models import StageNet
model = StageNet(dataset=samples)
StageNetMHA EHR

StageNet variant with multi-head attention over clinical stages for richer temporal modeling of disease trajectories.

from pyhealth.models import StageNetMHA
model = StageNetMHA(dataset=samples)
AdaCare EHR

Adaptive recurrent model with feature recalibration and multi-scale temporal convolution for clinical time-series (Ma et al., 2020).

from pyhealth.models import AdaCare
model = AdaCare(dataset=samples)
ConCare EHR

Context-aware attention model that uses demographic features to guide multi-head attention over clinical features (Ma et al., 2020).

from pyhealth.models import ConCare
model = ConCare(dataset=samples)
Agent EHR

Attention-based model with global-to-local information gathering for learning patient representations from irregular clinical visits.

from pyhealth.models import Agent
model = Agent(dataset=samples)
GRASP EHR

Graph-based similarity retrieval model that learns from both local patient features and similar patients in the dataset (Zhang et al., 2021).

from pyhealth.models import GRASP
model = GRASP(dataset=samples)
DeepR EHR

Deep representation learning from clinical records using sparse ICD code embeddings with temporal decay (Nguyen et al., 2016).

from pyhealth.models import DeepR
model = DeepR(dataset=samples)
EHRMamba EHR

State-space model (Mamba/SSM) adapted for EHR sequences — efficient linear-time alternative to transformers for long patient histories.

from pyhealth.models import EHRMamba
model = EHRMamba(dataset=samples)
JambaEHR EHR

Hybrid Jamba architecture (SSM + attention layers) for EHR, balancing the efficiency of Mamba with the expressiveness of transformers.

from pyhealth.models import JambaEHR
model = JambaEHR(dataset=samples)
SafeDrug Drug Rec

Drug recommendation with molecular structure and drug-drug interaction constraints for safe prescription (Yang et al., 2021).

from pyhealth.models import SafeDrug
model = SafeDrug(dataset=samples)
GAMENet Drug Rec

Graph augmented memory network leveraging drug knowledge graphs and patient history for medication recommendation (Shang et al., 2019).

from pyhealth.models import GAMENet
model = GAMENet(dataset=samples)
MICRON Drug Rec

Medication change prediction network that models prescription changes between visits via residual learning (Yang et al., 2021).

from pyhealth.models import MICRON
model = MICRON(dataset=samples)
MoleRec Drug Rec

Molecule-level drug recommendation with substructure-aware representation learning from drug molecular graphs (Yang et al., 2023).

from pyhealth.models import MoleRec
model = MoleRec(dataset=samples)
SparcNet Signal

Sparse neural network for EEG and biosignal classification — strong performance on sleep staging, abnormality detection, and EEG events.

from pyhealth.models import SparcNet
model = SparcNet(dataset=samples)
BIOT Signal

Biosignal foundation model with tokenized EEG patch representations and cross-dataset transfer learning capabilities (Yang et al., 2023).

from pyhealth.models import BIOT
model = BIOT(dataset=samples)
ContraWR Signal

Contrastive learning for wearable and EEG signals — self-supervised pretraining for sleep staging and physiological classification (Yang et al., 2023).

from pyhealth.models import ContraWR
model = ContraWR(dataset=samples)
TCN Signal

Temporal Convolutional Network with dilated causal convolutions for time-series classification — efficient and parallelizable.

from pyhealth.models import TCN
model = TCN(dataset=samples)
GNN Graph

General-purpose graph neural network for clinical knowledge graphs, patient similarity networks, and drug interaction graphs.

from pyhealth.models import GNN
model = GNN(dataset=samples, conv_type="GAT")
GraphCare Graph

Knowledge graph-enhanced patient representation that integrates clinical ontologies (ICD, ATC) for personalized healthcare prediction (Jiang et al., 2023).

from pyhealth.models import GraphCare
model = GraphCare(dataset=samples)
CNN Image

Convolutional neural network for medical image classification — configurable depth and normalization for chest X-ray and other imaging tasks.

from pyhealth.models import CNN
model = CNN(dataset=samples)
TorchvisionModel Image

Wrapper for any torchvision architecture (ResNet, ViT, EfficientNet, DenseNet) as a drop-in PyHealth model with pre-trained weights support.

from pyhealth.models import TorchvisionModel
model = TorchvisionModel(dataset=samples,
    backbone="resnet50")
VAE Image

Variational autoencoder for medical image generation and latent-space representation learning — supports chest X-ray synthesis.

from pyhealth.models import VAE
model = VAE(dataset=samples, latent_dim=128)
GAN Image

Generative adversarial network for medical image synthesis and data augmentation — includes generator and discriminator for chest X-ray generation.

from pyhealth.models import GAN
model = GAN(dataset=samples)
TransformersModel Text

HuggingFace model wrapper — plug in any BERT, ClinicalBERT, BioBERT, or LLM checkpoint as a PyHealth model for clinical NLP tasks.

from pyhealth.models import TransformersModel
model = TransformersModel(dataset=samples,
    model_name="emilyalsentzer/Bio_ClinicalBERT")
TextEmbeddingModel Text

Extracts fixed-size text embeddings from clinical notes using a HuggingFace encoder, feeding downstream PyHealth models.

from pyhealth.models import TextEmbeddingModel
model = TextEmbeddingModel(dataset=samples,
    model_name="bert-base-uncased")
SDOH Text

Social determinants of health model that extracts SDOH factors from unstructured clinical notes and integrates them into risk prediction.

from pyhealth.models import SDOH
model = SDOH(dataset=samples)
UnifiedMultimodalEmbeddingModel Multimodal

Unified temporal embedding model for simultaneous EHR codes, clinical text, medical images, and biosignals — the backbone of PyHealth 2.0's multimodal API.

from pyhealth.models import UnifiedMultimodalEmbeddingModel
model = UnifiedMultimodalEmbeddingModel(
    dataset=samples)
VisionEmbeddingModel Multimodal

Vision encoder bridge that wraps torchvision backbones as temporal feature processors compatible with the unified multimodal pipeline.

from pyhealth.models import VisionEmbeddingModel
model = VisionEmbeddingModel(dataset=samples,
    backbone="vit_b_16")
MedLink Multimodal

Patient record linkage model combining structured EHR features and free-text notes to identify duplicate patient records across data sources.

from pyhealth.models import MedLink
model = MedLink(dataset=samples)

Need a custom model?

PyHealth models inherit from BaseModel. Your custom model immediately works with the Trainer, all metrics, and explainability tools — no extra wiring needed.

Model API Reference → View Source on GitHub →