This project explores how generative models can help address class imbalance in bone tumor X-ray classification.
We use the BTXRD dataset, which contains X-ray images of different primary bone tumor entities.
- Implement a ResNet-based CNN as a baseline model for tumor classification.
- Apply and analyze class imbalance handling techniques such as Weighted Loss, Focal Loss, and Data Augmentation.
- Use generative models (e.g., GANs or Diffusion Models) to create synthetic X-ray images and evaluate their impact on classification performance.
We focus on classifying seven tumor types:
- Osteochondroma
- Osteosarcoma
- Multiple Osteochondromas
- Simple Bone Cyst
- Giant Cell Tumor
- Synovial Osteochondroma
- Osteofibroma
This project is part of the Applied Deep Learning in Medicine (ADLM) course at the
Technical University of Munich (TUM), organized by the
Chair of Artificial Intelligence in Healthcare and Medicine.
First, create and activate a virtual environment:
Option 1: Using Conda
conda create -n bone-tumour-classification python=3.12
conda activate bone-tumour-classificationOption 2: Using venv x
python3 -m venv venv
source venv/bin/activate # On Linux/macOS
# or
venv\Scripts\activate # On WindowsThen install dependencies:
pip install -r requirements.txt
This project supports optional logging to Weights & Biases. To enable W&B logging, set the following values in config.py:
WANDB_ENTITY = "your-wandb-entity"
WANDB_PROJECT = "your-wandb-project"To disable W&B logging, leave these values as empty strings:
WANDB_ENTITY = ""
WANDB_PROJECT = ""When W&B is disabled, training and testing will still work normally but metrics will only be logged locally (TensorBoard) and printed to console.
You can download the BTXRD dataset here
Place the BTXRD dataset under the following paths relative to the project root:
data/
dataset/
BTXRD/
images/ # Original X-ray images (e.g., IMG000123.jpeg)
Annotations/ # JSON annotation files (same basenames as images)
(Optional) folders created by scripts in this repo:
data/
dataset/
final_patched_BTXRD/ # Extracted patches from annotations (created, this is the final dataset used for training and testing)
squared_padded/ # Padded originals for 106 special cases (created) (optional)
The following files were used to identify and resolve issues that occurred during the preprocessing of the BTXRD dataset. These validation and debugging scripts ensured data integrity and enabled the creation of the final final_patched_BTXRD dataset.
bounding_box_checker.py
Adds the images, whose bounding box exceeds the original image size, to the csv file. Tells also whether the bounding box of that image exceeds the image sizebounding_box_visualization.py
Helps to visualize bounding boxestumour_bounding_box.py
Function that computes a square bounding box (with optional margin) around all given tumour points.pad_unsquared.py
Creates the squared padded folder in the dataset folder, containing padded images with their pad info in the csv file
python data/btxrd_bounding_box_dataset_extractor.pyThis creates data/dataset/final_patched_BTXRD/ from BTXRD/images + BTXRD/Annotations.
python train.py [arguments]Arguments:
--learning-rate(--learning_rate, float, default9.502991994821847e-05)--weight-decay(--weight_decay, float, default1e-05)--batch-size(--batch_size, int, default32)--epochs(int, default30)--dropout(float, default0.3498137514984224)--scheduler-factor(--scheduler_factor, float, default0.5)--scheduler-patience(--scheduler_patience, int, default4)--test-size(--test_size, float, default0.2)--random-state(--random_state, int, default42)--loss-fn(--loss_fn, choices:ce|wce|focal|wfocal, defaultwce)--focal-gamma(--focal_gamma, float, default2.924897740591147)--apply-minority-aug(--apply_minority_aug, bool, defaultFalse)--early-stop(--early_stop, bool, defaultFalse)--early-stop-patience(--early_stop_patience, int, default5)--early-stop-min-delta(--early_stop_min_delta, float, default0.0)--run-name-prefix(--run_name_prefix, str, defaultresnet_gan_15800)--num-classes(--num_classes, int, default7)--architecture(choices:resnet34|resnet50|densenet121, defaultresnet34)--trainwsyn(str path to split JSON, defaultNone)--image-dir(--image_dir, str, defaultdata/dataset/final_patched_BTXRD)--json-dir(--json_dir, str, defaultdata/dataset/BTXRD/Annotations)
Example:
python train.py --architecture resnet50 --loss-fn wce --early-stop true --early-stop-patience 10 --early-stop-min-delta 0.001 --batch-size 32 --epochs 50Notes:
- Labels are read directly from annotation JSON files.
- Checkpoints are saved under
checkpoints/<run_name>/.
python test.py --run-name <RUN_NAME> [--architecture resnet34|resnet50|densenet121]Arguments:
--run-name(required): Run/checkpoint folder name undercheckpoints/and matching W&B display name.--architecture(optional): Backbone used during training. Default isresnet34.
Example:
python test.py --run-name resnet_gan_15800_wce_noaug_2026-02-12_14-30-00 --architecture resnet50python supcon/train_supcon.pypython supcon/train_linear.py \
--encoder-path checkpoints_supcon/<run_timestamp>/encoder_supcon.pth \
--split-path checkpoints_supcon/<run_timestamp>/split.json \
--dataset-dir data/datasetpython supcon/eval_supcon.py \
--encoder-path checkpoints_supcon/<run_timestamp>/encoder_supcon.pth \
--classifier-path checkpoints_linear/<run_timestamp>/classifier.pth \
--split-path checkpoints_supcon/<run_timestamp>/split.json \
--dataset-dir data/dataset \
--wandbOutputs:
checkpoints_supcon/<time>/encoder_supcon.pthcheckpoints_linear/<time>/classifier.pth
In addition to the separate training scripts, the repository provides a single pipeline script that performs:
- Supervised Contrastive (SupCon) pretraining
- Linear classifier training
- Validation and test evaluation
- Logging to Weights & Biases (W&B)
- Saving model checkpoints and split information
python supcon/<PIPELINE_SCRIPT_NAME>.py \
--run-name-prefix <name> \
--random-state <int> \
--test-size <float> \
--val-size <float> \
--temperature <float> \
--feature-dim <int> \
--supcon-lr <float> \
--supcon-epochs <int> \
--linear-lr <float> \
--linear-epochs <int> \
--apply-minority-aug <true/false> \
--minority-classes <comma-separated-class-names>Note: Instructions on different approaches for generating synthetic images are in sections below
json_adjuster.py builds a series of training splits and corresponding annotations by mixing synthetic images into the original dataset. It:
- Copies original images and annotations into the output dataset folders.
- Samples synthetic images per class across multiple steps (see
STEPSin the script). - Creates JSON annotations for each synthetic image.
- Writes one split file per step:
split_step1.json,split_step2.json, ...
Basic usage:
python json_adjuster.py \
--input_split data/baseline_split.json \
--output_split data/dataset/splits \
--synthetic_images <Path to synthetic images (instructions on how to generate these is below)> \
--input_images data/dataset/final_patched_BTXRD \
--output_images data/dataset/BTXRD_images_new \
--input_annotations data/dataset/BTXRD/Annotations \
--output_annotations data/dataset/BTXRD/Annotations_newAdjustable arguments:
--output_split: Output directory for incremental split files, e.g.data/dataset/splits/(one per step).--synthetic_images: Root folder containing class subfolders of generated images--output_images: Target folder for originals + synthetic images (e.g.data/dataset/BTXRD_images_new).--output_annotations: Target folder for originals + synthetic JSON annotations (e.g.data/dataset/BTXRD/Annotations_new).
Outputs:
data/dataset/splits/split_step*.jsonwith incrementally expandedtrainindices.- New synthetic images and JSON annotations
You have to add trainwsyn argument to train with the specific step
Basic usage:
python train.py \
--trainwsyn <SYNTHETIC_SPLIT> \
--run-name-prefix <RUN_NAME_PREFIX>Key arguments:
--trainwsyn: Path to a synthetic split JSON (e.g.data/dataset/splits/split_step3.json). This selects which step’s augmented train indices are used.--run-name-prefix: Prefix for the run/checkpoint name (e.g.resnet_gan_15800). The final run name includes this prefix plus a timestamp/settings.
python test.py --run-name <RUN_NAME> [--architecture resnet34|resnet50|densenet121]Arguments:
--run-name(required): Run/checkpoint folder name undercheckpoints/and matching W&B display name.--architecture(optional): Backbone used during training. Default isresnet34.
Example:
python test.py --run-name resnet_gan_15800_wce_noaug_2026-02-12_14-30-00 --architecture resnet501. Clone stylegan2-ada-pytorch
Place these folders under stylegan2-ada-pytorch/data/dataset/:
BTXRD/(must containAnnotations/)final_patched_BTXRD/(patched images)dataset_split.json(or adapt--split-pathbelow)
Expected layout:
stylegan2-ada-pytorch/
data/
dataset/
BTXRD/
Annotations/
final_patched_BTXRD/
dataset_split.json
The full BTXRD preprocessing pipeline is handled by one script call via
style_gan_preprocessing.py full-pipeline (resize/sort, index-map creation,
and train-split correction).
python data/style_gan_preprocessing.py full-pipeline [arguments]Arguments:
--image-dir(path, default:data/dataset/final_patched_BTXRD): Input image directory for preprocessing.--json-dir(path, default:data/dataset/BTXRD/Annotations): Directory with BTXRD annotation JSON files.--preprocess-output-dir(path, default:data/dataset/BTXRD_resized_sorted_with_anatomical_location): Output directory for resized and class-sorted images.--target-size(int, default:256): Output square image size.--center-crop(flag): Center-crop images to square before resize.--no-dataset-json(flag): Skip writingdataset.jsonduring preprocess step.--use-anatomical-location(flag): Prefix tumor labels with anatomical location to create 21 classes.--xlsx-path(path, default:data/dataset/BTXRD/dataset.xlsx): Path to metadata XLSX used when--use-anatomical-locationis set.--split-path(path, default:data/dataset/dataset_split.json): Split JSON used to build index map and keep only train images.--index-map-dataset-dir(path, default:data/dataset/final_patched_BTXRD): Directory used to build the index-to-filename map.--index-map-output-path(path, default:data/dataset/final_patched_index_map.json): Output JSON path for the generated index map.--correct-split-dataset-dir(path, default:--preprocess-output-dir): Directory to apply train-split filtering on.--dry-run(flag): Preview split correction without deleting files or rewritingdataset.json.
Example:
python data/style_gan_preprocessing.py full-pipeline \
--image-dir data/dataset/final_patched_BTXRD \
--json-dir data/dataset/BTXRD/Annotations \
--preprocess-output-dir data/dataset/BTXRD_resized_sorted \
--target-size 256 \
--split-path data/dataset/dataset_split.json \
--index-map-dataset-dir data/dataset/final_patched_BTXRD \
--index-map-output-path data/dataset/final_patched_index_map.json \
--correct-split-dataset-dir data/dataset/BTXRD_resized_sortedOther available subcommands:
preprocess: Only resize/sort images and optionally writedataset.json.build-index-map: Only createfinal_patched_index_map.jsonfrom split indices.correct-split: Only filter the resized dataset to the train split and rewritedataset.json.
python data/style_gan_preprocessing.py preprocess --help
python data/style_gan_preprocessing.py build-index-map --help
python data/style_gan_preprocessing.py correct-split --help
python data/style_gan_preprocessing.py full-pipeline --helppython data/dataset_tool.py [arguments]Arguments:
--source(required, path): Input dataset path (folder or archive).--dest(required, path): Output dataset path (folder or archive, e.g..zip).--max-images(int, optional): Limit number of images.--resize-filter(choice:box|lanczos, default:lanczos): Resize interpolation filter.--transform(choice:center-crop|center-crop-wide, optional): Crop/resize mode.--width(int, optional): Output width.--height(int, optional): Output height.
Example:
python data/dataset_tool.py \
--source data/dataset/BTXRD_resized_sorted \
--dest data/btxrd_corrected_dataset.zip \
--width 256 --height 256 --resize-filter boxpython train.py [arguments]Arguments:
--outdir(required): Output directory for training runs.--data(required): Training dataset path (directory or zip).--gpus(int, default:1): Number of GPUs (power of two).--batch(int, optional): Override batch size.--gamma(float, optional): Override R1 gamma.--cond(bool, default:false): Enable conditional training from labels indataset.json.--mirror(bool, default:false): Enable horizontal flips.--aug(choice:noaug|ada|fixed, default:ada): Augmentation mode.--cfg(choice:auto|stylegan2|paper256|paper512|paper1024|cifar, default:auto): Base configuration.--snap(int, default:50): Snapshot interval in ticks.--resume(default:noresume): Resume from pickle or predefined source.--kimg(int, optional): Override training duration.--seed(int, default:0): Random seed.--metrics(default:fid50k_full): Comma-separated metric list ornone.
Example:
python train.py \
--outdir=/checkpoints/stylegan2ada_cond_train \
--data=./data/btxrd_train_dataset.zip \
--gpus=1 \
--batch=16 \
--gamma=6 \
--cond=1 \
--mirror=1 \
--aug=ada \
--cfg=auto \
--snap=50 \
--seed=42 \
--metrics=fid50k_fullAfter training, pick a snapshot from training-runs/.../network-snapshot-xxxxxx.pkl and sample images:
python generate.py [arguments]Arguments:
--network(required): Network pickle path (.pkl) or URL.--outdir(required): Output image directory.--seeds(required unless--projected-wis used): Comma list or range, e.g.0,1,2or0-199.--trunc(float, default:1.0): Truncation psi.--class(int, optional): Class ID for conditional models.--noise-mode(choice:const|random|none, default:const): Noise handling.--projected-w(file, optional): Generate from projected latentWinstead of seeds.
Example:
python generate.py \
--outdir out/btxrd_samples \
--trunc 0.7 \
--seeds 0-799 \
--class 0 \
--network training-runs/stylegan2ada_cond_train/00000-btxrd_train_anatomical_dataset-cond-mirror-auto1-gamma6-batch16-ada/network-snapshot-020000.pklNotes:
--classis required when training with--cond true; class IDs come fromdataset.json.- Repeat with different
--classvalues to generate each tumor class or use the generate.sbatch file - For each class a minimum of 800 images should be generated for the execution of json-adjuster to work.
Note: Considering that the finetuning process is compute-intensive, the following instructions assume access to a slurm cluster.
python data/generate_hf_dataset.py2. Clone diffusers and navigate to the cloned repository
python3 -m venv venv
source venv/bin/activate # On Linux/macOS
# or
venv\Scripts\activate # On Windows
pip install .Place finetuned_latent_diffusion/finetune_latent_diffusion.sh in examples/text_to_image in the diffusers repository.
bash finetune_latent_diffusion.sh [resolution] [lora_rank] [batch_size] [hf_dataset_path]resolutionis image resolution of generated images. We recommend to stick with 512, and downsample/upsample images if needed for downstream task.lora_rankis the LoRA rank to use in finetuning. We recommend 16, 32 or 64.batch_sizeis the batch size to use in finetuning. We recommend 4 or 8.hf_dataset_pathis the path to the dataset in the Hugging Face format you have previously generated. It is in this repository at/data/dataset/hf_dataset
Example:
bash finetune_latent_diffusion.sh 512 32 4 "../bone-tumour-classification/data/dataset/hf_dataset"sbatch generate_augmentation_images.sh <model_base> <lora_model_path> <num_images> <tumor_subtype> [use_detailed_prompt] [output_dir]model_base: Base diffusion model to use. Choices:stable-diffusionorroentgen.lora_model_path: Path to the LoRA weights directory from finetuning. This is in/diffusers/examples/text_to_image/. You can also use a subfolder inside for a specific checkpoint.num_images: Number of synthetic images to generate per tumor subtype.tumor_subtype: Which tumor type to generate. Choices:osteochondroma,osteosarcoma,multiple_osteochondromas,simple_bone_cyst,giant_cell_tumor,synovial_osteochondroma,osteofibroma, orall.use_detailed_prompt(optional):trueorfalse(default:false). Iftrue, randomly samples anatomical locations and views to create more varied prompts.output_dir(optional): Custom output directory for generated images. This should be specified if you are usingallas the tumor type, so that you can easily use the generated images for the classifier. This path is used as input for--synthetic_imagesforjson_adjuster.py.
Example:
sbatch generate_augmentation_images.sh stable-diffusion "../diffusers/examples/text_to_image/sd-1-5-lora-rank-64-batch-8-resolution-512-2026-02-11/checkpoint-10000" 800 all true generated_imagesTo compute FID score for the diffusion models, first images need to be generated.
python generate_fid_samples.py \
--model_base "$MODEL_BASE" \
--lora_model_path "$LORA_MODEL_PATH" \
--output_dir "$OUTPUT_DIR" \
--n_samples "$N_SAMPLES"model_base: Base diffusion model to use.lora_model_path: Path to the LoRA weights directory from finetuning.output_dir: Custom output directory for generated images.n_samples: number of images to generate.
And then run evaluate_fid.sh pointing at the right folders with real and fake images.
python -m pytorch_fid \
"$REAL_DIR" \
"$FAKE_DIR" \
--device cuda:0 \
--batch-size 32 \
--num-workers 0REAL_DIR: Flattened folder of real images.FAKE_DIR: Flattened folder of generated images to evaluate.
If needed, real images can be prepared and parsed by prepare_real_samples.sh.
LPIPS computation is similar to that explained in Metric Evaluation. However, as images need to be organized in subfolders according to anatomical site and cancer subtype, here you can find two additional files to re-organize your generated images.
- For real samples:
sbatch parse_real_images.sh- For generated samples:
sbatch parse_diff_images.shAnd then compute LPIPS using:
python lpips_eval.py \
--real_root "${REAL_ROOT}" \
--gen_root "${GEN_ROOT}" \
${CLASSES:+--classes ${CLASSES}} \
--pairs "${PAIRS}" \
--img_size "${IMG_SIZE}" \
--backbone "${BACKBONE}" \
--device "${DEVICE}" \
--seed "${SEED}"Arguments:
--real_root(required): Root directory with real class subfolders.--gen_root(required): Root directory with generated class subfolders.--classes(optional): Explicit class list. If omitted, all classes underreal_rootare used.--pairs(int, default:10000): Number of random image pairs per class.--img_size(int, default:256): Resize target for LPIPS input.--backbone(choice:alex|vgg|squeeze, default:alex): LPIPS backbone.--device(default: auto): Device for evaluation (cudaif available, elsecpu).--seed(int, default:0): Random seed for pair sampling.
python -m custom_latent_diffusion.vae.train
python -m custom_latent_diffusion.vae.train --run-name <RUN_NAME>
python -m custom_latent_diffusion.vae.sample --run-name <RUN_NAME>
<RUN_DIR> is the directory of the run which you want to test, for example train_vae_2025-12-07_17-36-29
python -m custom_latent_diffusion.diffusion.train --run-name <RUN_NAME>
<RUN_DIR> is the directory of the VAE train run, for example train_vae_2025-12-07_17-36-29
python -m custom_latent_diffusion.sample --vae-run-name <VAE_RUN_NAME> --ldm-run-name <LDM_RUN_NAME> --class-name <CLASS_NAME>
<VAE_RUN_DIR>is the directory of the VAE train run, for exampletrain_vae_2025-12-07_17-36-29<LDM_RUN_DIR>is the directory of the diffusion train run, for exampletrain_ldm_2025-12-07_17-36-29<CLASS_NAME>is the name of the tumor subtype which you wish to sample for, for exampleosteochondroma
Work had been started on a new custom latent diffusion approach that uses the diffusers library in the custom_latent_diffusion_new folder. However, this remains work in progress and is not usable yet.
This repo provides two complementary metrics to assess generated image quality and diversity:
- LPIPS (intra-class diversity):
lpips_eval.py - FID (distribution distance real vs. generated):
finetuned_latent_diffusion/evaluate_fid.sh
Run from project root:
python lpips_eval.py \
--real_root data/dataset/final_patched_BTXRD \
--gen_root <GENERATED_ROOT> \
--pairs 10000 \
--img_size 256 \
--backbone alexArguments:
--real_root(required): Root directory with real class subfolders.--gen_root(required): Root directory with generated class subfolders.--classes(optional): Explicit class list. If omitted, all classes underreal_rootare used.--pairs(int, default:10000): Number of random image pairs per class.--img_size(int, default:256): Resize target for LPIPS input.--backbone(choice:alex|vgg|squeeze, default:alex): LPIPS backbone.--device(default: auto): Device for evaluation (cudaif available, elsecpu).--seed(int, default:0): Random seed for pair sampling.
Output:
- Per-class LPIPS mean/std for real and generated sets.
- Macro average across classes.
To run LPIPS script, images need to be in
The SLURM script computes FID with pytorch-fid:
sbatch finetuned_latent_diffusion/evaluate_fid.shBefore running, edit these variables in finetuned_latent_diffusion/evaluate_fid.sh:
REAL_DIR: Flattened folder of real images.FAKE_DIR: Flattened folder of generated images to evaluate.