Practical Deep Learning for Coders, v3

Lesson 2_download

Creating your own dataset from Google Images

从Google Images创建你自己的数据集

作者: Francisco Ingham 和 Jeremy Howard. 灵感来源于Adrian Rosebrock

In this tutorial we will see how to easily create an image dataset through Google Images. Note: You will have to repeat these steps for any new category you want to Google (e.g once for dogs and once for cats).

在本教程中,我们将看到如何从Goolge Images中轻松地创建一个图片数据集。 注意:从Google搜集任何你想要的新品类,你都必须重复这些步骤(比如,狗的数据集,还有猫的数据集,你就得把这些步骤各执行一遍)。


In [ ]:
from fastai.vision import *

Get a list of URLs 获取URL的列表

Search and scroll 搜索并翻看

Go to Google Images and search for the images you are interested in. The more specific you are in your Google Search, the better the results and the less manual pruning you will have to do.

打开Google Images页面,搜索你感兴趣的图片。你在搜索框中输入的信息越精确,那么搜索的结果就越好,而需要你手动处理的工作就越少。

Scroll down until you've seen all the images you want to download, or until you see a button that says 'Show more results'. All the images you scrolled past are now available to download. To get more, click on the button, and continue scrolling. The maximum number of images Google Images shows is 700.

往下翻页直到你看到所有你想下载的图片,或者直到你看到一个“显示更多结果”的按钮为止。你刚翻看过的所有图片都是可下载的。为了获得更多的图片,点击“显示更多结果”按钮,继续翻看。Goolge Images最多可以显示700张图片。

It is a good idea to put things you want to exclude into the search query, for instance if you are searching for the Eurasian wolf, "canis lupus lupus", it might be a good idea to exclude other variants:

在搜索请求框中增加一些你想排除在外的信息是个好主意。比如,如果你要搜canis lupus lupus这一类欧亚混血狼,最好筛除掉别的种类(这样返回的结果才比较靠谱)

"canis lupus lupus" -dog -arctos -familiaris -baileyi -occidentalis

You can also limit your results to show only photos by clicking on Tools and selecting Photos from the Type dropdown.

你也可以限制搜索的结果,让搜索结果只显示照片,通过点击工具Type里选择照片进行下载。

Download into file 下载到文件中

Now you must run some Javascript code in your browser which will save the URLs of all the images you want for you dataset.

现在你需要在浏览器中运行一些javascript代码,浏览器将保存所有你想要放入数据集的图片的URL地址。

Press CtrlShiftJ in Windows/Linux and CmdOptJ in Mac, and a small window the javascript 'Console' will appear. That is where you will paste the JavaScript commands.

(浏览器窗口下)windows/linux系统按CtrlShiftJ,Mac系统按 CmdOptJ,就会弹出javascript的“控制台”面板,在这个面板中,你可以把相关的javascript命令粘贴进去。

You will need to get the urls of each of the images. Before running the following commands, you may want to disable ads block add-ons (YouBlock) in Chrome. Otherwise window.open() coomand doesn't work. Then you can run the following commands:

你需要获得每个图片对应的url。在运行下面的代码之前,你可能需要在Chrome中禁用广告拦截插件,否则window.open()函数将不能工作。然后你就可以运行下面的代码:

urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));

Create directory and upload urls file into your server

创建一个目录并将url文件上传到服务器上

Choose an appropriate name for your labeled images. You can run these steps multiple times to create different labels.

为带标签的图片选择一个合适的名字,你可以多次执行下面的步骤来创建不同的标签。


In [ ]:
folder = 'black'
file = 'urls_black.csv'

In [ ]:
folder = 'teddys'
file = 'urls_teddys.csv'

In [ ]:
folder = 'grizzly'
file = 'urls_grizzly.csv'

You will need to run this cell once per each category.

下面的单元格,每一个品种运行一次


In [1]:
path = Path('data/bears')
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-1-c9b49cd9863f> in <module>
----> 1 path = Path('data/bears')
      2 dest = path/folder
      3 dest.mkdir(parents=True, exist_ok=True)

NameError: name 'Path' is not defined

In [ ]:
path.ls()


Out[ ]:
[PosixPath('data/bears/urls_teddy.csv'),
 PosixPath('data/bears/black'),
 PosixPath('data/bears/urls_grizzly.csv'),
 PosixPath('data/bears/urls_black.csv')]

Finally, upload your urls file. You just need to press 'Upload' in your working directory and select your file, then click 'Upload' for each of the displayed files.

最后,上传你的url文件。你只需要在工作区点击“Upload”按钮,然后选择你要上传的文件,再点击“Upload”即可。

Download images 下载图片

Now you will need to download your images from their respective urls.

现在,你要做的是从图片对应的url地址下载这些图片。

fast.ai has a function that allows you to do just that. You just have to specify the urls filename as well as the destination folder and this function will download and save all images that can be opened. If they have some problem in being opened, they will not be saved.

fast.ai提供了一个函数来完成这个工作。你只需要指定url地址文件名和目标文件夹,这个函数就能自动下载和保存可打开的图片。如果图片本身无法打开的话,对应图片也不会被保存.

Let's download our images! Notice you can choose a maximum number of images to be downloaded. In this case we will not download all the urls.

我们开始下载图片吧!注意你可以设定需要下载的最大图片数量,这样我们就不会下载所有url地址了。

You will need to run this line once for every category.

下面这行代码,每一个品种运行一次


In [ ]:
classes = ['teddys','grizzly','black']

In [ ]:
download_images(path/file, dest, max_pics=200)


100.00% [200/200 00:12<00:00]
Error https://npn-ndfapda.netdna-ssl.com/original/2X/9/973877494e28bd274c535610ffa8e262f7dcd0f2.jpeg HTTPSConnectionPool(host='npn-ndfapda.netdna-ssl.com', port=443): Max retries exceeded with url: /original/2X/9/973877494e28bd274c535610ffa8e262f7dcd0f2.jpeg (Caused by NewConnectionError('<urllib3.connection.VerifiedHTTPSConnection object at 0x7f2f7c168f60>: Failed to establish a new connection: [Errno -2] Name or service not known'))

In [ ]:
# If you have problems download, try with `max_workers=0` to see exceptions:
download_images(path/file, dest, max_pics=20, max_workers=0)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/urllib3/connectionpool.py in _make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw)
    376             try:  # Python 2.7, use buffering of HTTP responses
--> 377                 httplib_response = conn.getresponse(buffering=True)
    378             except TypeError:  # Python 3

TypeError: getresponse() got an unexpected keyword argument 'buffering'

During handling of the above exception, another exception occurred:

WantReadError                             Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/urllib3/contrib/pyopenssl.py in recv_into(self, *args, **kwargs)
    293         try:
--> 294             return self.connection.recv_into(*args, **kwargs)
    295         except OpenSSL.SSL.SysCallError as e:

~/anaconda3/lib/python3.7/site-packages/OpenSSL/SSL.py in recv_into(self, buffer, nbytes, flags)
   1813             result = _lib.SSL_read(self._ssl, buf, nbytes)
-> 1814         self._raise_ssl_error(self._ssl, result)
   1815 

~/anaconda3/lib/python3.7/site-packages/OpenSSL/SSL.py in _raise_ssl_error(self, ssl, result)
   1613         if error == _lib.SSL_ERROR_WANT_READ:
-> 1614             raise WantReadError()
   1615         elif error == _lib.SSL_ERROR_WANT_WRITE:

WantReadError: 

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-29-90f92a041deb> in <module>
      1 # If you have problems download, try with `max_workers=0` to see exceptions:
----> 2 download_images(path/file, dest, max_pics=20, max_workers=0)

~/fastai/fastai/vision/data.py in download_images(urls, dest, max_pics, max_workers, timeout)
    194     dest = Path(dest)
    195     dest.mkdir(exist_ok=True)
--> 196     parallel(partial(_download_image_inner, dest, timeout=timeout), urls, max_workers=max_workers)
    197 
    198 def resize_to(img, targ_sz:int, use_min:bool=False):

~/fastai/fastai/core.py in parallel(func, arr, max_workers)
    299     "Call `func` on every element of `arr` in parallel using `max_workers`."
    300     max_workers = ifnone(max_workers, defaults.cpus)
--> 301     if max_workers<2: _ = [func(o,i) for i,o in enumerate(arr)]
    302     else:
    303         with ProcessPoolExecutor(max_workers=max_workers) as ex:

~/fastai/fastai/core.py in <listcomp>(.0)
    299     "Call `func` on every element of `arr` in parallel using `max_workers`."
    300     max_workers = ifnone(max_workers, defaults.cpus)
--> 301     if max_workers<2: _ = [func(o,i) for i,o in enumerate(arr)]
    302     else:
    303         with ProcessPoolExecutor(max_workers=max_workers) as ex:

~/fastai/fastai/vision/data.py in _download_image_inner(dest, url, i, timeout)
    187     suffix = re.findall(r'\.\w+?(?=(?:\?|$))', url)
    188     suffix = suffix[0] if len(suffix)>0  else '.jpg'
--> 189     download_image(url, dest/f"{i:08d}{suffix}", timeout=timeout)
    190 
    191 def download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4):

~/fastai/fastai/vision/data.py in download_image(url, dest, timeout)
    181 
    182 def download_image(url,dest, timeout=4):
--> 183     try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout)
    184     except Exception as e: print(f"Error {url} {e}")
    185 

~/fastai/fastai/core.py in download_url(url, dest, overwrite, pbar, show_progress, chunk_size, timeout, retries)
    164     s = requests.Session()
    165     s.mount('http://',requests.adapters.HTTPAdapter(max_retries=retries))
--> 166     u = s.get(url, stream=True, timeout=timeout)
    167     try: file_size = int(u.headers["Content-Length"])
    168     except: show_progress = False

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in get(self, url, **kwargs)
    544 
    545         kwargs.setdefault('allow_redirects', True)
--> 546         return self.request('GET', url, **kwargs)
    547 
    548     def options(self, url, **kwargs):

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
    531         }
    532         send_kwargs.update(settings)
--> 533         resp = self.send(prep, **send_kwargs)
    534 
    535         return resp

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in send(self, request, **kwargs)
    666 
    667         # Resolve redirects if allowed.
--> 668         history = [resp for resp in gen] if allow_redirects else []
    669 
    670         # Shuffle things around if there's history.

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in <listcomp>(.0)
    666 
    667         # Resolve redirects if allowed.
--> 668         history = [resp for resp in gen] if allow_redirects else []
    669 
    670         # Shuffle things around if there's history.

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in resolve_redirects(self, resp, req, stream, timeout, verify, cert, proxies, yield_requests, **adapter_kwargs)
    245                     proxies=proxies,
    246                     allow_redirects=False,
--> 247                     **adapter_kwargs
    248                 )
    249 

~/anaconda3/lib/python3.7/site-packages/requests/sessions.py in send(self, request, **kwargs)
    644 
    645         # Send the request
--> 646         r = adapter.send(request, **kwargs)
    647 
    648         # Total elapsed time of the request (approximately)

~/anaconda3/lib/python3.7/site-packages/requests/adapters.py in send(self, request, stream, timeout, verify, cert, proxies)
    447                     decode_content=False,
    448                     retries=self.max_retries,
--> 449                     timeout=timeout
    450                 )
    451 

~/anaconda3/lib/python3.7/site-packages/urllib3/connectionpool.py in urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)
    598                                                   timeout=timeout_obj,
    599                                                   body=body, headers=headers,
--> 600                                                   chunked=chunked)
    601 
    602             # If we're going to release the connection in ``finally:``, then

~/anaconda3/lib/python3.7/site-packages/urllib3/connectionpool.py in _make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw)
    378             except TypeError:  # Python 3
    379                 try:
--> 380                     httplib_response = conn.getresponse()
    381                 except Exception as e:
    382                     # Remove the TypeError from the exception chain in Python 3;

~/anaconda3/lib/python3.7/http/client.py in getresponse(self)
   1319         try:
   1320             try:
-> 1321                 response.begin()
   1322             except ConnectionError:
   1323                 self.close()

~/anaconda3/lib/python3.7/http/client.py in begin(self)
    294         # read until we get a non-100 response
    295         while True:
--> 296             version, status, reason = self._read_status()
    297             if status != CONTINUE:
    298                 break

~/anaconda3/lib/python3.7/http/client.py in _read_status(self)
    255 
    256     def _read_status(self):
--> 257         line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
    258         if len(line) > _MAXLINE:
    259             raise LineTooLong("status line")

~/anaconda3/lib/python3.7/socket.py in readinto(self, b)
    587         while True:
    588             try:
--> 589                 return self._sock.recv_into(b)
    590             except timeout:
    591                 self._timeout_occurred = True

~/anaconda3/lib/python3.7/site-packages/urllib3/contrib/pyopenssl.py in recv_into(self, *args, **kwargs)
    304                 raise
    305         except OpenSSL.SSL.WantReadError:
--> 306             if not util.wait_for_read(self.socket, self.socket.gettimeout()):
    307                 raise timeout('The read operation timed out')
    308             else:

~/anaconda3/lib/python3.7/site-packages/urllib3/util/wait.py in wait_for_read(sock, timeout)
    141     Returns True if the socket is readable, or False if the timeout expired.
    142     """
--> 143     return wait_for_socket(sock, read=True, timeout=timeout)
    144 
    145 

~/anaconda3/lib/python3.7/site-packages/urllib3/util/wait.py in poll_wait_for_socket(sock, read, write, timeout)
    102         return poll_obj.poll(t)
    103 
--> 104     return bool(_retry_on_intr(do_poll, timeout))
    105 
    106 

~/anaconda3/lib/python3.7/site-packages/urllib3/util/wait.py in _retry_on_intr(fn, timeout)
     40     # Modern Python, that retries syscalls by default
     41     def _retry_on_intr(fn, timeout):
---> 42         return fn(timeout)
     43 else:
     44     # Old and broken Pythons.

~/anaconda3/lib/python3.7/site-packages/urllib3/util/wait.py in do_poll(t)
    100         if t is not None:
    101             t *= 1000
--> 102         return poll_obj.poll(t)
    103 
    104     return bool(_retry_on_intr(do_poll, timeout))

KeyboardInterrupt: 

Then we can remove any images that can't be opened:

然后我们可以删除任何不能打开的图片:


In [ ]:
for c in classes:
    print(c)
    verify_images(path/c, delete=True, max_size=500)


teddys
100.00% [199/199 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000073.jpg'>
Image data/bears/teddys/00000106.gif has 1 instead of 3 channels
Image data/bears/teddys/00000067.png has 4 instead of 3 channels
Image data/bears/teddys/00000109.png has 4 instead of 3 channels
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000179.png'>
Image data/bears/teddys/00000125.jpg has 1 instead of 3 channels
Image data/bears/teddys/00000127.gif has 1 instead of 3 channels
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000012.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000145.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000165.jpg'>
Image data/bears/teddys/00000193.gif has 1 instead of 3 channels
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000059.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000075.jpg'>
Image data/bears/teddys/00000035.png has 4 instead of 3 channels
Image data/bears/teddys/00000086.png has 4 instead of 3 channels
cannot identify image file <_io.BufferedReader name='data/bears/teddys/00000177.jpg'>
Image data/bears/teddys/00000110.png has 4 instead of 3 channels
Image data/bears/teddys/00000099.gif has 1 instead of 3 channels
Image data/bears/teddys/00000010.png has 4 instead of 3 channels
grizzly
100.00% [199/199 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000116.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000178.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000119.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000082.png'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000108.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000019.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000132.jpg'>
Image data/bears/grizzly/00000175.gif has 1 instead of 3 channels
cannot identify image file <_io.BufferedReader name='data/bears/grizzly/00000122.jpg'>
black
100.00% [197/197 00:03<00:00]
cannot identify image file <_io.BufferedReader name='data/bears/black/00000020.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000095.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000186.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000143.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000176.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000008.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000118.jpg'>
cannot identify image file <_io.BufferedReader name='data/bears/black/00000135.jpg'>

View data 浏览数据


In [ ]:
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
        ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

In [3]:
# If you already cleaned your data, run this cell instead of the one before
# 如果你已经清洗过你的数据,直接运行这格代码而不是上面的
# np.random.seed(42)
# data = ImageDataBunch.from_csv(path, folder=".", valid_pct=0.2, csv_labels='cleaned.csv',
#         ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

Good! Let's take a look at some of our pictures then.

好!我们浏览一些照片。


In [ ]:
data.classes


Out[ ]:
['black', 'grizzly', 'teddys']

In [ ]:
data.show_batch(rows=3, figsize=(7,8))



In [ ]:
data.classes, data.c, len(data.train_ds), len(data.valid_ds)


Out[ ]:
(['black', 'grizzly', 'teddys'], 3, 448, 111)

Train model 训练模型


In [ ]:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)

In [ ]:
learn.fit_one_cycle(4)


Total time: 00:26

epoch train_loss valid_loss error_rate
1 0.957604 0.199212 0.045045
2 0.556265 0.093994 0.036036
3 0.376028 0.082099 0.036036
4 0.273781 0.076548 0.027027


In [ ]:
learn.save('stage-1')

In [ ]:
learn.unfreeze()

In [ ]:
learn.lr_find()

In [ ]:
learn.recorder.plot()

In [ ]:
learn.fit_one_cycle(2, max_lr=slice(3e-5,3e-4))


Total time: 00:11

epoch train_loss valid_loss error_rate
1 0.046916 0.072489 0.027027
2 0.041749 0.070343 0.027027


In [ ]:
learn.save('stage-2')

Interpretation 结果解读


In [ ]:
learn.load('stage-2');

In [ ]:
interp = ClassificationInterpretation.from_learner(learn)

In [ ]:
interp.plot_confusion_matrix()


Cleaning Up 清理

Some of our top losses aren't due to bad performance by our model. There are images in our data set that shouldn't be.

某些最大误差,不是由于模型的性能差导致的,而是由于数据集中的有些图片本身存在问题才导致的。

Using the ImageCleaner widget from fastai.widgets we can prune our top losses, removing photos that don't belong.

fastai.widgets库中导入并使用ImageCleaner小工具,我们就可以剔除那些归类错误的图片,从而减少预测失误。


In [ ]:
from fastai.widgets import *

First we need to get the file paths from our top_losses. We can do this with .from_toplosses. We then feed the top losses indexes and corresponding dataset to ImageCleaner.

首先,我们可以借助.from_toplosses,从top_losses中获取我们需要的文件路径。随后喂给ImageCleaner误差高的索引以及对应的数据集参数。

Notice that the widget will not delete images directly from disk but it will create a new csv file cleaned.csv from where you can create a new ImageDataBunch with the corrected labels to continue training your model.

需要注意的是,这些小工具本身并不会直接从磁盘删除图片,它会创建一个新的csv文件cleaned.csv,通过这个文件,你可以新创建一个包含准确标签信息的ImageDataBunch(图片数据堆),并继续训练你的模型。

In order to clean the entire set of images, we need to create a new dataset without the split. The video lecture demostrated the use of the ds_type param which no longer has any effect. See the thread for more details.

为了清空整个图片集,我们需要创建一个新的未经分拆的数据集。视频课程里演示的ds_type 参数的用法已经不再有效。参照 the thread 来获取更多细节。


In [ ]:
db = (ImageList.from_folder(path)
                   .no_split()
                   .label_from_folder()
                   .transform(get_transforms(), size=224)
                   .databunch()
     )

In [1]:
# If you already cleaned your data using indexes from `from_toplosses`,<br><br>
# 如果你已经从`from_toplosses`使用indexes清理了你的数据
# run this cell instead of the one before to proceed with removing duplicates.<br><br>
# 运行这个单元格里面的代码(而非上面单元格的内容)以便继续删除重复项
# Otherwise all the results of the previous step would be overwritten by<br><br>
# 否则前一个步骤中的结果都会被覆盖
# the new run of `ImageCleaner`.<br><br>
# 下面就是要运行的`ImageCleaner`代码,请把下面的注释去掉开始运行

# db = (ImageList.from_csv(path, 'cleaned.csv', folder='.')
#                    .no_split()
#                    .label_from_df()
#                    .transform(get_transforms(), size=224)
#                    .databunch()
#      )

Then we create a new learner to use our new databunch with all the images.

接下来,我们要创建一个新的学习器来使用包含全部图片的新数据堆。


In [ ]:
learn_cln = cnn_learner(db, models.resnet34, metrics=error_rate)

learn_cln.load('stage-2');

In [ ]:
ds, idxs = DatasetFormatter().from_toplosses(learn_cln)

Make sure you're running this notebook in Jupyter Notebook, not Jupyter Lab. That is accessible via /tree, not /lab. Running the ImageCleaner widget in Jupyter Lab is not currently supported.

确保你在Jupyter Notebook环境下运行这个notebook,而不是在Jupyter Lab中运行。我们可以通过/tree来访问(notebook),而不是/lab目前还不支持在Jupyter Lab中运行ImageCleaner小工具。


In [ ]:
ImageCleaner(ds, idxs, path)


'No images to show :)'

Flag photos for deletion by clicking 'Delete'. Then click 'Next Batch' to delete flagged photos and keep the rest in that row. ImageCleaner will show you a new row of images until there are no more to show. In this case, the widget will show you images until there are none left from top_losses.ImageCleaner(ds, idxs).

点击“Delete”标记待删除的照片,然后再点击“Next Batch”来删除已标记的照片,同时保持其他图片仍在原来的位置。ImageCleaner将显示一行新的图片,直到没有更多的图片可以展示。在这种情况下,小工具程序会为你展示图片,直到从top_losses.ImageCleaner(ds, idxs)没有更多图片输出为止。

You can also find duplicates in your dataset and delete them! To do this, you need to run .from_similars to get the potential duplicates' ids and then run ImageCleaner with duplicates=True. The API works in a similar way as with misclassified images: just choose the ones you want to delete and click 'Next Batch' until there are no more images left.

你会发现在你的数据集中存在重复图片,一定要删除他们!为了做到这一点,你需要运行.from_similars来获取有潜在重复可能的图片的id,然后运行ImageCleaner并使用duplicate=True作为参数。API的工作方式和(处理)错误分类的图片相类似:你只要选中那些你想删除的图片,然后点击'Next Batch'直到没有更多的图片遗留为止。

Make sure to recreate the databunch and learn_cln from the cleaned.csv file. Otherwise the file would be overwritten from scratch, loosing all the results from cleaning the data from toplosses.

确保你从cleaned.csv文件中重新创建了数据堆和learn_cln,否则文件会被完全覆盖,你将丢失所有从失误排行里清洗数据后的结果。


In [ ]:
ds, idxs = DatasetFormatter().from_similars(learn_cln)


Getting activations...
100.00% [1/1 00:01<00:00]
Computing similarities...
100.00% [55/55 00:00<00:00]

In [ ]:
ImageCleaner(ds, idxs, path, duplicates=True)


'No images to show :)'

Remember to recreate your ImageDataBunch from your cleaned.csv to include the changes you made in your data!

记住从你的cleaned.csv中重新创建ImageDatabunch,以便包含你对数据的所有变更!

Putting your model in production 部署模型

First thing first, let's export the content of our Learner object for production:

首先,导出我们训练好的Learner对象内容,为部署做好准备:


In [2]:
learn.export()


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-2-74b5358f014e> in <module>()
----> 1 learn.export()

NameError: name 'learn' is not defined

This will create a file named 'export.pkl' in the directory where we were working that contains everything we need to deploy our model (the model, the weights but also some metadata like the classes or the transforms/normalization used).

这个命令会在我们处理模型的目录里创建名为export.pk1的文件,该文件中包含了用于部署模型的所有信息(模型、权重,以及一些元数据,如一些类或用到的变换/归一化处理等)。

You probably want to use CPU for inference, except at massive scale (and you almost certainly don't need to train in real-time). If you don't have a GPU that happens automatically. You can test your model on CPU like so:

你可能想用CPU来进行推断,除了大规模的(几乎可以肯定你不需要实时训练模型),(所以)如果你没有GPU资源,你也可以使用CPU来对模型做简单的测试:


In [ ]:
defaults.device = torch.device('cpu')

In [ ]:
img = open_image(path/'black'/'00000021.jpg')
img


Out[ ]:

We create our Learner in production enviromnent like this, jsut make sure that path contains the file 'export.pkl' from before.

我们在这样的生产环境下创建学习器,只需确保path参数包含了前面生成好的“export.pk1”文件。


In [ ]:
learn = load_learner(path)

In [ ]:
pred_class,pred_idx,outputs = learn.predict(img)
pred_class


Out[ ]:
Category black

So you might create a route something like this (thanks to Simon Willison for the structure of this code):

你可能需要像下面的代码这样,创建一个路径, (谢谢Simon Willison提供了这些代码的架构):

@app.route("/classify-url", methods=["GET"])
async def classify_url(request):
    bytes = await get_bytes(request.query_params["url"])
    img = open_image(BytesIO(bytes))
    _,_,losses = learner.predict(img)
    return JSONResponse({
        "predictions": sorted(
            zip(cat_learner.data.classes, map(float, losses)),
            key=lambda p: p[1],
            reverse=True
        )
    })

(This example is for the Starlette web app toolkit.)

(这个例子适用于 Starlette的web app工具包)

Things that can go wrong 可能出错的地方

  • Most of the time things will train fine with the defaults
    大多数时候使用默认参数就能训练出好模型
  • There's not much you really need to tune (despite what you've heard!)
    没有太多需要你去调整的(尽管你可能听到过一些)
  • Most likely are
    可能就是(下面的参数)
    • Learning rate 学习率
    • Number of epochs epochs的数目

Learning rate (LR) too high 学习率(LR)太高


In [ ]:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)

In [ ]:
learn.fit_one_cycle(1, max_lr=0.5)


Total time: 00:13
epoch  train_loss  valid_loss  error_rate       
1      12.220007   1144188288.000000  0.765957    (00:13)

Learning rate (LR) too low 学习率(LR)太低


In [ ]:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)

Previously we had this result:

前面的代码运行后,我们得到如下结果:

Total time: 00:57
epoch  train_loss  valid_loss  error_rate
1      1.030236    0.179226    0.028369    (00:14)
2      0.561508    0.055464    0.014184    (00:13)
3      0.396103    0.053801    0.014184    (00:13)
4      0.316883    0.050197    0.021277    (00:15)

In [ ]:
learn.fit_one_cycle(5, max_lr=1e-5)


Total time: 01:07
epoch  train_loss  valid_loss  error_rate
1      1.349151    1.062807    0.609929    (00:13)
2      1.373262    1.045115    0.546099    (00:13)
3      1.346169    1.006288    0.468085    (00:13)
4      1.334486    0.978713    0.453901    (00:13)
5      1.320978    0.978108    0.446809    (00:13)


In [ ]:
learn.recorder.plot_losses()


As well as taking a really long time, it's getting too many looks at each image, so may overfit.

不仅运行耗时过长,而且模型对每一个图片都太过注重细节,因此可能过拟合。

Too few epochs epochs过少


In [ ]:
learn = cnn_learner(data, models.resnet34, metrics=error_rate, pretrained=False)

In [1]:
learn.fit_one_cycle(1)


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-1-4dfb24161c57> in <module>
----> 1 learn.fit_one_cycle(1)

NameError: name 'learn' is not defined

Too many epochs epochs过多


In [ ]:
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.9, bs=32, 
        ds_tfms=get_transforms(do_flip=False, max_rotate=0, max_zoom=1, max_lighting=0, max_warp=0
                              ),size=224, num_workers=4).normalize(imagenet_stats)

In [ ]:
learn = cnn_learner(data, models.resnet50, metrics=error_rate, ps=0, wd=0)
learn.unfreeze()

In [ ]:
learn.fit_one_cycle(40, slice(1e-6,1e-4))


Total time: 06:39
epoch  train_loss  valid_loss  error_rate
1      1.513021    1.041628    0.507326    (00:13)
2      1.290093    0.994758    0.443223    (00:09)
3      1.185764    0.936145    0.410256    (00:09)
4      1.117229    0.838402    0.322344    (00:09)
5      1.022635    0.734872    0.252747    (00:09)
6      0.951374    0.627288    0.192308    (00:10)
7      0.916111    0.558621    0.184982    (00:09)
8      0.839068    0.503755    0.177656    (00:09)
9      0.749610    0.433475    0.144689    (00:09)
10     0.678583    0.367560    0.124542    (00:09)
11     0.615280    0.327029    0.100733    (00:10)
12     0.558776    0.298989    0.095238    (00:09)
13     0.518109    0.266998    0.084249    (00:09)
14     0.476290    0.257858    0.084249    (00:09)
15     0.436865    0.227299    0.067766    (00:09)
16     0.457189    0.236593    0.078755    (00:10)
17     0.420905    0.240185    0.080586    (00:10)
18     0.395686    0.255465    0.082418    (00:09)
19     0.373232    0.263469    0.080586    (00:09)
20     0.348988    0.258300    0.080586    (00:10)
21     0.324616    0.261346    0.080586    (00:09)
22     0.311310    0.236431    0.071429    (00:09)
23     0.328342    0.245841    0.069597    (00:10)
24     0.306411    0.235111    0.064103    (00:10)
25     0.289134    0.227465    0.069597    (00:09)
26     0.284814    0.226022    0.064103    (00:09)
27     0.268398    0.222791    0.067766    (00:09)
28     0.255431    0.227751    0.073260    (00:10)
29     0.240742    0.235949    0.071429    (00:09)
30     0.227140    0.225221    0.075092    (00:09)
31     0.213877    0.214789    0.069597    (00:09)
32     0.201631    0.209382    0.062271    (00:10)
33     0.189988    0.210684    0.065934    (00:09)
34     0.181293    0.214666    0.073260    (00:09)
35     0.184095    0.222575    0.073260    (00:09)
36     0.194615    0.229198    0.076923    (00:10)
37     0.186165    0.218206    0.075092    (00:09)
38     0.176623    0.207198    0.062271    (00:10)
39     0.166854    0.207256    0.065934    (00:10)
40     0.162692    0.206044    0.062271    (00:09)