dak ブログ

python、rubyなどのプログラミング、MySQL、サーバーの設定などの備忘録。レゴの写真も。

GoogLeNet で出力層の手前の層を特徴ベクトルとして取得

2022-02-20 23:29:55 | 画像処理
GoogLeNet で出力層の手前の層を特徴ベクトルとして取得する方法のメモ。
register_forward_hook() で出力層の手前の層の出力を取得します。
# -*- coding:utf-8 -*-

import sys
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from googlenet_pytorch import GoogLeNet

feature_vector = None

def get_feature_vector(preproc, model, img_file):
    input_image = Image.open(img_file)
    input_tensor = preproc(input_image)
    input_batch = input_tensor.unsqueeze(0)
    
    logits = model(input_batch)
    preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()
    return feature_vector

def forward_hook(module, inputs, outputs):
    global feature_vector
    feature_vector = outputs.detach().clone()[0].tolist()
    
def init():
    preproc = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    labels_map = json.load(open("labels_map.txt"))
    labels_map = [labels_map[str(i)] for i in range(1000)]

    model = GoogLeNet.from_pretrained("googlenet")
    model.eval()

    layers = list(model.children())
    handle = layers[-2].register_forward_hook(forward_hook)
    return preproc, model

def main():
    preproc, model = init()
    img_file = sys.argv[1]
    fv = get_feature_vector(preproc, model, img_file)
    print(fv)

    return 0

if __name__ == '__main__':
    res = main()
    exit(res)


CentOS8 で yum install でのエラー対処

2022-02-20 15:45:59 | linux
CentOS8 で yum install がエラーになり、
mirrorlist の URL を mirrorlist.centos.org から vault.centos.org に
変更すればよいという情報もありましたが、エラーが解消されないため、
以下の設定を行いました。

■/etc/yum.repos.d/CentOS-Linux-AppStream.repo
[appstream]
name=CentOS Linux $releasever - AppStream
baseurl=http://linuxsoft.cern.ch/centos-vault/8.4.2105/AppStream/$basearch/os/
#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=AppStream&infra=$infra
#baseurl=http://mirror.centos.org/$contentdir/$releasever/AppStream/$basearch/os/
gpgcheck=1
enabled=1
gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-centosofficial

baseurl 内のバージョン番号部分に $reeasever を記述すると、バージョンが 8 になってしまい、
エラーが続いたため、/etc/redhat-release に記述されているバージョン番号を指定しました。

/etc/yum.repos.d/ 内の他のファイルについても同様に URL を変更することで、
エラーが解消されました。