- toc {:toc}
Model.train() & Model.eval()
- ๋ชจ๋ธ์ ๊ต์กํ๊ณ ์๋ค๋ ๊ฒ์ ๋ชจ๋ธ์ ์๋ฆฐ๋ค.
- Batch Normalization Layer, Dropout Layer์ ๊ฒฝ์ฐ train, eval ์ ๋ฐ๋ผ์ ๋ค๋ฅด๊ฒ ๋์ํ๊ธฐ ๋๋ฌธ์ ํ์ต, ์ถ๋ก ์ ๋ฐ๋ฅธ ๋์์ ์ง์ ํด์ค์ผ ํ๋ค.
- train()
- Batch Norm - Batch Statistics๋ฅผ ์ด์ฉํ๋ค.
- Dropout - ์ฃผ์ด์ง ํ๋ฅ ์ ๋ฐ๋ผ ํ์ฑํ๋๋ค.
- eval()
- Batch Norm - ํ์ต ์ ์ฌ์ฉ๋ Batch Statistics๋ฅผ ํตํด ๊ฒฐ์ ๋ Running Statistics๋ฅผ ์ด์ฉํ๋ค.
- Dropout - ๋นํ์ฑํ ๋๋ค.
with torch.no_grad()
{: width=โ800โ}{: .center}
- no_grad()๋ฅผ with statement์ ํฌํจ์ํค๋ฉด Pytorch์ Autograd Engine์ ๋นํ์ฑํํ์ฌ gradient ์ฐ์ฐ์ ์ฌ์ฉํ์ง ์๋ค.
- ์ด ๊ฒฝ์ฐ backward๋ฅผ ์ฌ์ฉํ์ง ์๊ณ require_grad=False ๊ฐ ๋๋ค.
- gradient ์ฐ์ฐ์ ํ์ง ์๊ธฐ ๋๋ฌธ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ณ ์๋๋ฅผ ๋์ธ๋ค.
optimizer.zero_grad()
-
Pytorch์ ๊ฒฝ์ฐ gradient๊ฐ๋ค์ backwardํ ๋ ๊ณ์ํด์ ์ถ์ ํ๋ค.
-
zero_grad()๋ฅผ ํ์ง ์๋๋ค๋ฉด ์ด์ step์์ backwardํ ๋ ์ฌ์ฉ๋ gradient๊ฐ์ด ํ์ฌ step์์๋ ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์ ์ค๋ณต ์ ์ฉ๋๋ค.
โ ์ค๋ณต ์ ์ฉ๋๋ค๋ฉด ์๋ํ ๋ฐฉํฅ๊ณผ ๋ค๋ฅธ ๋ฐฉํฅ์ ๊ฐ๋ฅด์ผ ํ์ต์ด ์ ๋์ง ์๋๋ค.
โ ๊ฒฐ๋ก ์ ์ผ๋ก ๋งค ํ์ต iteration๋ง๋ค optimizer.zero_grad()๋ฅผ ํด์ฃผ๋ฉด์ gradient๋ฅผ 0์ผ๋ก ์ด๊ธฐํ ์์ผ์ค์ผ ํ๋ค.