(See Getting Started with SFrames for setup instructions)
import graphlab
# Limit number of worker processes. This preserves system memory, which prevents hosted notebooks from crashing.
graphlab.set_runtime_config('GRAPHLAB_DEFAULT_NUM_PYLAMBDA_WORKERS', 4)
We will use a popular benchmark dataset in computer vision called CIFAR-10.
(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)
This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.
image_train = graphlab.SFrame('image_train_data/')
The two lines below allow us to compute deep features. This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded.
(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)
# deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
# image_train['deep_features'] = deep_learning_model.extract_features(image_train)
image_train.head()
We will now build a simple image retrieval system that finds the nearest neighbors for any image.
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
label='id')
Let's find similar images to this cat picture.
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19] # 第18 张图片
cat['image'].show()
knn_model.query(cat)
We are going to create a simple function to view the nearest neighbors to save typing:
函数
filter_by()
:选取数据集中一部分行
def get_images_from_ids(query_result): # 通过索引编号返回对应图片
return image_train.filter_by(query_result['reference_label'],'id')
cat_neighbors = get_images_from_ids(knn_model.query(cat))
cat_neighbors['image'].show()
Very cool results showing similar cats.
car = image_train[8:9]
car['image'].show()
get_images_from_ids(knn_model.query(car))['image'].show()
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()
show_neighbors(8)
show_neighbors(26)