Transcript:

What’s up, guys in this video? I want to give you a deep understanding of how hooks work in Pi Torch. So the main reason that hooks exist in Pi Torch is so that you can inject code into parts of the computational flow that you otherwise wouldn’t have access to, or would be hard to reach, and I think this will make more sense once we get into the examples, so there are two types of hooks you can use ones you can add to tensors and ones you can add to modules and the first type we’re going to look at are the types you can add to tensors, so these are hooks that allow you to access the gradients as they flow through your backwards graph, so let’s just go through an example, so we’re on the same page with understanding how auto grad works in pytorch, and if you want a deeper understanding of this, I have a separate video and I’ll link to it in the description, so we’ll start off our example by creating two tensors named a and B A will have a value of 2 and B will have a value of 3 and they’ll both require a gradient over here in the diagram. This will be the name of the tensor, and these will be the properties on that tensor. There will also be other properties on these tensors that are hidden and I’ll just show the ones that I think are important for this example next. We’ll multiply a times B to get C. This will have the value of 6. And when we multiply and B together, we’ll start building our backwards graph, so we’ll create a node here called mole backward zero. I’m not sure why the zero is in the name, but that’s just what it is when you inspect the graph, and if you know why the zero is there, let me know in the comments I’m interested and we’ll also create two accumulate grad nodes, one for tensor a and one for tensor B and these will just accumulate the gradients as they flow backwards through the graph so that they can be stored on these leaf tensors and the definition of a leaf tensor. That’s a part of the graph just means that this leaf not only requires a gradient, but it wasn’t created by any previous operations that had input tensors that required a gradient. So for example, this C tensor was created by an operation that took in at least one tensor that required a gradient and, in fact, in this case, both of them do so it isn’t a leaf node. It’s an intermediate node, and that’s why it gets a grad function property that points to a node in the backwards graph, and basically why it has a pointer to this node Is that if we call C dot backward to start propagating a gradient backward through our backward graph, then it’ll start the backward pass by passing the gradient to this node specified on the grad function property of the C tensor, and if we want to start off with a different gradient than one, we can specify that as an argument in our backward call, but if no gradient is specified, it starts out with a default gradient value of one that gets passed to this mold, backward node and I’ll start by passing it through its backward method and I didn’t specify here, but basically what it does is it multiplies this gradient by the necessary values to get the gradient with respect to the input to this operator, so in this case because a was multiplied by three, the gradient with respect to a would be three, so it multiplies this value times three, and it passes it to the accumulate grad node for the a tensor and similarly since B is multiplied by a and a has a value of 2 then we’ll multiply this value by 2 to get the gradient with respect to B and pass that to the accumulate grad node for the B tensor and then the accumulate grad nodes just assign this gradient that’s passed to it to the grad property on the tensors, but I want to make this example a little bit more elaborate, so we can get a deeper understanding of what’s actually happening as we build our backwards graph, and as we pass gradients through it, so instead of calling C dot backward here, well, instead create a new tensor D with the value four. Well, then multiply C times D to get e and when we do the second multiplication, we’ll create a second mole backward node and we’ll also create this e tensor, which has the value of 24 and has a grad function property that points to this new mole backward node. We’ll also create this accumulate grad node that’s for the D tensor and also the next functions on these mole backward nodes, just point to the node that the gradient gets passed to next, and it has two nodes here, one for each input because the operation takes in two inputs. So now if we call e dot backward, we’ll start the backward pass from the e tensor and it’ll look at the grad function property on this tensor, which is this mole, backward node and start by passing it. The default gradient of one well then go into the backward method of this node, and it will compute what the gradient of the output of this multiply node was with respect to the inputs in this case, four and six, and I’ll multiply those values by the incoming gradient one, so we just get the outputs four and six six gets passed to this accumulate grad node. It then gets stored on the grad property of this tensor and the four gets passed along to this mole backward node and just as before it multiplies this incoming gradient by 3 and 2 and sends it to these accumulate grad nodes, which then saves that gradient on the grad property of these tensors. So here’s the example without hooks and you can see. We have a forward graph and a backwards graph, and if we want to update the forward graph or if we want to inspect any of these tensors as they’re being computed, we can just write that here. We can write print a or print B to print out the values of these tensors or we can add in more computations if we want to change the forward computational graph, but once we call e dot backward, this whole computation of this gradient getting passed through these nodes is inaccessible to us, and we can’t really inspect the gradients as they flow through it or change them if we want to, and we’re only able to see what the gradients were that were output to the leaf notes, and that’s where hooks on tensors come in, they allow us to inspect the gradients as they flow backwards through the graph and potentially change them if we want to. So this is the same example as before, but I’ve just added hooks and I’ll walk through what happens to the graph as you’re adding the hooks and then I’ll show how the gradients are computed as they flow backwards through the graph with the hooks in place, so right here is where we add the first hook we call C dot register hook and we pass it a function that takes in a gradient and optionally returns a new gradient. And if you don’t return anything from this function, it’ll just use the same gradient as before and pass it along, so when we register this hook, it first gets added to the backward hooks on this C tensor and this is an ordered dictionary, so it matters what order you add the hooks to the tensor because in the backwards graph, they’ll get called in that order, so right here we have C hook as the first hook next, we’ll register another hook this time we’re just passing it a lambda function and it’ll take in a gradient and it’ll just print out a gradient so here it’ll return the value from this print function, which is actually none, so it won’t change the gradient. I’ll just print it out and it’ll continue using the previous gradient in the backwards graph, and you can see here. It’s added the Lambda function to its backward hooks so next we’ll call C dot retain grad, and this is what you would want to do. If you want to store the gradient on an intermediate node, so in this example, a B and D are leaf nodes and they will be the only nodes that get gradients stored to them by default by these accumulate grad nodes, and if we want a gradient to be stored on an intermediate node, we call retaingrad and what that does is register a backward hook that is the retain grad function whenever that function gets called, whatever gradient was passed to, it will be stored on the grad property of this tensor. So next we’ll create the detensor and then we’ll register a hook on the detensor and here it’s just the lambda function that takes in the gradient and adds 100 to it, so because it’s returning a gradient, it will replace the gradient that was passed to it and you can see that Lambda function was added over here to the D Tensors backward hooks, so something to note about how the internals of the backward hooks system works is that there’s a difference between adding a hook to an intermediate node and a leaf node. When you add a hook to a leaf node, it just adds it to its backward Hooks ordered dictionary, but when you add a hook to an intermediate node the first time you add a hook to. It’s backward hooks order dictionary. You also notify the node in the backwards graph associated with this tensor in this case, this mole backward node that this tensor does have backward hooks, so you’ll add this tensor’s backward hooks to this backward nodes list of pre-hooks and these are the hooks that will get called before the gradient is passed to its backward method, so after multiplying C and D to get e well, then call e dot retain grad and this as before, we’ll add the retain grad hook to the e tensors backward hooks. It will also notify this backward node that this tensor has hooks by adding these backward hooks to this backward nodes. Pre-hook’s well, then register another hook on this e-tensor. It’s just a lambda function that takes in the gradient and multiplies it by two, replacing the original gradient. Well, then call e dot retain grad again, And I just wanted to do this just to show you that. When you call retained grad a second time on a tensor, it is basically a no op. It doesn’t have any effect. This method call will actually check to see if that tensor already retains a gradient and if so, it will just ignore that second call. So as we can see here, just the single retain grad hook and this Lambda function are added. So now with all these hooks in place, let’s walk through the backward pass, so we start by calling e dot backward so that starts by passing this gradient of 1 to this mole backward node. It will then get passed through this nodes, pre-hooks in this case, the e backward hooks over here we go through these backward hooks one. At a time first, we go through the retain grad hook and that takes this gradient and saves it to the grad property of this tensor. So now it’s one. We then continue passing this gradient of one into this lambda function, which takes in the gradient and multiplies it by two, so the gradient is now two. It’s exited these pre-hooks and it’s passed through this backward method on this node and as before it takes in that gradient of two and multiplies it by 4 for this output and multiplies it by 6. For this output, so this 12 gets past this accumulate grad node and the accumulate grad nodes will always check to see if the tensor associated with this node has any backward hooks, and if it does, it’ll pass that gradient through those hooks before saving it to the grad property on that tensor. So in this case, we only have one hook and it’s this lambda function, which, if you remember, was this one over here and just takes in that gradient and returns a new gradient, which is the original plus 100 so 12 plus 100 is 112 and then that gradient gets stored on this tensor. This gradient of 8 then gets passed to this small backward node and it goes through its pre-hooks, which are the C backward hooks, so it starts off by going through this function, which we’ll start off by just printing the value so prints out eight, and then it will return a new gradient, which is the original plus two so eight, plus two is ten that value of 10 then gets passed on to the next hook, which is just this lambda function, so we print out 10 and since we don’t return any value, the original gradient gets passed along so 10 gets past this retain grad hook function, which will just store the value of 10 on this C tensor. So one thing to note here. Is that here? We called retain grad at the end of our hooks and on the e-tensor we called retain grad at the beginning of the hooks, so that’s something to keep in mind if you’re using retained grad with hooks that change the gradient, the gradient that actually gets saved to that tensor will depend on where the retained grad is in your list of hooks, so for example here initially, the gradient was 1 and it stored that gradient before passing it to this lambda function, which multiplied it by 2 and up here we are passed in a gradient of 8 and this C hook function adds 2 to it, and that happens before we retain the gradient so here we retain the value of 10 so now continuing, we finished these pre-hooks and we have the value of 10 It’s then passed into this node’s backward method, which as before multiplies that gradient by 3 and passes it to this accumulate grad node now a value of 30 and it also multiplies that 10 gradient by 2 and passes it to this accumulate grad node now. These accumulate grad nodes. Check if these tensors have any hooks, and they don’t. So they just store that value directly on the grad property of these tensors. Now you may have noticed that each of these hooks has a unique key in their backward hooks. Order dictionary. And this key is actually the ID of the handle to this hook. And whenever you call register hook, it will actually return the handle to this hook. So for example, if here, when we call C Dot register hook, we save the output of that to the variable H we can later call H Dot remove, and that will remove the hook from whichever dictionary it’s stored in so in this example before we had the C hook, but after we remove it, it’s no longer in this tensor’s hooks, and it will no longer get called during the backward. Pass another thing to keep in mind. Is that with your hook functions? You don’t actually want to change the gradient tensor that’s passed into your functions and what I mean by this is, you don’t want to perform any in-place operations on those tensors and the reason for this is that this gradient may actually be passed to other parts of your backward graph as well and just to show a quick example of that if we take our previous example, and instead of multiplying C and D to get E If we instead just add C to D to get e, we can see how this could happen, so in this case, if we call e dot backward, we’ll start with a gradient of one. It’ll get passed to this backward method, which, in this case just passes that gradient along in the backwards graph. This is because when you pass a gradient backwards through an addition, the gradient gets passed equally to the sea tensor and to this detensor. So now if we add a hook to this detensor over here register, hook, d-hook and this function performs an in-place operation, not only will it change this gradient that gets passed to this accumulate grad node, but it’ll also change this gradient, which gets passed to this mole backward node and it will end up having the effect of multiplying all our gradients in our leaf nodes by 100 which is probably a side effect. You did not intend when using this in-place operation and you can see even here. We didn’t even return the gradient from this function, so it shouldn’t have even updated the gradient for this D tensor, but because nothing was returned from this function, it tries to use the same gradient as before, but because it’s changed from the in-place operation, The gradient that gets stored on this. Tensor is the updated value from the in-place operation. So in most cases, you should just avoid In-place operations unless you’re strictly memory bound and you’re trying to do something to get around this memory issue, and then, in that case, just be aware that this is a potential issue and make sure that the gradient that’s getting passed to your hook function isn’t getting passed to anywhere else in your backward graph. Now those were the hooks on tensors and those were actually the much more complicated hooks. We’ll, now look at the hooks on modules and these will be a lot easier to understand. First of all, a typical module will have a forward method and here we’ve just taken three inputs. We’ll add them together and return the output and the hooks you have on modules is to add a function that gets called before this forward method or a function that gets called after this forward method, so it is pretty straightforward, but there are some unique things about these hooks that I wanted to point out. So in this example, we’ll start with a module that takes in three values, adds them together and returns the result. We’ll, instantiate, an instance of this module. Well, then register a hook that will be called before this forward method and that’s done by calling register forward pre-hook. Well, then register a hook that gets called after this forward method and that’s just done with register forward hook and these hooks will be these functions here. Well, then create three tensors of value, one, two and three and pass them in to our module. So the first thing to note is that we pass in a and B as positional arguments and C as a keyword argument, and that actually makes a difference in these hook functions, so the values that get passed to this forward Prehook function are this module instance, which will be this sumnet instance, as well as the positional input arguments. So in this case, it’ll be a tuple of the values a and B and we can unpack that tuple, and if we want to change the inputs that get passed to this forward method, we can return another tuple here. So in this case, we add 10 to the value of a and keep B the same, so a and B are the values 1 and 2 here, but we’re going to return the values 11 and 2 here, so those values then get passed to this forward method call. We now have the values 11 2 and 3 We add those together to get 16 and return that value here. After that, this forward hug thing gets called, well have the same module instance, as before in this function and we’ll also receive the inputs to this forward method Call. Now these inputs may be different than these inputs, and in this case, they are because these are going to be the values returned by this hug function so up here, they were the values one and two and down here they’ll be the values 11 and 2 we’ll also get past the output from this forward method call, and if we want, we can return a new value, which will override the output returned from this module. So in this case, we add 100 to the original output value of 16 and we get the value 116 so down here D, we can print it out and confirm that that’s what happened and here we see a tensor of 116 and just as with the hooks on tensors, you can also save the value returned when you register the hook and that will be a handle to the hook so that you can remove it later. If you want so here, the first time we run it through, we’ll get the value of 116 and after we remove the hooks, We should just get the value of 6. Which will be 1 plus two plus three, and if we look at the output of running this function, we’ll see, that’s what we get now. There is one more type of hook that you can use on a module but currently it’s somewhat broken, so I’d recommend not using it at all, so you can see it here listed in their API, and they have a warning message that says it may not work as expected with complex modules and there’s this issue open on Github about it. If you want to learn more about how it’s currently broken, or if you want to see if it’s finally been fixed. But if you look at Adam’s comment here, he says no. This has never been fixed and I think it’s a won’t fix for the foreseeable future. So I’m not sure when this will be fixed, but there is a workaround, and that is to just use the tensor hooks on the tensors. You pass into your module and on the tensor output from the module so to quickly demonstrate how it’s currently broken and how we can work around that we’ll start with this example, it’s just a module which takes in two values, multiplies them together and returns the result, so we’ll start by making an instance of this module. We’ll register the backward hook, which is this function here, and it will be passed an instance of the module, the gradients with respect to the input and the gradients with respect to the output and we’ll just print them out here, so we’ll make two tensors with the values two and three, we’ll pass them into our module and then we’ll call C dot backward on the output, and in this case, it works as expected the gradient with respect to the output is one and the gradients with respect to the inputs, A AND B are three and two, so this is a case where the backward hook will work because it’s a simple module, but if we update this module so that it now takes in three values and then it multiplies all of them together and returns that result with everything else the same well now just make three tensors with the values two, three and four, then we’ll pass them in to our instance of this module and then call the backward method on the output and now we’ll start to see the unusual behavior. The gradient, with respect to the output is still correct. Starts off with just one. But now we see the gradient with respect to the input. Are the values four and six, and we can see that something went wrong because not only did we give it three input values and here we’re only returned two values, but this value of four isn’t the correct gradient for any of our inputs, so in the current implementation, what’s actually happening is this is two separate operations as we saw before. First, we multiply a times B and then we multiply the result of that times. C and the gradients were seeing in our hook function are actually the gradients to the last operation in this module. So here we take in 2 times 3 and multiply them together to get 6. And then we multiply together 6. And 4 to get our output and that’s where we’re getting these gradients of 4 and 6. So what we can do instead is just add the hooks directly to the tensors that are input to the module and return from the module So here, we’ll register a hook on a b c and D, which each just print out the gradient with respect to that tensor. So now when we call d dot backward, we’ll see that we get the gradient with respect to the output of the D. Tensor is the value of one which is correct and the gradient with respect to C is six with a is 12 and with B is eight, and if we wanted to change the gradient as it flows backwards through this graph, we could also do that in these hook functions. So I hope this video gave you a better intuition for how hooks work in Pi Torch, and I want to thank Azalee Razag and Nikunj for suggesting that I make this video and if you have any other suggestions for videos, you want me to make let me know in the comments and as always links in the description to the code in the video, and I’ll see you guys next time. [music] Oh!