파이토치 모델이 CPU에서 너무 느릴때. torch.set_flush_denormal()
결론부터 말하자면, 파이토치 관련 코드의 시작 부분에 이렇게 쓰면 된다.
torch.set_flush_denormal(True)
딥러닝 모델 훈련 및 실사용시엔 고차원 텐서 연산을 병렬적으로 처리할 수 있는 GPU가 거의 필수적이지만, 가벼우면서 real-time estimation이 필수적이지 않은 모델은 CPU에서 실행하더라도 충분히 실사용이 가능한 경우도 있다.
그러나 훈련된 모델을 CPU에서 실행시켜보면 간혹 예상한 속도보다 비교가 안 될 정도로 느려진다. GPU에서 하나의 입력에 대해 예측값을 내기까지 걸리는 시간이 0.1초도 걸리지 않는 모델이 CPU에서 실행했더니 20초가 넘게 걸린다고 가정해보자. 아무리 실시간 예측을 포기한다고 하더라도 결과 하나를 얻기 위해 수 초 이상을 기다리는 경험은 상당히 불쾌할 것이다.
위에 적은 내용은 실제로 내가 경험한 일이다. 해당 모델은 몇개의 모듈 단위로 구성되어 있었고, 각 모듈마다 선형 레이어가 사용되었다. 모듈별로 입/출력값은 달랐지만 선형 레이어와 입/출력값의 사이즈는 모두 같았다. 나는 당연히 입력 텐서와 선형 레이어의 크기가 같은 모듈끼린 실행 시간이 비슷할 것이라고 예상했지만, 두 모듈 내에서 선형 레이어의 연산 속도를 비교 출력해본 결과 처리 속도가 무려 40배 이상 차이가 났다. 처음엔 둘다 동일한 CPU에서 실행됐기 때문에 CPU와 관련된 문제라곤 생각하지 못했다. 각각의 입력값과 가중치가 문제인가도 고민했지만, 수십만개의 가중치와 입력값을 일일히 비교해보는 것도 무리였다.
특정 개발자 커뮤니티에 질문한 후 가장 유력한 답변을 하나 받게 되었는데, 0에 가까운 수는 연산이 느리고, 파이토치에 이를 0으로 처리하는 옵션이 있다는 것이다. 덕분에 torch.set_flush_denormal의 존재에 대해 알게 됐고, 해당 모델의 실행속도는 20초에서 0.3초 남짓으로 대폭 줄어들게 되었다. (성능 변화도 거의 없었다.)
비정규값 (Denormalized numbers)
일반적으로 부동소수점 값을 표현할 때는 아래 예시와 같이 유효숫자의 첫자리를 1의자리에서 시작한다. 가수부의 표현 범위를 1이상 9이하, 즉 유효숫자를 일의 자리로 정규화했기 때문에 이를 정규값이라고 한다.
\[1.05 \times 10^{-2}\]같은 수라도 정규화해서 표현하지 않고 지수부를 다르게 해서 자유롭게 표현할 수 있다.
\(0.105 \times 10^{-1}\) 또는 \(10.5 \times 10^{-3}\) 도 모두 같은 수이다.
만약 0.000000000105라는 숫자를 부동 소수점으로 표현하고 싶은데, 지수부가 표현 가능한 자릿수를 8로 제한한다면 정규화 되지 않은 수로만 표현 할 수 있을 것이다. 이렇게 지수부 제한으로 인해 정규화되지 못한 작은 값들을 비정규값(denormalized number)으로 부른다.
\[0.0105 \times 10^{-8}\]컴퓨터는 부동 소수점을 부호, 지수부, 가수부를 통해 2진법으로 나타내게 되는데, 메모리 비트 수에 따라 지수부가 표현할 수 있는 수의 범위가 제한된다. 정규값에서 가수부의 맨 앞 비트가 1의자리를 표현하지만, 비정규값은 지수부가 모두 0으로 채워져 있고 가수부의 맨 앞 비트가 0.1의 자리를 나타낸다. 만약 부호가 양수이고 지수부 8비트가 모두 0, 가수부 비트가 0101000….인 비정규값이 있다면, 실제 값을 이렇게 표현할 수 있다.
\[0.0101_{(2)} \times 2^{-255}\]이렇듯 0에 가까운 작은 수가 비정규값으로 처리되고, 파이토치에서 이러한 비정규값 입력과 가중치들을 모두 0으로 일괄 처리하는 옵션을 통해 연산속도를 향상시킬 수 있다는 점을 알게 되었다. 그러나 비정규값이 정규값에 비해 더 많은 비트를 차지하는 것도 아닌데 어째서 연산속도는 미치도록 느린 것인지 이해가 되지 않아 이유를 좀 더 찾아보았다.
x86 CPU의 비정규값 처리
해당 스택 오버플로우 답변을 통해 알 수 있었다. 답변자는 x86 CPU의 설계 경험이 있다고 한다… 고인물이다.
비정규값 연산 중 가수부와 지수부는 각각 다음과 같은 과정을 거친다.
- 가수부 비트는 left-shift 연산으로 정규화되고 연산 후엔 다시 right-shift로 변환된다.
- 지수부는 레지스터나 메모리에 적재될때는 32비트 중 8비트로만 제한적으로 표현되지만 연산과정 자체는 비트수의 제약을 받지 않는다.
간단한 shift연산과 정수연산… 이것만 본다면 딱히 느릴 이유가 없어보인다.
원인은 아키텍쳐 설계 원칙에 있다. CPU가 주로 처리하는 값의 대부분은 정규값이다. 자주 사용하는 연산을 더 빠른 회로에서 처리하고 덜 사용되는 연산은 상대적으로 더 오래 걸리는 회로에 배치하는 설계원칙에 의해 비정규값 연산의 우선순위가 뒤로 밀린 것이다.
해당 수가 정규값인지 비정규값인지부터 판단하고 연산을 하는 회로에선 정규값 계산시 50%의 추가 지연 시간이 생기기 때문에, x86 CPU는 모든 연산을 정규값 연산으로 처리해버리고 이후 비정규값으로 인해 예외가 발생하면 예외처리 후 뒤늦게 비정규값 연산을 수행하는 구조로 설계되었다고 한다. 때문에 정규값 연산은 클럭 주파수 3~6 사이클 수준이지만 비정규값 연산은 100 사이클 가까이 걸린다고.
결국 비정규값의 연산 그 자체가 오래 걸리는 건 아니지만 microcode exception handler까지 도달했다가 나오는 시간이 문제였다.
반대로 GPU는 비정규값을 처리하기 위한 파이프라인을 추가로 구축함으로써 정규값 연산속도에 약간의 trade-off가 존재하지만 덕분에 비정규값을 거의 속도 저하 없이 처리할 수 있다고 한다.
파이토치에서 해당 옵션 사용시 주의사항
SSE3 명령어셋을 지원하는 x86 CPU 또는 x64 CPU 에서만 사용 가능하다. (x64(64bit)는 x86(32bit)의 하위 호환성을 보장한다.)
CPU 아키텍쳐 확인
$ arch
>>> x86-64
SSE3 지원 확인
$ grep 'sse3\|pni' /proc/cpuinfo > /dev/null
if [ $? -eq 0 ]; then
echo "Supported!"
else
echo "Not supported!"
fi
>>> "Supported!"
arm등 다른 아키텍쳐의 CPU에서도 비정규값 처리 기능이 있고, 더 자세히 관련 내용에 대해 알아보려면 ‘Flush To Zero’라는 키워드로 검색해보면 된다.
출처
'computer_science' 카테고리의 다른 글
더보기파이토치 모델이 CPU에서 너무 느릴때. torch.set_flush_denormal() | 2021. 12. 13 |
---|