1- Double Backward with Custom Functions
1+ ์ฌ์ฉ์ ์ ์ ํจ์์ ์ด์ค ์ญ์ ํ
22=====================================
3+ **๋ฒ์ญ **: `๋ฐ๊ฑด์ <https://github.com/ParkKunsu >`_
34
4- It is sometimes useful to run backwards twice through backward graph, for
5- example to compute higher-order gradients. It takes an understanding of
6- autograd and some care to support double backwards, however. Functions
7- that support performing backward a single time are not necessarily
8- equipped to support double backward. In this tutorial we show how to
9- write a custom autograd function that supports double backward, and
10- point out some things to look out for.
115
6+ ์ญ์ ํ ๊ทธ๋ํ๋ฅผ ํตํด ์ญ์ ํ๋ฅผ ๋ ๋ฒ ์คํํ๋ ๊ฒ์ ๊ฐ๋์ฉ ์ ์ฉํ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค.
7+ ์๋ฅผ ๋ค์ด ๊ณ ์ฐจ ๋ฏธ๋ถ์ ๊ณ์ฐํ ๋์
๋๋ค. ๊ทธ๋ฌ๋ ์ด์ค ์ญ์ ํ๋ฅผ ์ง์ํ๋ ค๋ฉด
8+ autograd์ ๋ํ ์ดํด์ ์ธ์ฌํ ์ฃผ์๊ฐ ํ์ํฉ๋๋ค. ๋จ์ผ ์ญ์ ํ๋ฅผ ์ง์ํ๋ค๊ณ ๋ฐ๋์
9+ ์ด์ค ์ญ์ ํ๋ฅผ ์ง์ํ๋ ๊ฒ์ ์๋๋๋ค. ์ด ํํ ๋ฆฌ์ผ์์๋ ์ด๋ป๊ฒ ์ฌ์ฉ์
10+ ์ ์ ํจ์๋ก ์ด์ค ์ญ์ ํ๋ฅผ ์ง์ํ๋์ง ์๋ ค์ฃผ๊ณ ์ฃผ์ํด์ผ ํ ์ ๋ค์ ์๋ดํฉ๋๋ค.
1211
13- When writing a custom autograd function to backward through twice,
14- it is important to know when operations performed in a custom function
15- are recorded by autograd, when they aren't, and most importantly, how
16- `save_for_backward ` works with all of this.
1712
18- Custom functions implicitly affects grad mode in two ways:
13+ ์ด์ค ์ญ์ ํ๋ฅผ ์ฌ์ฉํ๋ ์ฌ์ฉ์ ์ ์ autograd ํจ์๋ฅผ ์ฌ์ฉํ ๋,
14+ ํจ์ ๋ด์์ ์ด๋ป๊ฒ ๋์ํ๋์ง ์ธ์ ๊ณ์ฐ ๊ฒฐ๊ณผ๊ฐ ๊ธฐ๋ก๋๊ณ ์ธ์ ๊ธฐ๋ก๋์ง
15+ ์๋์ง ์ดํดํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ํนํ ์ ์ฒด ๊ณผ์ ์์ `save_for_backward ` ๊ฐ
16+ ์ด๋ป๊ฒ ๋์ํ๋์ง ์๋ ๊ฒ์ด ๊ฐ์ฅ ์ค์ํฉ๋๋ค.
1917
20- - During forward, autograd does not record any the graph for any
21- operations performed within the forward function. When forward
22- completes, the backward function of the custom function
23- becomes the `grad_fn ` of each of the forward's outputs
18+ ์ฌ์ฉ์ ์ ์ ํจ์๋ ์๋ฌต์ ์ผ๋ก grad ๋ชจ๋์ ๋ ๊ฐ์ง ๋ฐฉ์์ผ๋ก ์ํฅ์ ์ค๋๋ค.
2419
25- - During backward, autograd records the computation graph used to
26- compute the backward pass if create_graph is specified
20+ - ์์ ํ๋ฅผ ์งํํ๋ ๋์ autograd๋ ์์ ํ ํจ์์์์ ๋์ํ๋
21+ ์ด๋ค ์ฐ์ฐ๋ ๊ทธ๋ํ์ ๊ธฐ๋กํ์ง ์์ต๋๋ค. ์์ ํ๊ฐ ๋๋๊ณ ์ฌ์ฉ์ ์ ์ ํจ์์
22+ ์ญ์ ํ๋ ์์ ํ์ ๊ฒฐ๊ณผ์ `grad_fn ` ์ด ๋ฉ๋๋ค.
2723
28- Next, to understand how ` save_for_backward ` interacts with the above,
29- we can explore a couple examples:
24+ - ์ญ์ ํ๊ฐ ์งํ๋๋ ๋์ create_graph๊ฐ ์ง์ ๋์ด ์๋ค๋ฉด
25+ autograd๋ ์ญ์ ํ์ ์ฐ์ฐ์ ๊ทธ๋ํ์ ๊ธฐ๋กํฉ๋๋ค.
3026
27+ ๋ค์์ผ๋ก, `save_for_backward ` ๊ฐ ์์ ๋ด์ฉ๊ณผ ์ด๋ป๊ฒ ์ํธ์์ฉํ๋์ง ์ดํดํ๊ธฐ ์ํด์,
28+ ๋ช ๊ฐ์ง ์์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
3129
32- Saving the Inputs
30+
31+ ์
๋ ฅ๊ฐ ์ ์ฅํ๊ธฐ
3332-------------------------------------------------------------------
34- Consider this simple squaring function. It saves an input tensor
35- for backward. Double backward works automatically when autograd
36- is able to record operations in the backward pass, so there is usually
37- nothing to worry about when we save an input for backward as
38- the input should have grad_fn if it is a function of any tensor
39- that requires grad. This allows the gradients to be properly propagated.
33+ ๊ฐ๋จํ ์ ๊ณฑ ํจ์๋ฅผ ์๊ฐํด ๋ณด๊ฒ ์ต๋๋ค. ์ด ํจ์๋ ์ญ์ ํ๋ฅผ ์ํด์ ์
๋ ฅ ํ
์๋ฅผ ์ ์ฅํฉ๋๋ค.
34+ ์ญ์ ํ ๊ณผ์ ์ autograd๊ฐ ๊ธฐ๋กํ ์ ์๋ค๋ฉด ์ด์ค ์ญ์ ํ๋ ์๋์ผ๋ก ๋์ํฉ๋๋ค.
35+ ๋ฐ๋ผ์ ์ญ์ ํ๋ฅผ ์ํด ์
๋ ฅ์ ์ ์ฅํ ๋๋ ์ผ๋ฐ์ ์ผ๋ก ๊ฑฑ์ ํ ํ์๊ฐ ์์ต๋๋ค.
36+ ์
๋ ฅ์ด grad๋ฅผ ์๊ตฌํ๋ ํ
์๋ถํฐ ๊ณ์ฐ๋ ํจ์๋ผ๋ฉด grad_fn์ ๊ฐ์ง๊ณ ์๊ณ
37+ ์ด๋ฅผ ํตํด์ ๋ณํ๋๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ์ ํ๋๊ธฐ ๋๋ฌธ์
๋๋ค.
4038
4139.. code :: python
4240
@@ -64,7 +62,7 @@ that requires grad. This allows the gradients to be properly propagated.
6462 torch.autograd.gradgradcheck(Square.apply, x)
6563
6664
67- We can use torchviz to visualize the graph to see why this works
65+ torchviz๋ก ๊ทธ๋ํ๋ฅผ ์๊ฐํํด์ ์๋์๋ฆฌ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
6866
6967.. code-block :: python
7068
@@ -75,18 +73,17 @@ We can use torchviz to visualize the graph to see why this works
7573 grad_x, = torch.autograd.grad(out, x, create_graph = True )
7674 torchviz.make_dot((grad_x, x, out), {" grad_x" : grad_x, " x" : x, " out" : out})
7775
78- We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x)
79- And the graph of this function has been properly constructed
76+ x์ ๋ํ ๋ณํ๋๊ฐ ๊ทธ ์์ฒด๋ก x์ ํจ์๋ผ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค (dout/dx = 2x).
77+ ์ด ํจ์์ ๋ํ ๊ทธ๋ํ๋ ์ ๋๋ก ์์ฑ๋์์ต๋๋ค.
8078
8179.. image :: https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png
8280 :width: 400
8381
8482
85- Saving the Outputs
83+ ๊ฒฐ๊ณผ ์ ์ฅํ๊ธฐ
8684-------------------------------------------------------------------
87- A slight variation on the previous example is to save an output
88- instead of input. The mechanics are similar because outputs are also
89- associated with a grad_fn.
85+ ์ด์ ์์ ๋ฅผ ์กฐ๊ธ ๋ณํํ๋ฉด ์
๋ ฅ๋์ ์ถ๋ ฅ์ ์ ์ฅํ ์ ์์ต๋๋ค.
86+ ์ถ๋ ฅ๋ grad_fn๊ณผ ์ฐ๊ฒฐ๋๊ธฐ์ ๋ฐฉ์์ ๋น์ทํฉ๋๋ค.
9087
9188.. code-block :: python
9289
@@ -111,7 +108,7 @@ associated with a grad_fn.
111108 torch.autograd.gradcheck(Exp.apply, x)
112109 torch.autograd.gradgradcheck(Exp.apply, x)
113110
114- Use torchviz to visualize the graph:
111+ torchviz๋ก ๊ทธ๋ํ ์๊ฐํํ๊ธฐ.
115112
116113.. code-block :: python
117114
@@ -123,23 +120,22 @@ Use torchviz to visualize the graph:
123120 :width: 332
124121
125122
126- Saving Intermediate Results
123+ ์ค๊ฐ ๊ฒฐ๊ณผ ์ ์ฅํ๊ธฐ
127124-------------------------------------------------------------------
128- A more tricky case is when we need to save an intermediate result .
129- We demonstrate this case by implementing:
125+ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ๋ ๊ฒ์ ์ข ๋ ์ด๋ ต์ต๋๋ค .
126+ ๋ค์์ ๊ตฌํํ์ฌ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค.
130127
131128.. math ::
132129 sinh(x) := \frac {e^x - e^{-x}}{2 }
133130
134- Since the derivative of sinh is cosh, it might be useful to reuse
135- `exp(x) ` and `exp(-x) `, the two intermediate results in forward
136- in the backward computation.
131+ sinh์ ๋ํจ์๋ cosh์ด๋ฏ๋ก, ์์ ํ์ ์ค๊ฐ ๊ฒฐ๊ณผ์ธ
132+ `exp(x) ` ์ `exp(-x) ` ๋ฅผ ์ญ์ ํ ๊ณ์ฐ์ ์ฌ์ฌ์ฉํ๋ฉด ํจ์จ์ ์
๋๋ค.
137133
138- Intermediate results should not be directly saved and used in backward though.
139- Because forward is performed in no-grad mode, if an intermediate result
140- of the forward pass is used to compute gradients in the backward pass
141- the backward graph of the gradients would not include the operations
142- that computed the intermediate result. This leads to incorrect gradients .
134+ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ง์ ์ ์ฅํ์ฌ ์ญ์ ํ์ ์ฌ์ฉํ๋ฉด ์ ๋ฉ๋๋ค.
135+ ์์ ํ๊ฐ no-grad ๋ชจ๋์์ ์คํ๋๊ธฐ ๋๋ฌธ์, ๋ง์ฝ ์์ ํ์ ์ค๊ฐ ๊ฒฐ๊ณผ๊ฐ
136+ ์ญ์ ํ์์ ๋ณํ๋๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋๋ฉด ๋ณํ๋์ ์ญ์ ํ ๊ทธ๋ํ์
137+ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ณ์ฐํ ์ฐ์ฐ๋ค์ด ํฌํจ๋์ง ์์ต๋๋ค.
138+ ๊ฒฐ๊ณผ์ ์ผ๋ก ๋ณํ๋๊ฐ ๋ถ์ ํํด์ง๋๋ค .
143139
144140.. code-block :: python
145141
@@ -172,7 +168,7 @@ that computed the intermediate result. This leads to incorrect gradients.
172168 torch.autograd.gradgradcheck(sinh, x)
173169
174170
175- Use torchviz to visualize the graph:
171+ torchviz๋ก ๊ทธ๋ํ ์๊ฐํํ๊ธฐ.
176172
177173.. code-block :: python
178174
@@ -184,12 +180,11 @@ Use torchviz to visualize the graph:
184180 :width: 460
185181
186182
187- Saving Intermediate Results: What not to do
183+ ์ค๊ฐ ๊ฒฐ๊ณผ ์ ์ฅํ๊ธฐ: ์๋ชป๋ ๋ฐฉ๋ฒ
188184-------------------------------------------------------------------
189- Now we show what happens when we don't also return our intermediate
190- results as outputs: `grad_x ` would not even have a backward graph
191- because it is purely a function `exp ` and `expnegx `, which don't
192- require grad.
185+ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ฐํํ์ง ์์ผ๋ฉด ์ด๋ค ์ผ์ด ๋ฐ์ํ๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
186+ `grad_x ` ๋ ์ญ์ ํ ๊ทธ๋ํ๋ฅผ ์์ ๊ฐ์ง ๋ชปํฉ๋๋ค.
187+ ์ด๊ฒ์ `grad_x ` ๊ฐ ์ค์ง grad๋ฅผ ํ์๋ก ํ์ง ์๋ `exp ` ์ `expnegx ` ์ ํจ์์ด๊ธฐ ๋๋ฌธ์
๋๋ค.
193188
194189.. code-block :: python
195190
@@ -211,8 +206,8 @@ require grad.
211206 return grad_input
212207
213208
214- Use torchviz to visualize the graph. Notice that ` grad_x ` is not
215- part of the graph !
209+ torchviz๋ก ๊ทธ๋ํ ์๊ฐํํ๊ธฐ.
210+ ` grad_x ` ๊ฐ ๊ทธ๋ํ์ ํฌํจ๋์ง ์๋ ๊ฒ์ ํ์ธํ์ธ์ !
216211
217212.. code-block :: python
218213
@@ -225,15 +220,13 @@ part of the graph!
225220
226221
227222
228- When Backward is not Tracked
223+ ์ญ์ ํ ์ถ์ ์ด ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ
229224-------------------------------------------------------------------
230- Finally, let's consider an example when it may not be possible for
231- autograd to track gradients for a functions backward at all.
232- We can imagine cube_backward to be a function that may require a
233- non-PyTorch library like SciPy or NumPy, or written as a
234- C++ extension. The workaround demonstrated here is to create another
235- custom function CubeBackward where you also manually specify the
236- backward of cube_backward!
225+ ๋ง์ง๋ง์ผ๋ก autograd๊ฐ ํจ์์ ์ญ์ ํ์ ๋ํ ๋ณํ๋๋ฅผ ์ถ์ ํ ์ ์๋
226+ ์ํฉ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. cube_backward๊ฐ SciPy๋ NumPy ๊ฐ์
227+ ์ธ๋ถ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ฑฐ๋ C++๋ก ๊ตฌํ๋์๋ค๊ณ ๊ฐ์ ํด ๋ณด๊ฒ ์ต๋๋ค.
228+ ์ด๋ฐ ๊ฒฝ์ฐ๋ CubeBackward๋ผ๋ ๋ ๋ค๋ฅธ ์ฌ์ฉ์ ์ ์ ํจ์๋ฅผ ์์ฑํ์ฌ
229+ cube_backward์ ์ญ์ ํ๋ ์๋์ผ๋ก ์ง์ ํ๋ ๊ฒ์
๋๋ค!
237230
238231
239232.. code-block :: python
@@ -280,7 +273,7 @@ backward of cube_backward!
280273 torch.autograd.gradgradcheck(Cube.apply, x)
281274
282275
283- Use torchviz to visualize the graph:
276+ torchviz๋ก ๊ทธ๋ํ ์๊ฐํํ๊ธฐ.
284277
285278.. code-block :: python
286279
@@ -292,10 +285,9 @@ Use torchviz to visualize the graph:
292285 :width: 352
293286
294287
295- To conclude, whether double backward works for your custom function
296- simply depends on whether the backward pass can be tracked by autograd.
297- With the first two examples we show situations where double backward
298- works out of the box. With the third and fourth examples, we demonstrate
299- techniques that enable a backward function to be tracked, when they
300- otherwise would not be.
288+ ๊ฒฐ๋ก ์ ์ผ๋ก ์ฌ์ฉ์ ์ ์ ํจ์์ ์ด์ค ์ญ์ ํ ์๋ ์ฌ๋ถ๋ autograd๊ฐ
289+ ์ญ์ ํ ๊ณผ์ ์ ์ถ์ ํ ์ ์๋๋์ ๋ฌ๋ ค ์์ต๋๋ค. ์ฒ์ ๋ ์์ ์์๋
290+ ์ด์ค ์ญ์ ํ๊ฐ ์๋์ผ๋ก ๋์ํ๋ ๊ฒฝ์ฐ๋ฅผ ๋ณด์ฌ์ฃผ์๊ณ ,
291+ ์ธ ๋ฒ์งธ์ ๋ค ๋ฒ์งธ ์์ ๋ ์ถ์ ๋์ง ์๋ ์ญ์ ํ ํจ์๋ฅผ
292+ ์ถ์ ๊ฐ๋ฅํ๊ฒ ๋ง๋๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํ์ต๋๋ค.
301293
0 commit comments