티스토리 뷰

반응형

증상

  • pytorch에서 loss.backward()를 호출하는 부분에서 아래와 같은 오류 발생
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

디버깅 방법

  • 아래와 같이 computation graph를 그려서 어느 부분부터 그래프가 중복되는지 확인한다.
from torchviz import make_dot
...
make_dot(loss, params=dict(net.named_parameters())).render(f"graph", format="png")

문제 해결

  • 문제는 computation graph가 중간에 겹쳐서 발생한 문제였다. computation graph가 겹치는 variable이 존재하였다.(기존 loop iteration에서 계산한 값을 다음 iteration에서 사용) 이런 variable은 index로 reference를 하지 않기 때문에 찾기가 극도록 곤란하였다.
  • 그래서 computation graph가 겹치는 변수들은 매 iteration마다 다시 계산하도록 코드를 수정하였고 정상 작동함을 확인하였다.
반응형
댓글