diff --git "a/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\354\236\245\354\204\234\354\227\260.ipynb" "b/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\354\236\245\354\204\234\354\227\260.ipynb" new file mode 100644 index 0000000..c0d1c45 --- /dev/null +++ "b/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\354\236\245\354\204\234\354\227\260.ipynb" @@ -0,0 +1,5404 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "24TpKoXHYjO9" + }, + "source": [ + "# BART 모델을 활용한 뉴스 요약 Fine-tuning\n", + "\n", + "본 실습 노트북은 교재의 실습 코드를 바탕으로 구현되었으며, **BART (Denoising Sequence-to-Sequence)** 논문의 아키텍처적 특성을 이해하고 실제 생성 요약 Tasks에 적용하는 것을 목표로 합니다.\n", + "\n", + "1. **조건부 생성 (Conditional Generation):** BART는 인코더-디코더 전체 구조를 사용하여 입력 컨텍스트를 파악하고 새로운 문장을 생성합니다. 코드에서는 `BartForConditionalGeneration` 클래스를 사용합니다.\n", + "2. **손실 함수와 패딩 무시 (-100):** 생성 태스크에서 가변 길이 문장을 처리할 때 패딩 토큰은 손실(Loss) 계산에서 제외해야 합니다. 코드에서는 패딩 값으로 `-100`을 설정하여 Cross Entropy 손실 계산 시 자동 무시되도록 처리합니다.\n", + "3. **ROUGE 평가지표:** 생성된 요약문과 정답 요약문의 텍스트 유사도를 N-gram 정밀도 및 재현율 기반으로 평가합니다.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nIAUb2mZYjPD" + }, + "source": [ + "## 0. 필수 라이브러리 설치\n", + "데이터세트 로드 및 루지(ROUGE) 평가를 위한 Hugging Face 라이브러리들을 설치합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "YnHvdJudYjPE", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5cf7a70b-e2c4-4970-9bd4-6f13d4b703e6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.10.2)\n", + "Collecting evaluate\n", + " Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)\n", + "Collecting rouge_score\n", + " Downloading rouge_score-0.1.2.tar.gz (17 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (1.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.29.2)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.7.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n", + "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (1.18.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (26.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.3)\n", + "Requirement already satisfied: regex>=2025.10.22 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n", + "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)\n", + "Requirement already satisfied: typer in /usr/local/lib/python3.12/dist-packages (from transformers) (0.25.1)\n", + "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.8.0)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (from rouge_score) (3.9.1)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge_score) (1.17.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.14.1)\n", + "Requirement already satisfied: click>=8.4.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (8.4.1)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.5.1)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (0.28.1)\n", + "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.18)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.5.20)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (1.5.4)\n", + "Requirement already satisfied: rich>=13.8.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (13.9.4)\n", + "Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (0.0.4)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (1.5.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2026.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.2)\n", + "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (26.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.5.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.24.2)\n", + "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (4.13.0)\n", + "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (1.0.9)\n", + "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (0.16.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer->transformers) (4.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer->transformers) (2.20.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.8.0->typer->transformers) (0.1.2)\n", + "Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: rouge_score\n", + " Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=a3c78a9e7ef9798bf5269c45ac9778d587d0110b0fcdf4e6e675ffc80adfdddc\n", + " Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70\n", + "Successfully built rouge_score\n", + "Installing collected packages: rouge_score, evaluate\n", + "Successfully installed evaluate-0.4.6 rouge_score-0.1.2\n" + ] + } + ], + "source": [ + "!pip install datasets transformers evaluate rouge_score absl-py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8BbMrCWYjPG" + }, + "source": [ + "## 1. 뉴스 요약 데이터세트 불러오기 및 분할 (교재 예제 7.18)\n", + "미국의 AI 기업 아르길라(Argilla)가 공개한 뉴스 요약 데이터세트를 불러와 학습, 검증, 테스트 데이터로 분리합니다. 연산 속도를 확보하기 위해 5,000개의 샘플만 추출하여 사용합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "rKyLLGXYYjPH", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 533, + "referenced_widgets": [ + "9b428c251fb4403caf0a7faa608ba63d", + "dc1adeab56d749b79958ef05e17b2dba", + "211f8695a7cb424cb1d2a2b942168b98", + "2c8e5849b1ea461794eaf330da6991eb", + "9bc4603fe5c74a409490f5848598d2c3", + "3bf4a3fb458748408b4b594e52dfa710", + "a75f9023adac4a008ec0f559ee93673e", + "d8921ee59e134e6397c907761e4284c1", + "207c2c053113440a9c809c423e0ed5f1", + "83c55f3d8ea3416295446d289bce15c4", + "a6b4236f16c141c592ed06e032c30634", + "b86fedff9b84451d9e6313725e71c64d", + "7f5f45b1d44943c2b0ff2d6c57efa9ed", + "ef6f15a423d746488fc0a7a237fc1cbd", + "24915bfd227f4fdfad7151b99aee4b38", + "d78a89400e284f229d16243956026670", + "85e308cd5db24b37bee1854886746224", + "c8bb19dc45544fb2820db49f54f64f57", + "67adfec146ab42f9b1a43eb501f33a9c", + "3d9886fc5be945f6884ea8b66a0a0cae", + "ab238fe1be154242aa9fb874e0df86f7", + "75e268c2d02643578212dbcc287cd756", + "8636eca42b034a70b4fe5c7ba353ed7f", + "57628eba9eb14dfc890e6af48608017e", + "6bdce2d6ef0c419886e5b03331ddf7ca", + "c807e2e73824446d8e809ed6b94f054d", + "9d13332a906b456f973320db946dd913", + "459b98ef73184bfdae9eabb14329389a", + "c15be171a60f4769802b3f2214687ae1", + "c7402fdfd10f43e3a3a715e3bbdf649e", + "d53b537b8b744dfa8af6f7877f2dc1bc", + "7e2c882cd3bc46da900148fd0223ec79", + "167e973fff324594997f200a031d8b8a", + "1a3389a9dbf94a45bffb3f546ce97836", + "1e6b18dc4b7647e2a0834532eb0e6158", + "d5a9f75522cc4736966dcca62bb4a492", + "96f56e294ff7486792396540f03d4a2e", + "6aa267da443244e7814e7d4c82971ff7", + "04ce4a8e431b43ed8438f6c8f0487b00", + "439217dc5fe5424a8764e929efb01728", + "2e23bb4c7150488888ddfa94dace04f6", + "ac5a29b93e7a4ed4aa941fbc3a30fcc2", + "b1a417064c08432486cf91c806a983cb", + "565c06d5355e4eb4942e1582598dd6f9", + "3bfbcc5855d142a9b182cf7b4794b83c", + "29dc692a564a48ca92d88aee35664ec8", + "909104d9718f41ffa3b4336882b99c84", + "f0f36c6ea7db4bdf8202ef23c529eae9", + "415cfb6c13984e609d4e834e396d8856", + "4feeee35813a4a3d98160b3cb42438ba", + "1eb22325105342b980d9bb569e9f0b40", + "8ac75cc3c8d84a9aba5b5d2b413bd71a", + "57ddf0caa1ee4c10a8d428cd8eccb29e", + "2b8e3eb34d154cc29bd798a92320e07e", + "e2e1096108e9402196b809c1c8c1de9c" + ] + }, + "outputId": "e28fad27-01c6-4705-aaf7-45cdf3fc7f3b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:112: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "README.md: 0%| | 0.00/2.02k [00:00 best_rouge:\n", + " best_rouge = val_rouge\n", + " torch.save(model.state_dict(), \"../models/BartForConditionalGeneration.pt\")\n", + " print(\"⭐ 최고 성능 갱신! 모델 가중치 저장 완료.\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4dOaOfgyYjPP" + }, + "source": [ + "## 5. 최종 테스트 데이터세트 평가 (교재 예제 7.22)\n", + "학습 과정에서 검증 세트 기준 성능이 가장 우수했던 모델 가중치를 로드하여, 학습에 쓰이지 않은 최종 테스트 데이터를 대상으로 평가를 수행합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "rvV_IQXoYjPP", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 86, + "referenced_widgets": [ + "0c6293ca44294cd584db2ed62223b132", + "c00d7847b59446c8bcd7781d27ba38ef", + "c7e943f25b7845dbbe589260fd41b7b3", + "f075643910f94aeb9446cda06624cc05", + "93f0fe1390724d3b83fdea6fa01a4585", + "1ea0624acdf54913910e04f17f7ec565", + "bea733735361491e8136e2af29b0c182", + "a77de872637e4b5b87b33a5e86aaf806", + "e2522e2fdb3446d1b3b6523fde888b50", + "bdada99845c04be7a98cf25a81615f3e", + "79ddae16b152428a8031155cb9b3d21d" + ] + }, + "outputId": "e0268ea0-7279-4990-9665-24e3166feb73" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/259 [00:00