%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
from fastai.metrics import accuracy
PATH = "/home/katey/DeepLearning/Data/HAM10000/"
label_csv = f'{PATH}HAM10000_metadata.csv'
label_df = pd.read_csv(label_csv)
Note that perhaps a quarter of lesions have multiple images. We need to ensure that images from the same lesion do not appear in both the training and validation sets. Take a random sample of the unique lesion ids
np.random.seed(827)
val_lesions = list(np.random.choice(label_df.lesion_id.unique(),size = 3000))
val_idxs = label_df[label_df['lesion_id'].isin(val_lesions)].index
Make df with just the image-id and dx
reduce_label_df = label_df.drop(columns = ['lesion_id','dx_type','age','sex','localization'])
reduce_label_df.columns = ['filename','label']
reduce_label_df.head()
label_csv = f'{PATH}labels.csv'
reduce_label_df.to_csv(label_csv, index = False)
tfms = get_transforms(flip_vert = True)
sz = 256
bs=32
src = (ImageItemList.from_csv(PATH, 'labels.csv', folder = 'train', suffix = '.jpg')
.split_by_idx(val_idxs)
.label_from_df())
data = (src.transform(tfms, size=sz)
.databunch(bs=bs).normalize(imagenet_stats))
Load previously trained model
learn = create_cnn(data, models.resnet50, metrics=accuracy).load('mod-resnet50-sz256')
Pick which image to use from the validation set
image_num = 87
idx=image_num
x,y = data.valid_ds[idx]
x.show()
data.valid_ds.y[idx]
m = learn.model.eval();
Create a minibatch with a single item, and put on the GPU
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
from fastai.callbacks.hooks import *
A 'hook' saves the output of the final layer of the convolutional part of the model, or m[0]. The predictions are computed just to get the hook.
def hooked_backward(cat=y):
with hook_output(m[0]) as hook_a:
with hook_output(m[0], grad=True) as hook_g:
preds = m(xb)
preds[0,int(cat)].backward()
return hook_a,hook_g
hook_a,hook_g = hooked_backward()
acts = hook_a.stored[0].cpu()
acts.shape
Compute the mean activations across the final layer; 2048 channels averaged, resulting in an 8 x 8 heatmap
avg_acts = acts.mean(0)
avg_acts.shape
def show_heatmap(hm):
_,ax = plt.subplots()
xb_im.show(ax)
ax.imshow(hm, alpha=0.6, extent=(0,256,256,0),
interpolation='bilinear', cmap='magma');
show_heatmap(avg_acts)
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
grad.shape,grad_chan.shape
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)