diff --git "a/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\354\234\244\354\204\234.ipynb" "b/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\354\234\244\354\204\234.ipynb" new file mode 100644 index 0000000..7ba8ba6 --- /dev/null +++ "b/Week13_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\354\234\244\354\204\234.ipynb" @@ -0,0 +1,1379 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "24TpKoXHYjO9" + }, + "source": [ + "# BART 모델을 활용한 뉴스 요약 Fine-tuning\n", + "\n", + "본 실습 노트북은 교재의 실습 코드를 바탕으로 구현되었으며, **BART (Denoising Sequence-to-Sequence)** 논문의 아키텍처적 특성을 이해하고 실제 생성 요약 Tasks에 적용하는 것을 목표로 합니다.\n", + "\n", + "1. **조건부 생성 (Conditional Generation):** BART는 인코더-디코더 전체 구조를 사용하여 입력 컨텍스트를 파악하고 새로운 문장을 생성합니다. 코드에서는 `BartForConditionalGeneration` 클래스를 사용합니다.\n", + "2. **손실 함수와 패딩 무시 (-100):** 생성 태스크에서 가변 길이 문장을 처리할 때 패딩 토큰은 손실(Loss) 계산에서 제외해야 합니다. 코드에서는 패딩 값으로 `-100`을 설정하여 Cross Entropy 손실 계산 시 자동 무시되도록 처리합니다.\n", + "3. **ROUGE 평가지표:** 생성된 요약문과 정답 요약문의 텍스트 유사도를 N-gram 정밀도 및 재현율 기반으로 평가합니다.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nIAUb2mZYjPD" + }, + "source": [ + "## 0. 필수 라이브러리 설치\n", + "데이터세트 로드 및 루지(ROUGE) 평가를 위한 Hugging Face 라이브러리들을 설치합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "YnHvdJudYjPE", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "230b82ad-4636-4b23-f967-217a24caaedf" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.10.2)\n", + "Requirement already satisfied: evaluate in /usr/local/lib/python3.12/dist-packages (0.4.6)\n", + "Requirement already satisfied: rouge_score in /usr/local/lib/python3.12/dist-packages (0.1.2)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (1.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.29.2)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.7.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n", + "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (1.18.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (26.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.3)\n", + "Requirement already satisfied: regex>=2025.10.22 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n", + "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)\n", + "Requirement already satisfied: typer in /usr/local/lib/python3.12/dist-packages (from transformers) (0.25.1)\n", + "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.8.0)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (from rouge_score) (3.9.1)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge_score) (1.17.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.14.1)\n", + "Requirement already satisfied: click>=8.4.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (8.4.1)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.5.1)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (0.28.1)\n", + "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.18)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.5.20)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (1.5.4)\n", + "Requirement already satisfied: rich>=13.8.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (13.9.4)\n", + "Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer->transformers) (0.0.4)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (1.5.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2026.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.2)\n", + "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (26.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.5.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.24.2)\n", + "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (4.13.0)\n", + "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (1.0.9)\n", + "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (0.16.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer->transformers) (4.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer->transformers) (2.20.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.8.0->typer->transformers) (0.1.2)\n" + ] + } + ], + "source": [ + "!pip install datasets transformers evaluate rouge_score absl-py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8BbMrCWYjPG" + }, + "source": [ + "## 1. 뉴스 요약 데이터세트 불러오기 및 분할 (교재 예제 7.18)\n", + "미국의 AI 기업 아르길라(Argilla)가 공개한 뉴스 요약 데이터세트를 불러와 학습, 검증, 테스트 데이터로 분리합니다. 연산 속도를 확보하기 위해 5,000개의 샘플만 추출하여 사용합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "rKyLLGXYYjPH", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b9216dd2-20d5-4ecd-c188-5986b76549d3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:112: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n", + "WARNING:huggingface_hub.utils._http:Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Source News: DANANG, Vietnam (Reuters) - Russian President Vladimir Putin said on Saturday he had a normal dialogue with U.S. leader Donald Trump at a summit in Vietnam, and described Trump as civil, well-educated\n", + "Summarization: Putin says had useful interaction with Trump at Vi\n", + "Training Data Size: 3000\n", + "Validation Data Size: 1000\n", + "Testing Data Size: 1000\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/numpy/_core/fromnumeric.py:57: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.\n", + " return bound(*args, **kwds)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from datasets import load_dataset\n", + "\n", + "# 데이터세트 불러오기\n", + "news = load_dataset(\"argilla/news-summary\", split=\"test\")\n", + "df = news.to_pandas().sample(5000, random_state=42)[[\"text\", \"prediction\"]]\n", + "\n", + "# 전처리: 중첩된 딕셔너리 구조에서 정답 텍스트만 추출\n", + "df[\"prediction\"] = df[\"prediction\"].map(lambda x: x[0][\"text\"])\n", + "\n", + "# 6:2:2 비율로 데이터 분할 (학습: 3000, 검증: 1000, 테스트: 1000)\n", + "train, valid, test = np.split(\n", + " df.sample(frac=1, random_state=42),\n", + " [int(0.6 * len(df)), int(0.8 * len(df))]\n", + ")\n", + "\n", + "print(f\"Source News: {train.text.iloc[0][:200]}\")\n", + "print(f\"Summarization: {train.prediction.iloc[0][:50]}\")\n", + "print(f\"Training Data Size: {len(train)}\")\n", + "print(f\"Validation Data Size: {len(valid)}\")\n", + "print(f\"Testing Data Size: {len(test)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vTDI1bHwYjPI" + }, + "source": [ + "## 2. BART 입력 텐서 및 데이터로더 생성 (교재 예제 7.19)\n", + "BART 토크나이저를 활용하여 입력 문장과 정답 요약문을 토큰화하고 패딩을 적용합니다.\n", + "**BART 논문 및 교재 핵심 개념:** 교차 엔트로피 손실 함수에서 패딩된 토큰을 무시하도록 레이블 패딩 값을 `-100`으로 설정합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "5b4BxlSgYjPK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d0b4bae6-3cc9-4bed-fb73-dc5486c36311" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n", + "데이터로더 구축 완료!\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import BartTokenizer\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "from torch.utils.data import RandomSampler, SequentialSampler\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "\n", + "def make_dataset(data, tokenizer, device):\n", + " # 본문 텍스트 토큰화\n", + " tokenized = tokenizer(\n", + " data.text.tolist(),\n", + " padding=\"longest\",\n", + " truncation=True,\n", + " max_length=1024, # 텍스트를 그대로 토큰화하면, CUDA에서 수용 가능한 범위를 넘어가 학습 시 AcceleratorError가 발생하므로 수용 가능한 크기로 최대 길이를 제한\n", + " return_tensors=\"pt\"\n", + " )\n", + "\n", + " labels = []\n", + " input_ids = tokenized[\"input_ids\"].to(device)\n", + " attention_mask = tokenized[\"attention_mask\"].to(device)\n", + "\n", + " # 요약문(정답) 토큰화\n", + " for target in data.prediction:\n", + " labels.append(tokenizer.encode(target, truncation=True, max_length=128, return_tensors=\"pt\").squeeze())\n", + "\n", + " # 손실 함수 계산 시 패딩 무시를 위해 padding_value=-100 사용\n", + " padded_labels = pad_sequence(labels, batch_first=True, padding_value=-100).to(device)\n", + " return TensorDataset(input_ids, attention_mask, padded_labels)\n", + "\n", + "def get_dataloader(dataset, sampler, batch_size):\n", + " data_sampler = sampler(dataset)\n", + " dataloader = DataLoader(dataset, sampler=data_sampler, batch_size=batch_size)\n", + " return dataloader\n", + "\n", + "# 하이퍼파라미터 및 디바이스 설정\n", + "epochs = 3\n", + "batch_size = 8\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# 토크나이저 초기화 및 데이터로더 생성\n", + "tokenizer = BartTokenizer.from_pretrained(pretrained_model_name_or_path=\"facebook/bart-base\")\n", + "\n", + "train_dataset = make_dataset(train, tokenizer, device)\n", + "train_dataloader = get_dataloader(train_dataset, RandomSampler, batch_size)\n", + "\n", + "valid_dataset = make_dataset(valid, tokenizer, device)\n", + "valid_dataloader = get_dataloader(valid_dataset, SequentialSampler, batch_size)\n", + "\n", + "test_dataset = make_dataset(test, tokenizer, device)\n", + "test_dataloader = get_dataloader(test_dataset, SequentialSampler, batch_size)\n", + "\n", + "print(\"데이터로더 구축 완료!\")" + ] + }, + { + "cell_type": "code", + "source": [ + "# 본문 및 요약문의 토큰화 시 개수 확인\n", + "# 본문 길이 분포 확인\n", + "text_lengths = [len(tokenizer.encode(t)) for t in train.text.tolist()]\n", + "print(f\"본문 - 평균: {np.mean(text_lengths):.0f}, 중앙값: {np.median(text_lengths):.0f}, 95%: {np.percentile(text_lengths, 95):.0f}, 최대: {max(text_lengths)}\")\n", + "\n", + "# 요약문 길이 분포 확인\n", + "summary_lengths = [len(tokenizer.encode(t)) for t in train.prediction.tolist()]\n", + "print(f\"요약 - 평균: {np.mean(summary_lengths):.0f}, 중앙값: {np.median(summary_lengths):.0f}, 95%: {np.percentile(summary_lengths, 95):.0f}, 최대: {max(summary_lengths)}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tSeTrABAHKPs", + "outputId": "3244ee4b-a513-404b-8f3e-299b1813f7ca" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "본문 - 평균: 498, 중앙값: 460, 95%: 1142, 최대: 3913\n", + "요약 - 평균: 15, 중앙값: 15, 95%: 20, 최대: 30\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "29ka76GWYjPM" + }, + "source": [ + "## 3. BART 조건부 생성 모델 및 최적화 함수 선언 (교재 예제 7.20)\n", + "조건부 생성 작업에 특화된 `BartForConditionalGeneration` 클래스를 사용해 6개 계층을 갖는 `facebook/bart-base` 모델을 인스턴스화합니다. 가중치 최적화를 위해 `AdamW` 알고리즘을 정의합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "x-bpZIViYjPN", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68, + "referenced_widgets": [ + "241b49c563a5466d96699a2adc01854a", + "88c50fb1fffd40ea80e1cbc7451c6c07", + "a894cd7b0ba446b8b96d3efcaacb34ca", + "3f223e3dd0684d39baf222cb603f310c", + "6cfe35e7ab1f44c79741a9a5d0e7bccb", + "90d7606b346d4072a4ae44bb3ce20102", + "40e8ca7b0da54981ac936817e7a38c0b", + "a02bd94453ab499cbd000cf5cf110c7a", + "dcc7b59c097f43869ee0a983a8bfb02d", + "6588cafba8174c4eb5680c5836f9bdbd", + "f85f30f8848642e49bda0da8d1f6e610" + ] + }, + "outputId": "a38b9347-8660-441d-f391-cd8e8bd18d10" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/259 [00:00 best_rouge:\n", + " best_rouge = val_rouge\n", + " torch.save(model.state_dict(), filepath)\n", + " print(\"⭐ 최고 성능 갱신! 모델 가중치 저장 완료.\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4dOaOfgyYjPP" + }, + "source": [ + "## 5. 최종 테스트 데이터세트 평가 (교재 예제 7.22)\n", + "학습 과정에서 검증 세트 기준 성능이 가장 우수했던 모델 가중치를 로드하여, 학습에 쓰이지 않은 최종 테스트 데이터를 대상으로 평가를 수행합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "rvV_IQXoYjPP", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 86, + "referenced_widgets": [ + "40cd1151b9cf45b989b76761fc7608af", + "9c13c216261543e99ef5eab42bb2c8d1", + "3035f38b29b7480c90fa395237b8de4c", + "4b96ca9a6c0a484aa99b3858cfd51c6f", + "6ff067161b3c45f48f130215f68c48e0", + "8e1c3a9b79e94ace89392a27392d1dce", + "43bc9160c08f4bc082712c08db4f4141", + "08c22a9c6cca4bcfa410f87a0edff02e", + "c70ffe2c706040288dca43710f8e07d9", + "a41499ed1a1a4cd88951136d09cd783e", + "2657f0d9160b4492ac3cdd32af5c1057" + ] + }, + "outputId": "007448c7-b883-4ba4-a929-5a519093ecb1" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/259 [00:00