PyHealth Models
33+ Clinical ML Models
Production-ready models spanning EHR sequence learning, drug recommendation, biosignal analysis, graph neural networks, medical imaging, and clinical NLP — all with a unified training API.
All Models 34
Full API reference →Gated recurrent network (GRU/LSTM) for sequential EHR data — a strong, interpretable baseline for longitudinal clinical tasks.
from pyhealth.models import RNN model = RNN(dataset=samples, rnn_type="GRU")
Self-attention transformer encoder for EHR sequences — captures long-range dependencies across clinical visits.
from pyhealth.models import Transformer model = Transformer(dataset=samples)
Multi-layer perceptron over aggregated patient features — a fast, simple non-sequential baseline.
from pyhealth.models import MLP model = MLP(dataset=samples, hidden_dim=256)
Logistic regression over bag-of-codes features — the standard interpretable clinical baseline.
from pyhealth.models import LogisticRegression model = LogisticRegression(dataset=samples)
Reverse-time attention network providing visit-level and code-level interpretability for clinical risk prediction (Choi et al., 2016).
from pyhealth.models import RETAIN model = RETAIN(dataset=samples)
Stage-aware LSTM that models disease progression stages from irregular time-series EHR data (Gao et al., 2020).
from pyhealth.models import StageNet model = StageNet(dataset=samples)
StageNet variant with multi-head attention over clinical stages for richer temporal modeling of disease trajectories.
from pyhealth.models import StageNetMHA model = StageNetMHA(dataset=samples)
Adaptive recurrent model with feature recalibration and multi-scale temporal convolution for clinical time-series (Ma et al., 2020).
from pyhealth.models import AdaCare model = AdaCare(dataset=samples)
Context-aware attention model that uses demographic features to guide multi-head attention over clinical features (Ma et al., 2020).
from pyhealth.models import ConCare model = ConCare(dataset=samples)
Attention-based model with global-to-local information gathering for learning patient representations from irregular clinical visits.
from pyhealth.models import Agent model = Agent(dataset=samples)
Graph-based similarity retrieval model that learns from both local patient features and similar patients in the dataset (Zhang et al., 2021).
from pyhealth.models import GRASP model = GRASP(dataset=samples)
Deep representation learning from clinical records using sparse ICD code embeddings with temporal decay (Nguyen et al., 2016).
from pyhealth.models import DeepR model = DeepR(dataset=samples)
State-space model (Mamba/SSM) adapted for EHR sequences — efficient linear-time alternative to transformers for long patient histories.
from pyhealth.models import EHRMamba model = EHRMamba(dataset=samples)
Hybrid Jamba architecture (SSM + attention layers) for EHR, balancing the efficiency of Mamba with the expressiveness of transformers.
from pyhealth.models import JambaEHR model = JambaEHR(dataset=samples)
Drug recommendation with molecular structure and drug-drug interaction constraints for safe prescription (Yang et al., 2021).
from pyhealth.models import SafeDrug model = SafeDrug(dataset=samples)
Graph augmented memory network leveraging drug knowledge graphs and patient history for medication recommendation (Shang et al., 2019).
from pyhealth.models import GAMENet model = GAMENet(dataset=samples)
Medication change prediction network that models prescription changes between visits via residual learning (Yang et al., 2021).
from pyhealth.models import MICRON model = MICRON(dataset=samples)
Molecule-level drug recommendation with substructure-aware representation learning from drug molecular graphs (Yang et al., 2023).
from pyhealth.models import MoleRec model = MoleRec(dataset=samples)
Sparse neural network for EEG and biosignal classification — strong performance on sleep staging, abnormality detection, and EEG events.
from pyhealth.models import SparcNet model = SparcNet(dataset=samples)
Biosignal foundation model with tokenized EEG patch representations and cross-dataset transfer learning capabilities (Yang et al., 2023).
from pyhealth.models import BIOT model = BIOT(dataset=samples)
Contrastive learning for wearable and EEG signals — self-supervised pretraining for sleep staging and physiological classification (Yang et al., 2023).
from pyhealth.models import ContraWR model = ContraWR(dataset=samples)
Temporal Convolutional Network with dilated causal convolutions for time-series classification — efficient and parallelizable.
from pyhealth.models import TCN model = TCN(dataset=samples)
General-purpose graph neural network for clinical knowledge graphs, patient similarity networks, and drug interaction graphs.
from pyhealth.models import GNN model = GNN(dataset=samples, conv_type="GAT")
Knowledge graph-enhanced patient representation that integrates clinical ontologies (ICD, ATC) for personalized healthcare prediction (Jiang et al., 2023).
from pyhealth.models import GraphCare model = GraphCare(dataset=samples)
Convolutional neural network for medical image classification — configurable depth and normalization for chest X-ray and other imaging tasks.
from pyhealth.models import CNN model = CNN(dataset=samples)
Wrapper for any torchvision architecture (ResNet, ViT, EfficientNet, DenseNet) as a drop-in PyHealth model with pre-trained weights support.
from pyhealth.models import TorchvisionModel model = TorchvisionModel(dataset=samples, backbone="resnet50")
Variational autoencoder for medical image generation and latent-space representation learning — supports chest X-ray synthesis.
from pyhealth.models import VAE model = VAE(dataset=samples, latent_dim=128)
Generative adversarial network for medical image synthesis and data augmentation — includes generator and discriminator for chest X-ray generation.
from pyhealth.models import GAN model = GAN(dataset=samples)
HuggingFace model wrapper — plug in any BERT, ClinicalBERT, BioBERT, or LLM checkpoint as a PyHealth model for clinical NLP tasks.
from pyhealth.models import TransformersModel model = TransformersModel(dataset=samples, model_name="emilyalsentzer/Bio_ClinicalBERT")
Extracts fixed-size text embeddings from clinical notes using a HuggingFace encoder, feeding downstream PyHealth models.
from pyhealth.models import TextEmbeddingModel model = TextEmbeddingModel(dataset=samples, model_name="bert-base-uncased")
Social determinants of health model that extracts SDOH factors from unstructured clinical notes and integrates them into risk prediction.
from pyhealth.models import SDOH model = SDOH(dataset=samples)
Unified temporal embedding model for simultaneous EHR codes, clinical text, medical images, and biosignals — the backbone of PyHealth 2.0's multimodal API.
from pyhealth.models import UnifiedMultimodalEmbeddingModel model = UnifiedMultimodalEmbeddingModel( dataset=samples)
Vision encoder bridge that wraps torchvision backbones as temporal feature processors compatible with the unified multimodal pipeline.
from pyhealth.models import VisionEmbeddingModel model = VisionEmbeddingModel(dataset=samples, backbone="vit_b_16")
Patient record linkage model combining structured EHR features and free-text notes to identify duplicate patient records across data sources.
from pyhealth.models import MedLink model = MedLink(dataset=samples)
Need a custom model?
PyHealth models inherit from BaseModel. Your custom model immediately works with the Trainer, all metrics, and explainability tools — no extra wiring needed.