GoogLeNet で出力層の手前の層を特徴ベクトルとして取得する方法のメモ。
register_forward_hook() で出力層の手前の層の出力を取得します。
register_forward_hook() で出力層の手前の層の出力を取得します。
# -*- coding:utf-8 -*- import sys import json import torch import torchvision.transforms as transforms from PIL import Image from googlenet_pytorch import GoogLeNet feature_vector = None def get_feature_vector(preproc, model, img_file): input_image = Image.open(img_file) input_tensor = preproc(input_image) input_batch = input_tensor.unsqueeze(0) logits = model(input_batch) preds = torch.topk(logits, k=5).indices.squeeze(0).tolist() return feature_vector def forward_hook(module, inputs, outputs): global feature_vector feature_vector = outputs.detach().clone()[0].tolist() def init(): preproc = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) labels_map = json.load(open("labels_map.txt")) labels_map = [labels_map[str(i)] for i in range(1000)] model = GoogLeNet.from_pretrained("googlenet") model.eval() layers = list(model.children()) handle = layers[-2].register_forward_hook(forward_hook) return preproc, model def main(): preproc, model = init() img_file = sys.argv[1] fv = get_feature_vector(preproc, model, img_file) print(fv) return 0 if __name__ == '__main__': res = main() exit(res)