Image Retrieval using Deep Features
In this article, we focused on using deep learning to create non-linear features to improve the performance of machine learning. We will also see how transfer learning techniques can be applied to use deep features learned with one dataset to get great performance on a different dataset. In this Ipython notebook, we are going to build new image retrieval models and explore their results on different parts of our image dataset
import graphlab
Load the CIFAR-10 dataset
# CSV format datasets https://d396qusza40orc.cloudfront.net/phoenixassets/image_train_data.csv
# https://d396qusza40orc.cloudfront.net/phoenixassets/image_test_data.csv
image_train = graphlab.SFrame('coursera-notebooks/course-1/image_train_data')
image_test = graphlab.SFrame('coursera-notebooks/course-1/image_test_data')
image_train.head()
id | image | label | deep_features | image_array |
---|---|---|---|---|
24 | Height: 32 Width: 32 | bird | [0.242871761322, 1.09545373917, 0.0, ... |
[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ... |
33 | Height: 32 Width: 32 | cat | [0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ... |
36 | Height: 32 Width: 32 | cat | [0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ... |
70 | Height: 32 Width: 32 | dog | [1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ... |
[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ... |
90 | Height: 32 Width: 32 | bird | [1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ... |
97 | Height: 32 Width: 32 | automobile | [1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ... |
107 | Height: 32 Width: 32 | dog | [0.0, 0.0, 0.220677852631, 0.0, ... |
[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ... |
121 | Height: 32 Width: 32 | bird | [0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ... |
[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ... |
136 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ... |
[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ... |
138 | Height: 32 Width: 32 | bird | [0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ... |
Train a nearest neighbour model for retrieving images using deep features
knn_model = graphlab.nearest_neighbors.create(image_train,
features=['deep_features'],
label='id')
PROGRESS: Starting brute force nearest neighbors model training.
Use image retrival model with deep features to find similar images
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()
knn_model.query(cat)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 7.002ms |
PROGRESS: | Done | | 100 | 338.929ms |
PROGRESS: +--------------+---------+-------------+--------------+
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 384 | 0.0 | 1 |
0 | 6910 | 36.9403137951 | 2 |
0 | 39777 | 38.4634888975 | 3 |
0 | 36870 | 39.7559623119 | 4 |
0 | 41734 | 39.7866014148 | 5 |
def get_images_id(query_result):
return image_train.filter_by(query_result['reference_label'],'id')
cat_neighbours = get_images_id(knn_model.query(cat))
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 8.562ms |
PROGRESS: | Done | | 100 | 307.947ms |
PROGRESS: +--------------+---------+-------------+--------------+
cat_neighbours['image'].show()
car = image_train[8:9]
car['image'].show()
get_images_id(knn_model.query(car))['image'].show()
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 7.241ms |
PROGRESS: | Done | | 100 | 289.806ms |
PROGRESS: +--------------+---------+-------------+--------------+
Just for fun, let's create a lambda function
show_neightbors = lambda i:get_images_id(knn_model.query(image_train[i:i+1]))['image'].show()
show_neightbors(8)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 8.657ms |
PROGRESS: | Done | | 100 | 303.641ms |
PROGRESS: +--------------+---------+-------------+--------------+
show_neightbors(26)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 6.87ms |
PROGRESS: | Done | | 100 | 332.992ms |
PROGRESS: +--------------+---------+-------------+--------------+
show_neightbors(1222)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 6.621ms |
PROGRESS: | Done | | 100 | 299.262ms |
PROGRESS: +--------------+---------+-------------+--------------+
show_neightbors(2000)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.0498753 | 6.76ms |
PROGRESS: | Done | | 100 | 291.575ms |
PROGRESS: +--------------+---------+-------------+--------------+
Creating category-specific image retrieval models
image_train_dog = image_train[image_train['label'] == 'dog']
len(image_train_dog)
509
image_train_cat = image_train[image_train['label'] == 'cat']
len(image_train_cat)
509
image_train_automobile = image_train[image_train['label'] == 'automobile']
len(image_train_automobile)
509
image_train_bird = image_train[image_train['label'] == 'bird']
len(image_train_bird)
478
image_train_automobile.head()
id | image | label | deep_features | image_array |
---|---|---|---|---|
97 | Height: 32 Width: 32 | automobile | [1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ... |
136 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ... |
[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ... |
302 | Height: 32 Width: 32 | automobile | [0.583938002586, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[64.0, 52.0, 37.0, 85.0, 60.0, 40.0, 92.0, 66.0, ... |
312 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.392823398113, 0.0, ... |
[124.0, 126.0, 113.0, 124.0, 126.0, 113.0, ... |
323 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 4.42310428619, ... |
[241.0, 241.0, 241.0, 238.0, 238.0, 238.0, ... |
536 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.42903900146, 0.0, ... |
[164.0, 154.0, 154.0, 128.0, 119.0, 120.0, ... |
593 | Height: 32 Width: 32 | automobile | [1.65033948421, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[231.0, 222.0, 227.0, 232.0, 217.0, 221.0, ... |
962 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.39552795887, 0.0, 0.0, ... |
[255.0, 255.0, 255.0, 255.0, 255.0, 255.0, ... |
997 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.04085636139, 0.0, ... |
[145.0, 148.0, 157.0, 131.0, 134.0, 145.0, ... |
1421 | Height: 32 Width: 32 | automobile | [0.0, 0.0, 0.0, 0.0, 0.0, 0.359612941742, ... |
[114.0, 95.0, 33.0, 118.0, 98.0, 26.0, 91.0, ... |
dog_knn_model = graphlab.nearest_neighbors.create(image_train_dog,
features=['deep_features'],
label='id')
PROGRESS: Starting brute force nearest neighbors model training.
cat_knn_model = graphlab.nearest_neighbors.create(image_train_cat,
features=['deep_features'],
label='id')
PROGRESS: Starting brute force nearest neighbors model training.
automobile_knn_model = graphlab.nearest_neighbors.create(image_train_automobile,
features=['deep_features'],
label='id')
PROGRESS: Starting brute force nearest neighbors model training.
bird_knn_model = graphlab.nearest_neighbors.create(image_train_bird,
features=['deep_features'],
label='id')
PROGRESS: Starting brute force nearest neighbors model training.
image_test[0:1]
id | image | label | deep_features | image_array |
---|---|---|---|---|
0 | Height: 32 Width: 32 | cat | [1.13469004631, 0.0, 0.0, 0.0, 0.0366497635841, ... |
[158.0, 112.0, 49.0, 159.0, 111.0, 47.0, ... |
image_test[0:1]['image'].show()
test_cat = image_test[0:1]
cat_knn_model.query(test_cat)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.196464 | 7.432ms |
PROGRESS: | Done | | 100 | 80.534ms |
PROGRESS: +--------------+---------+-------------+--------------+
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 16289 | 34.623719208 | 1 |
0 | 45646 | 36.0068799284 | 2 |
0 | 32139 | 36.5200813436 | 3 |
0 | 25713 | 36.7548502521 | 4 |
0 | 331 | 36.8731228168 | 5 |
def get_images_id_cat(query_result):
return image_train_cat.filter_by(query_result['reference_label'],'id')
get_images_id_cat(cat_knn_model.query(test_cat))['image'].show()
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.196464 | 7.881ms |
PROGRESS: | Done | | 100 | 83.519ms |
PROGRESS: +--------------+---------+-------------+--------------+
test_cat_neighbours = cat_knn_model.query(test_cat)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.196464 | 6.998ms |
PROGRESS: | Done | | 100 | 82.467ms |
PROGRESS: +--------------+---------+-------------+--------------+
test_cat_neighbours
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 16289 | 34.623719208 | 1 |
0 | 45646 | 36.0068799284 | 2 |
0 | 32139 | 36.5200813436 | 3 |
0 | 25713 | 36.7548502521 | 4 |
0 | 331 | 36.8731228168 | 5 |
image_train_cat.filter_by(test_cat_neighbours['reference_label'],'id').[1]
id | image | label | deep_features | image_array |
---|---|---|---|---|
331 | Height: 32 Width: 32 | cat | [0.0, 0.0, 0.510963916779, 0.0, ... |
[45.0, 65.0, 92.0, 72.0, 95.0, 110.0, 106.0, ... |
16289 | Height: 32 Width: 32 | cat | [0.964287519455, 0.0, 0.0, 0.0, 1.12515509129, ... |
[215.0, 219.0, 231.0, 215.0, 219.0, 232.0, ... |
25713 | Height: 32 Width: 32 | cat | [0.536971271038, 0.0, 0.0, 0.0894458889961, ... |
[228.0, 222.0, 236.0, 224.0, 213.0, 222.0, ... |
32139 | Height: 32 Width: 32 | cat | [1.29409468174, 0.0, 0.0, 0.513800263405, ... |
[217.0, 220.0, 205.0, 221.0, 227.0, 218.0, ... |
45646 | Height: 32 Width: 32 | cat | [0.983677506447, 0.0, 0.0, 0.0, 0.0, ... |
[51.0, 42.0, 26.0, 56.0, 47.0, 31.0, 59.0, 50.0, ... |
nearest_cat_sframe = image_train_cat.filter_by(test_cat_neighbours['reference_label'],'id')
nearest_cat_sframe['image'].show()
get_images_id(dog_knn_model.query(test_cat))['image'].show()
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.196464 | 6.663ms |
PROGRESS: | Done | | 100 | 79.548ms |
PROGRESS: +--------------+---------+-------------+--------------+
nearest_dog = dog_knn_model.query(test_cat)
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0 | 1 | 0.196464 | 7.628ms |
PROGRESS: | Done | | 100 | 84.028ms |
PROGRESS: +--------------+---------+-------------+--------------+
nearest_dog
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 16976 | 37.4642628784 | 1 |
0 | 13387 | 37.5666832169 | 2 |
0 | 35867 | 37.6047267079 | 3 |
0 | 44603 | 37.7065585153 | 4 |
0 | 6094 | 38.5113254907 | 5 |
image_train_dog.filter_by(nearest_dog['reference_label'],'id')
id | image | label | deep_features | image_array |
---|---|---|---|---|
6094 | Height: 32 Width: 32 | dog | [0.470533549786, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[91.0, 98.0, 71.0, 138.0, 123.0, 63.0, 135.0, ... |
13387 | Height: 32 Width: 32 | dog | [0.366494178772, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[255.0, 255.0, 255.0, 255.0, 255.0, 255.0, ... |
16976 | Height: 32 Width: 32 | dog | [0.755595386028, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[16.0, 17.0, 11.0, 18.0, 19.0, 13.0, 20.0, 21.0, ... |
35867 | Height: 32 Width: 32 | dog | [0.305321395397, 0.0, 0.0, 0.0, 0.0, 0.0, ... |
[101.0, 93.0, 9.0, 93.0, 88.0, 9.0, 90.0, 86.0, ... |
44603 | Height: 32 Width: 32 | dog | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.2646656036, 0.0, ... |
[8.0, 25.0, 9.0, 29.0, 39.0, 22.0, 66.0, 75.0, ... |
image_train_dog.filter_by(nearest_dog['reference_label'],'id')['image'].show()
calculating mean distances from neighbours
test_cat_neighbours
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 16289 | 34.623719208 | 1 |
0 | 45646 | 36.0068799284 | 2 |
0 | 32139 | 36.5200813436 | 3 |
0 | 25713 | 36.7548502521 | 4 |
0 | 331 | 36.8731228168 | 5 |
test_cat_neighbours['distance'].show()
nearest_dog
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 16976 | 37.4642628784 | 1 |
0 | 13387 | 37.5666832169 | 2 |
0 | 35867 | 37.6047267079 | 3 |
0 | 44603 | 37.7065585153 | 4 |
0 | 6094 | 38.5113254907 | 5 |
nearest_dog['distance'].show()
Computing nearest neighbors accuracy
len(image_test)
4000
image_test['label'].sketch_summary()
+------------------+-------+----------+
| item | value | is exact |
+------------------+-------+----------+
| Length | 4000 | Yes |
| # Missing Values | 0 | Yes |
| # unique values | 4 | No |
+------------------+-------+----------+
Most frequent items:
+-------+------------+------+------+------+
| value | automobile | cat | bird | dog |
+-------+------------+------+------+------+
| count | 1000 | 1000 | 1000 | 1000 |
+-------+------------+------+------+------+
image_test_cat = image_test[image_test['label'] == 'cat']
len(image_test_cat)
1000
image_test_dog = image_test[image_test['label'] == 'dog']
image_test_automobile = image_test[image_test['label'] == 'automobile']
image_test_bird = image_test[image_test['label'] == 'bird']
dog_cat_neighbors = cat_knn_model.query(image_test_dog, k=1)
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 1
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000 | 509000 | 100 | 791.418ms |
PROGRESS: | Done | 509000 | 100 | 794.145ms |
PROGRESS: +--------------+---------+-------------+--------------+
dog_cat_neighbors
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 33 | 36.4196077068 | 1 |
1 | 30606 | 38.8353268874 | 1 |
2 | 5545 | 36.9763410854 | 1 |
3 | 19631 | 34.5750072914 | 1 |
4 | 7493 | 34.778824791 | 1 |
5 | 47044 | 35.1171578292 | 1 |
6 | 13918 | 40.6095830913 | 1 |
7 | 10981 | 39.9036867306 | 1 |
8 | 45456 | 38.0674700168 | 1 |
9 | 44673 | 42.7258732951 | 1 |
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.
dog_dog_neighbors = dog_knn_model.query(image_test_dog, k=1)
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 1
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000 | 509000 | 100 | 790.626ms |
PROGRESS: | Done | 509000 | 100 | 793.114ms |
PROGRESS: +--------------+---------+-------------+--------------+
dog_dog_neighbors
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 49803 | 33.4773590373 | 1 |
1 | 5755 | 32.8458495684 | 1 |
2 | 20715 | 35.0397073189 | 1 |
3 | 13387 | 33.9010327697 | 1 |
4 | 12089 | 37.4849250909 | 1 |
5 | 6094 | 34.945165344 | 1 |
6 | 3431 | 39.0957278345 | 1 |
7 | 6184 | 37.7696131032 | 1 |
8 | 2167 | 35.1089144603 | 1 |
9 | 7776 | 43.2422832585 | 1 |
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.
dog_automobile_neighbors = automobile_knn_model.query(image_test_dog, k=1)
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 1
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000 | 509000 | 100 | 797.994ms |
PROGRESS: | Done | 509000 | 100 | 800.718ms |
PROGRESS: +--------------+---------+-------------+--------------+
dog_automobile_neighbors
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 33859 | 41.9579761457 | 1 |
1 | 2046 | 46.0021331807 | 1 |
2 | 19594 | 42.9462290692 | 1 |
3 | 11000 | 41.6866060048 | 1 |
4 | 19594 | 39.2269664935 | 1 |
5 | 49314 | 40.5845117698 | 1 |
6 | 40822 | 45.1067352961 | 1 |
7 | 44997 | 41.3221140974 | 1 |
8 | 33859 | 41.8244654995 | 1 |
9 | 33859 | 45.4976929401 | 1 |
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.
dog_bird_neighbors = bird_knn_model.query(image_test_dog, k=1)
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 1
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000 | 478000 | 100 | 765.493ms |
PROGRESS: | Done | 478000 | 100 | 768.23ms |
PROGRESS: +--------------+---------+-------------+--------------+
dog_bird_neighbors
query_label | reference_label | distance | rank |
---|---|---|---|
0 | 44658 | 41.7538647304 | 1 |
1 | 9215 | 41.3382958925 | 1 |
2 | 36675 | 38.6157590853 | 1 |
3 | 12582 | 37.0892269954 | 1 |
4 | 36122 | 38.272288694 | 1 |
5 | 8736 | 39.1462089236 | 1 |
6 | 38991 | 40.523040106 | 1 |
7 | 44177 | 38.1947918393 | 1 |
8 | 4549 | 40.1567131661 | 1 |
9 | 40225 | 45.5597962603 | 1 |
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.
dog_distances = graphlab.SFrame({'dog-dog': dog_dog_neighbors['distance'],'dog-cat': dog_cat_neighbors['distance'],
'dog-automobile': dog_automobile_neighbors['distance'], 'dog-bird': dog_bird_neighbors['distance']})
dog_distances
dog-automobile | dog-bird | dog-cat | dog-dog |
---|---|---|---|
41.9579761457 | 41.7538647304 | 36.4196077068 | 33.4773590373 |
46.0021331807 | 41.3382958925 | 38.8353268874 | 32.8458495684 |
42.9462290692 | 38.6157590853 | 36.9763410854 | 35.0397073189 |
41.6866060048 | 37.0892269954 | 34.5750072914 | 33.9010327697 |
39.2269664935 | 38.272288694 | 34.778824791 | 37.4849250909 |
40.5845117698 | 39.1462089236 | 35.1171578292 | 34.945165344 |
45.1067352961 | 40.523040106 | 40.6095830913 | 39.0957278345 |
41.3221140974 | 38.1947918393 | 39.9036867306 | 37.7696131032 |
41.8244654995 | 40.1567131661 | 38.0674700168 | 35.1089144603 |
45.4976929401 | 45.5597962603 | 42.7258732951 | 43.2422832585 |
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.
dog_distances[1]
{'dog-automobile': 46.002133180677895,
'dog-bird': 41.3382958924861,
'dog-cat': 38.83532688735544,
'dog-dog': 32.845849568405555}
def is_dog_correct(row):
if ( min(row, key=row.get) == 'dog-dog'):
return 1
else:
return 0
is_dog_correct(dog_distances[1])
1
is_dog_correct(dog_distances[9])
0
dog_distances.apply(is_dog_correct)
dtype: int
Rows: 1000
[1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, ... ]
dog_distances.apply(is_dog_correct).sum()
678
accuracy = dog_distances.apply(is_dog_correct).sum() / float (1000)
accuracy
0.678