CNNs or ViT for Medical Imaging?

Facebook
Twitter
LinkedIn
Bone Fracture Detection & Bone Age Detection

Introduction

As we dive deeper into the integration of artificial intelligence within healthcare, medical imaging has surfaced as a domain ripe for disruption. Whether it’s detecting fractures or identifying tumors, machine learning algorithms have been increasingly pivotal in augmenting the capabilities of radiologists. In this study, we’ll compare two powerful machine learning models for medical imaging: Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs).

What are Convolutional Neural Networks (CNNs)?

Convolutional Neural Networks (CNNs) have been a go-to choice for image-related tasks for many years. Rooted in mathematical convolutions, CNNs process images through a series of convolutional layers, pooling layers, and fully connected layers. This enables them to capture various features like edges, textures, and more complex representations. Particularly in medical imaging, CNNs have shown promise in tasks like identifying tumors or analyzing X-rays.

What are Vision Transformers (ViTs)?

Vision Transformers, or ViTs, bring the power of Transformers from the realm of Natural Language Processing (NLP) to computer vision. Unlike CNNs, which process images in a hierarchical manner, ViTs divide an image into patches and treat them as sequences. Each patch passes through multiple Transformer layers, making them highly capable of capturing long-range dependencies. This feature could be especially beneficial for medical imaging where small features can have significant implications.

Datasets Used in the Experiment

We performed experiments on two datasets related to bone health. Both datasets consisted of PNG X-ray images, which are high-quality images suitable for medical analysis.


Preparing the Data

For CNNs

The data was loaded using TensorFlow’s Keras API, employing methods like image_dataset_from_directory to automatically categorize and load image data from directories. The images were resized and batched for training and validation.

For ViTs

In the case of ViTs, PyTorch was used for data loading. Custom data loader classes were written for specific transformations and augmentations. Positional embeddings were added to the image patches to maintain spatial information.


CNNs vs ViTs: Methodology

CNNs

Models like Xception, VGG16, and InceptionV3 were used. The base models were fine-tuned on the medical imaging datasets. Hyperparameters like learning rate and batch size were optimized, and the models were trained for 40 epochs.

Code Highlights:

model = return_model(input_dim, len(classes), head=head, freeze=True)
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=lr), metrics=['accuracy'])
model.fit(train_ds, epochs=epochs, validation_data=val_ds, callbacks=[save_weights])

ViTs

We used the ViT-base model with patch size 16×16 and trained it for 10 epochs. The model was then fine-tuned for the specific medical imaging tasks.

Code Highlights:

model = ViT().to(device)
trained_model = model_train(dataset['train'], EPOCHS, LEARNING_RATE, BATCH_SIZE)

Results and Observations

The CNN models couldn’t surpass 44% accuracy, even after 40 epochs. Conversely, the ViT model achieved 94% accuracy after just 10 epochs of training. Further empirical testing validated that ViTs vastly outperformed CNNs on the test images.

Conclusion

The stark contrast in performance underscores the efficacy of Vision Transformers for medical imaging tasks. While CNNs have been a reliable option for years, it might be time to reconsider their supremacy.

Resources and Code Repositories

You can find all the code snippets and resources on our GitHub profile amine0110 and don’t forget to check out the detailed video on this experiment on our YouTube channel PYCAD.

Join My Newsletter!

Interested in receiving more updates on medical imaging and machine learning? Join our newsletter here.

More to explorer

Making Sense of AI in Medical Images

Explore how AI revolutionizes medical imaging, enhancing diagnosis and treatment. Dive into real-world AI applications for better healthcare outcomes.