๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
Study/Deep Learning

[Deep learning] Class-Incremental Learning (LwF, PODNet)

by ์œ ๋ฏธ๋ฏธYoomimi 2024. 7. 8.

 ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ํ•œ ๋ฒˆ์— ์ €์žฅํ•˜๊ณ  ํ•™์Šตํ•˜๋Š” ๊ฒƒ์€ ๋น„ํšจ์œจ์ ์ด๊ณ  ์–ด์ฉŒ๋ฉด ๋น„ํ˜„์‹ค์ ์ด๋‹ค. ํŠนํžˆ, ๋ฐ์ดํ„ฐ๊ฐ€ ๋งค์šฐ ํฌ๊ฑฐ๋‚˜ ๋ฏผ๊ฐํ•œ ์ •๋ณด๋ฅผ ํฌํ•จํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š” ๋”์šฑ! ๊ทธ๋ž˜์„œ ๊ณ ์•ˆ๋œ Class-Incremental Learning (CIL)์€ ๋ชจ๋ธ์ด ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ ์ ์ง„์ ์œผ๋กœ ์ƒˆ๋กœ์šด ํด๋ž˜์Šค๋ฅผ ํ•™์Šตํ•˜๋Š” ํ•™์Šต๋ฒ•์ด๋‹ค. ์ „ํ†ต์ ์ธ ํ•™์Šต ๋ฐฉ์‹์—์„œ๋Š” ๋ชจ๋“  ํด๋ž˜์Šค๋ฅผ ํ•œ ๋ฒˆ์— ํ•™์Šตํ•˜์ง€๋งŒ, CIL์—์„œ๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ์ ์ง„์ ์œผ๋กœ ์ œ๊ณต๋˜๋ฉฐ ๋ชจ๋ธ์ด ์ƒˆ๋กœ์šด ํด๋ž˜์Šค๋ฅผ ํ•™์Šตํ•  ๋•Œ ์ด์ „์— ํ•™์Šตํ•œ ๋‚ด์šฉ์„ ์žŠ์ง€ ์•Š๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ์ด๋Š” CIL์ด ๋‹ค๋ฃจ๋Š” Catastrophic Forgetting ๋ฌธ์ œ๋ผ๊ณ ๋„ ๋ถˆ๋ฆฌ๋Š”๋ฐ, ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ• ์ค‘ LwF์™€ PODNet์„ ๊ฐ„๋‹จํžˆ ์†Œ๊ฐœํ•ด ๋ณด๊ฒ ๋‹ค.

๋‘˜์„ ๊ตฌํ˜„ํ•œ ipynb ํŒŒ์ผ์„ ์•„๋ž˜ Github repo์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

https://github.com/yoomimi/Class-Incremental-Learning

 

GitHub - yoomimi/Class-Incremental-Learning: Class-Incremental Learning with LwF and PODNet

Class-Incremental Learning with LwF and PODNet. Contribute to yoomimi/Class-Incremental-Learning development by creating an account on GitHub.

github.com

 

[CIL ๋ฐฉ๋ฒ•๋ก ]

  • Regularization : ์ด์ „ task์—์„œ ํ•™์Šตํ•œ ๋„คํŠธ์›Œํฌ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ตœ๋Œ€ํ•œ ๋ณ€ํ•˜์ง€ ์•Š์œผ๋ฉด์„œ ์ƒˆ๋กœ์šด task๋ฅผ ํ•™์Šต
  • Distillation : ์ด์ „ task์—์„œ ํ•™์Šตํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ƒˆ๋กœ์šด task๋ฅผ ์œ„ํ•œ ๋„คํŠธ์›Œํฌ์— distillation
  • Distillation + Memory : ์ด์ „ task์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์†Œ๋Ÿ‰ ๋ฉ”๋ชจ๋ฆฌ๋กœ ๋‘๊ณ  ์ƒˆ๋กœ์šด taskํ•™์Šต ๋•Œ ํ™œ์šฉ
  • Distillation + Memory + Bias correction : ์ƒˆ๋กœ์šด task์— ๋Œ€ํ•œ bias๋ฅผ ์ฃผ์š” ๋ฌธ์ œ๋กœ ๋ณด๊ณ , ์ด์— ๋Œ€ํ•œ ๊ฐœ์„ ์— ์ง‘์ค‘
  • Distillation + Memory + Dynamic structure : task์— ๋”ฐ๋ผ ๊ฐ€๋ณ€์ ์œผ๋กœ ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ ์‚ฌ์šฉ
  • Distillation + Memory + Generative model : ์ด์ „ task์˜ ๋ฐ์ดํ„ฐ๋ฅผ generative model์„ ์‚ฌ์šฉํ•˜์—ฌ replay
  • Dynamic structure : Pruning / Masking ๋“ฑ์„ ์‚ฌ์šฉํ•˜์—ฌ task๋ณ„๋กœ ์‚ฌ์šฉํ•  ํŒŒ๋ผ๋ฏธํ„ฐ ๋˜๋Š” ๋„คํŠธ์›Œํฌ ๋“ฑ์„ ์ •ํ•ด์คŒ

 

1. LWF: Learning Without Forgetting(ECCV 2016https://arxiv.org/abs/1606.09282

 

LwF๋Š” Distillation์„ ์ด์šฉํ•œ CIL์ด๋‹ค.

์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์„ธ ๊ฐ€์ง€ ์ฃผ์š” ์ ‘๊ทผ ๋ฐฉ์‹์„ ๋น„๊ตํ•˜๋Š”๋ฐ,

  1. Feature Extraction: ๊ธฐ์กด ๋„คํŠธ์›Œํฌ์˜ ๋งˆ์ง€๋ง‰ hidden layer์˜ activation result ์‚ฌ์šฉํ•˜์—ฌ ํŠน์ง• ์ถ”์ถœ
  2. Fine-Tuning: ๊ธฐ์กด ๋„คํŠธ์›Œํฌ์˜ parameter๋ฅผ๋ฅผ ์ƒˆ๋กœ์šด ์ž‘์—…์„ ์œ„ํ•ด fine-tuningํ•˜๊ธฐ. (์ผ๋ฐ˜์ ์ธ fine-tuning ์ ์šฉ)
  3. Joint Training: ๋ชจ๋“  parameter๋ฅผ ์ตœ์ ํ™”

LwF๋Š” ์ด๋Ÿฌํ•œ ์ ‘๊ทผ ๋ฐฉ์‹์„ ํ˜ผํ•ฉํ•ด ์ƒˆ๋กœ์šด task์˜ class์™€ ๊ธฐ์กด task์˜ class๋กœ ๋‚˜๋ˆ ์„œ loss๋ฅผ ์ •์˜ํ–ˆ๋‹ค. 

์šฐ์„  new loss๋Š” classification์— ๋งŽ์ด ์“ฐ์ด๋Š” cross entropy loss๋ฅผ ์ผ๋‹ค. old loss๋Š” ์žŠ์œผ๋ฉด ์•ˆ๋˜๋‹ˆ Distillation loss๋ฅผ ํ™œ์šฉํ–ˆ๋‹ค. ( -> Less-forgetting Learning in Deep Neural Networks (LFL)์—์„œ ์ฒ˜์Œ ๊ณ ์•ˆ๋จ.) (y๋Š” softmax ์ถœ๋ ฅ๊ฐ’์ด๋‹ค.)

LFL๊ณผ ๋ญ๊ฐ€ ๋‹ค๋ฅด๋ƒ? LFL์—์„œ๋Š” ๋งˆ์ง€๋ง‰ Feature์— ๋Œ€ํ•ด Distillation Loss๋ฅผ ์ ์šฉํ•œ๋ฐ ๋ฐ˜ํ•ด LwF์—์„œ๋Š” Softmax ์ถœ๋ ฅ๊ฐ’์— ๋Œ€ํ•ด Distillation Loss๋ฅผ ์ ์šฉํ–ˆ๋‹ค.

 

 

2. PODNet: Pooled Outputs Distillation for Small Tasks Incremental Learning(ECCV 2020) https://arxiv.org/abs/2004.13513

 

PODNet์€ Distillation + Memory + Bias correction์„ ์ด์šฉํ•œ CIL ๋ฐฉ๋ฒ•๋ก ์ด๋‹ค. POD Distillation loss๋ฅผ ์ œ์•ˆํ•˜๊ณ  classifier๋Š” Local Similarity Classifier๋กœ ๋ณ€๊ฒฝํ–ˆ๋‹ค.

 

 

1) Pooled Outputs Distillation loss

์ €์ž๋“ค์€ width pooling๊ณผ height pooling์„ ๋ฌถ์–ด POD-spatial์ด๋ผ๊ณ  ์ •์˜ํ•˜๊ณ  ์ด๋ ‡๊ฒŒ Pooling๋œ Feature์— ๋Œ€ํ•œ L2 Loss๋ฅผ ์ทจํ•ด POD-spatial loss๋ฅผ ์„ค์ •ํ–ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ์ž ๊น ๋ณด๋Š” pooling ์ข…๋ฅ˜.

 

final embedding์—๋Š” POD-flat์ด๋ผ๋Š” ๋ฐฉ์‹์˜ flatten์„ ์ ์šฉํ–ˆ๋‹ค.

 

๊ฒฐ๊ตญ ์ตœ์ข… POD Distillation loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

 

 

2) Local similarity classifier 

 ๊ธฐ์กด์˜ UCIR (Unified Class Incremental Learning via Rebalancing)์€ ์•„๋ž˜์™€ ๊ฐ™์€ logit์„ ์ผ๋‹ค.

์ด๋•Œ theta๊ฐ€ ํ•˜๋‚˜์˜ class์— ๋Œ€ํ•œ ํ•˜๋‚˜์˜ proxy์ธ๋ฐ, ์ €์ž๋“ค์€ ๊ฐ๊ฐ์˜ ํด๋ž˜์Šค๋ฅผ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๊ฐ€ ์•„๋‹Œ, ์—ฌ๋Ÿฌ ๋ฒกํ„ฐ๋“ค๋กœ ํ‘œํ˜„ํ•˜๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•œ multi proxy๋ฅผ ์ œ์•ˆํ–ˆ๋‹ค.

  ์ด๋ ‡๊ฒŒ  local similarity classifier loss๋ฅผ ๋งŒ๋“ค์—ˆ๋‹ค.

 

 

 

3) Final loss

 

 

 

4) Result of PODNet

 

๊ฐ„๋žตํžˆ ๋ณด์ž๋ฉด cifar100์—์„œ iCaRL, BiC, UCIR๋ณด๋‹ค PODNet์ด ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ์™ธ์—๋„ ImageNet100/1000์—์„œ๋„ SOTA๋ฅผ ๋‹ฌ์„ฑํ–ˆ๋‹ค.