diff --git "a/BART_\354\213\244\354\212\265_\354\234\240\353\213\244\355\230\204.ipynb" "b/BART_\354\213\244\354\212\265_\354\234\240\353\213\244\355\230\204.ipynb" new file mode 100644 index 0000000..181f6b2 --- /dev/null +++ "b/BART_\354\213\244\354\212\265_\354\234\240\353\213\244\355\230\204.ipynb" @@ -0,0 +1,957 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "24TpKoXHYjO9" + }, + "source": [ + "# BART 모델을 활용한 뉴스 요약 Fine-tuning\n", + "\n", + "본 실습 노트북은 교재의 실습 코드를 바탕으로 구현되었으며, **BART (Denoising Sequence-to-Sequence)** 논문의 아키텍처적 특성을 이해하고 실제 생성 요약 Tasks에 적용하는 것을 목표로 합니다.\n", + "\n", + "1. **조건부 생성 (Conditional Generation):** BART는 인코더-디코더 전체 구조를 사용하여 입력 컨텍스트를 파악하고 새로운 문장을 생성합니다. 코드에서는 `BartForConditionalGeneration` 클래스를 사용합니다.\n", + "2. **손실 함수와 패딩 무시 (-100):** 생성 태스크에서 가변 길이 문장을 처리할 때 패딩 토큰은 손실(Loss) 계산에서 제외해야 합니다. 코드에서는 패딩 값으로 `-100`을 설정하여 Cross Entropy 손실 계산 시 자동 무시되도록 처리합니다.\n", + "3. **ROUGE 평가지표:** 생성된 요약문과 정답 요약문의 텍스트 유사도를 N-gram 정밀도 및 재현율 기반으로 평가합니다.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nIAUb2mZYjPD" + }, + "source": [ + "## 0. 필수 라이브러리 설치\n", + "데이터세트 로드 및 루지(ROUGE) 평가를 위한 Hugging Face 라이브러리들을 설치합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "YnHvdJudYjPE", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "73081155-e9b0-4d77-caa6-3f0370e4487b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (4.44.2)\n", + "Requirement already satisfied: evaluate in /usr/local/lib/python3.12/dist-packages (0.4.6)\n", + "Requirement already satisfied: rouge_score in /usr/local/lib/python3.12/dist-packages (0.1.2)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (1.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.29.2)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.7.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n", + "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.36.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (26.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.8.0)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (from rouge_score) (3.9.1)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge_score) (1.17.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.14.1)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.5.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.18)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.5.20)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (8.4.1)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (1.5.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2026.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.2)\n", + "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (26.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.5.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.24.2)\n" + ] + } + ], + "source": [ + "!pip install datasets transformers evaluate rouge_score absl-py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8BbMrCWYjPG" + }, + "source": [ + "## 1. 뉴스 요약 데이터세트 불러오기 및 분할 (교재 예제 7.18)\n", + "미국의 AI 기업 아르길라(Argilla)가 공개한 뉴스 요약 데이터세트를 불러와 학습, 검증, 테스트 데이터로 분리합니다. 연산 속도를 확보하기 위해 5,000개의 샘플만 추출하여 사용합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "rKyLLGXYYjPH", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5e11583d-32ee-46fc-c3f5-638a97b35b69" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Source News: DANANG, Vietnam (Reuters) - Russian President Vladimir Putin said on Saturday he had a normal dialogue with U.S. leader Donald Trump at a summit in Vietnam, and described Trump as civil, well-educated\n", + "Summarization: Putin says had useful interaction with Trump at Vi\n", + "Training Data Size: 3000\n", + "Validation Data Size: 1000\n", + "Testing Data Size: 1000\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/numpy/_core/fromnumeric.py:57: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.\n", + " return bound(*args, **kwds)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from datasets import load_dataset\n", + "\n", + "# 데이터세트 불러오기\n", + "news = load_dataset(\"argilla/news-summary\", split=\"test\")\n", + "df = news.to_pandas().sample(5000, random_state=42)[[\"text\", \"prediction\"]]\n", + "\n", + "# 전처리: 중첩된 딕셔너리 구조에서 정답 텍스트만 추출\n", + "df[\"prediction\"] = df[\"prediction\"].map(lambda x: x[0][\"text\"])\n", + "\n", + "# 6:2:2 비율로 데이터 분할 (학습: 3000, 검증: 1000, 테스트: 1000)\n", + "train, valid, test = np.split(\n", + " df.sample(frac=1, random_state=42),\n", + " [int(0.6 * len(df)), int(0.8 * len(df))]\n", + ")\n", + "\n", + "print(f\"Source News: {train.text.iloc[0][:200]}\")\n", + "print(f\"Summarization: {train.prediction.iloc[0][:50]}\")\n", + "print(f\"Training Data Size: {len(train)}\")\n", + "print(f\"Validation Data Size: {len(valid)}\")\n", + "print(f\"Testing Data Size: {len(test)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vTDI1bHwYjPI" + }, + "source": [ + "## 2. BART 입력 텐서 및 데이터로더 생성 (교재 예제 7.19)\n", + "BART 토크나이저를 활용하여 입력 문장과 정답 요약문을 토큰화하고 패딩을 적용합니다.\n", + "**BART 논문 및 교재 핵심 개념:** 교차 엔트로피 손실 함수에서 패딩된 토큰을 무시하도록 레이블 패딩 값을 `-100`으로 설정합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "5b4BxlSgYjPK", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 161, + "referenced_widgets": [ + "ae39e71510294570a5e05f7b96c7731c", + "6849bd5080bc4d7dadf42d09e4dca6a1", + "6fe4dd581d0342298577887cab37ccf8", + "85363b3181f249d28124ea5abc4903eb", + "ff3c56ca0bb44ebdb9fcf54fee98ed17", + "db3b3b954a744518983eff72f4eef95b", + "dce445be829e4720beef6117f4a1ce24", + "18ffd9f7b26b4c748e796092f969f790", + "4741948e5ff2477e8c081c108c564434", + "48fe72a0503a4a9b91bdfef805bcd9b0", + "f4e9354c1b254a0fb6bd7bd641c8963f" + ] + }, + "outputId": "dda039c3-aba9-452b-cd71-3e11f6fa47ee" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "ae39e71510294570a5e05f7b96c7731c" + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "데이터로더 구축 완료!\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import BartTokenizer\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "from torch.utils.data import RandomSampler, SequentialSampler\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "\n", + "def make_dataset(data, tokenizer, device):\n", + " # 본문 텍스트 토큰화\n", + " tokenized = tokenizer(\n", + " data.text.tolist(),\n", + " padding=\"longest\",\n", + " truncation=True,\n", + " max_length=1024,\n", + " return_tensors=\"pt\"\n", + " )\n", + "\n", + " labels = []\n", + " input_ids = tokenized[\"input_ids\"].to(device)\n", + " attention_mask = tokenized[\"attention_mask\"].to(device)\n", + "\n", + " # 요약문(정답) 토큰화\n", + " for target in data.prediction:\n", + " labels.append(tokenizer.encode(target, return_tensors=\"pt\").squeeze())\n", + "\n", + " # 손실 함수 계산 시 패딩 무시를 위해 padding_value=-100 사용\n", + " padded_labels = pad_sequence(labels, batch_first=True, padding_value=-100).to(device)\n", + " return TensorDataset(input_ids, attention_mask, padded_labels)\n", + "\n", + "def get_dataloader(dataset, sampler, batch_size):\n", + " data_sampler = sampler(dataset)\n", + " dataloader = DataLoader(dataset, sampler=data_sampler, batch_size=batch_size)\n", + " return dataloader\n", + "\n", + "# 하이퍼파라미터 및 디바이스 설정\n", + "epochs = 3\n", + "batch_size = 8\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# 토크나이저 초기화 및 데이터로더 생성\n", + "tokenizer = BartTokenizer.from_pretrained(pretrained_model_name_or_path=\"facebook/bart-base\")\n", + "\n", + "train_dataset = make_dataset(train, tokenizer, device)\n", + "train_dataloader = get_dataloader(train_dataset, RandomSampler, batch_size)\n", + "\n", + "valid_dataset = make_dataset(valid, tokenizer, device)\n", + "valid_dataloader = get_dataloader(valid_dataset, SequentialSampler, batch_size)\n", + "\n", + "test_dataset = make_dataset(test, tokenizer, device)\n", + "test_dataloader = get_dataloader(test_dataset, SequentialSampler, batch_size)\n", + "\n", + "print(\"데이터로더 구축 완료!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "29ka76GWYjPM" + }, + "source": [ + "## 3. BART 조건부 생성 모델 및 최적화 함수 선언 (교재 예제 7.20)\n", + "조건부 생성 작업에 특화된 `BartForConditionalGeneration` 클래스를 사용해 6개 계층을 갖는 `facebook/bart-base` 모델을 인스턴스화합니다. 가중치 최적화를 위해 `AdamW` 알고리즘을 정의합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "x-bpZIViYjPN", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7af7b4d4-7118-494c-9965-248e32ac4f07" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "BART 모델 및 옵티마이저 선언 완료!\n" + ] + } + ], + "source": [ + "from torch import optim\n", + "from transformers import BartForConditionalGeneration\n", + "\n", + "# 모델 불러오기 및 GPU 배치\n", + "model = BartForConditionalGeneration.from_pretrained(\n", + " pretrained_model_name_or_path=\"facebook/bart-base\"\n", + ").to(device)\n", + "\n", + "# 최적화 알고리즘 설정 (Learning Rate = 5e-5)\n", + "optimizer = optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)\n", + "\n", + "print(\"BART 모델 및 옵티마이저 선언 완료!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o8BV_yIIYjPO" + }, + "source": [ + "## 4. 모델 학습 및 평가 루프 구성 (교재 예제 7.21 기반)\n", + "교재에서 수식화된 ROUGE-2 평가지표 계산 함수(`calc_rouge`)와 검증 루프(`evaluation`)를 포함하여, 전체 3 에포크 동안 파인 튜닝을 진행하는 전체 학습 파이프라인 코드입니다.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "S5zGjhyPYjPO", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "526c3f6f-e8f3-4004-f259-36652e3ec04e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "🚀 BART 미세 조정 학습을 시작합니다. (약 20~30분 소요)\n", + "Epoch [1/3] | Step [100/375] | Loss: 1.8201\n", + "Epoch [1/3] | Step [200/375] | Loss: 2.1882\n", + "Epoch [1/3] | Step [300/375] | Loss: 2.0177\n", + "\n", + "[Epoch 1 결과] Train Loss: 2.1592 | Val Loss: 1.8341 | Val Rouge-2: 0.2606\n", + "⭐ 최고 성능 갱신! 모델 가중치 저장 완료.\n", + "\n", + "Epoch [2/3] | Step [100/375] | Loss: 1.3673\n", + "Epoch [2/3] | Step [200/375] | Loss: 1.7122\n", + "Epoch [2/3] | Step [300/375] | Loss: 1.9037\n", + "\n", + "[Epoch 2 결과] Train Loss: 1.6616 | Val Loss: 1.8735 | Val Rouge-2: 0.2594\n", + "Epoch [3/3] | Step [100/375] | Loss: 1.0899\n", + "Epoch [3/3] | Step [200/375] | Loss: 1.0555\n", + "Epoch [3/3] | Step [300/375] | Loss: 1.2623\n", + "\n", + "[Epoch 3 결과] Train Loss: 1.2493 | Val Loss: 1.9696 | Val Rouge-2: 0.2561\n" + ] + } + ], + "source": [ + "import evaluate\n", + "\n", + "# 허깅페이스 evaluate 라이브러리에서 ROUGE 메트릭 로드\n", + "rouge_score = evaluate.load(\"rouge\", tokenizer=tokenizer)\n", + "\n", + "def calc_rouge(preds, labels):\n", + " # 확률 스코어가 가장 높은 토큰 인덱스 추출\n", + " preds = preds.argmax(axis=-1)\n", + "\n", + " # -100 패딩 값을 디코딩이 가능하도록 토크나이저의 pad_token_id로 복원\n", + " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", + "\n", + " # 정수 토큰 배열을 실제 텍스트 문자열로 변환\n", + " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", + " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", + "\n", + " # ROUGE 점수 계산\n", + " rouge2 = rouge_score.compute(\n", + " predictions=decoded_preds,\n", + " references=decoded_labels\n", + " )\n", + " return rouge2[\"rouge2\"]\n", + "\n", + "def evaluation(model, dataloader):\n", + " with torch.no_grad():\n", + " model.eval() # 평가 모드 전환\n", + " val_loss, val_rouge = 0.0, 0.0\n", + "\n", + " for input_ids, attention_mask, labels in dataloader:\n", + " outputs = model(\n", + " input_ids=input_ids, attention_mask=attention_mask, labels=labels\n", + " )\n", + " logits = outputs.logits\n", + " loss = outputs.loss\n", + "\n", + " # 데이터 이동 및 넘파이 변환\n", + " logits = logits.detach().cpu().numpy()\n", + " label_ids = labels.to(\"cpu\").numpy()\n", + "\n", + " rouge = calc_rouge(logits, label_ids)\n", + " val_loss += loss.item()\n", + " val_rouge += rouge\n", + "\n", + " val_loss = val_loss / len(dataloader)\n", + " val_rouge = val_rouge / len(dataloader)\n", + " return val_loss, val_rouge\n", + "\n", + "# --- 본격적인 모델 학습 (Training Loop) Run ---\n", + "import os\n", + "os.makedirs(\"../models\", exist_ok=True)\n", + "best_rouge = 0.0\n", + "\n", + "print(\"🚀 BART 미세 조정 학습을 시작합니다. (약 20~30분 소요)\")\n", + "for epoch in range(epochs):\n", + " model.train() # 학습 모드 전환\n", + " total_train_loss = 0.0\n", + "\n", + " for step, (input_ids, attention_mask, labels) in enumerate(train_dataloader):\n", + " optimizer.zero_grad()\n", + "\n", + " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n", + " loss = outputs.loss\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " total_train_loss += loss.item()\n", + "\n", + " if (step + 1) % 100 == 0:\n", + " print(f\"Epoch [{epoch+1}/{epochs}] | Step [{step+1}/{len(train_dataloader)}] | Loss: {loss.item():.4f}\")\n", + "\n", + " # 에포크 종료 후 평가 수행\n", + " avg_train_loss = total_train_loss / len(train_dataloader)\n", + " val_loss, val_rouge = evaluation(model, valid_dataloader)\n", + "\n", + " print(f\"\\n[Epoch {epoch+1} 결과] Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Rouge-2: {val_rouge:.4f}\")\n", + "\n", + " # 최고 성능 모델 저장\n", + " if val_rouge > best_rouge:\n", + " best_rouge = val_rouge\n", + " torch.save(model.state_dict(), \"../models/BartForConditionalGeneration.pt\")\n", + " print(\"⭐ 최고 성능 갱신! 모델 가중치 저장 완료.\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4dOaOfgyYjPP" + }, + "source": [ + "## 5. 최종 테스트 데이터세트 평가 (교재 예제 7.22)\n", + "학습 과정에서 검증 세트 기준 성능이 가장 우수했던 모델 가중치를 로드하여, 학습에 쓰이지 않은 최종 테스트 데이터를 대상으로 평가를 수행합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "rvV_IQXoYjPP", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3c910ce4-0b8d-44d3-8177-cddd1c33b7e2" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test Loss: 1.8017\n", + "Test ROUGE-2 Score: 0.2670\n" + ] + } + ], + "source": [ + "# 저장된 최고의 가중치 가동\n", + "model = BartForConditionalGeneration.from_pretrained(\n", + " pretrained_model_name_or_path=\"facebook/bart-base\"\n", + ").to(device)\n", + "model.load_state_dict(torch.load(\"../models/BartForConditionalGeneration.pt\"))\n", + "\n", + "# 최종 평가 실행\n", + "test_loss, test_rouge_score = evaluation(model, test_dataloader)\n", + "print(f\"Test Loss: {test_loss:.4f}\")\n", + "print(f\"Test ROUGE-2 Score: {test_rouge_score:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g9i9ZBatYjPP" + }, + "source": [ + "## 6. 모델 요약 생성 문장 실제 비교 (교재 예제 7.23)\n", + "Hugging Face의 `pipeline` 추론 함수를 활용해 실제 뉴스 원문(Test 데이터)에 대한 요약문을 생성하고, 이를 정답 요약문과 직접 눈으로 비교하며 품질을 확인합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "Duxpci-UYjPP", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "4b670342-6534-4dd9-a4aa-3cbc9273d5d5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[1번 뉴스 요약 비교]\n", + "정답 요약문 : Clinton leads Trump by 4 points in Washington Post: ABC News poll\n", + "모델 요약문 : Clinton leads Trump by 4 percentage points in Washington Post poll\n", + "--------------------------------------------------\n", + "[2번 뉴스 요약 비교]\n", + "정답 요약문 : Democrats question independence of Trump Supreme Court nominee\n", + "모델 요약문 : U.S. senators question Gorsuch's ability to serve on Supreme Court\n", + "--------------------------------------------------\n", + "[3번 뉴스 요약 비교]\n", + "정답 요약문 : In push for Yemen aid, U.S. warned Saudis of threats in Congress\n", + "모델 요약문 : U.S. warns Saudi Arabia over Yemen humanitarian situation\n", + "--------------------------------------------------\n", + "[4번 뉴스 요약 비교]\n", + "정답 요약문 : Romanian ruling party leader investigated over 'criminal group'\n", + "모델 요약문 : Romanian anti-graft prosecutors probe ruling party's leader\n", + "--------------------------------------------------\n", + "[5번 뉴스 요약 비교]\n", + "정답 요약문 : Billionaire environmental activist Tom Steyer endorses Clinton\n", + "모델 요약문 : Steyer backs Hillary Clinton for U.S. president\n", + "--------------------------------------------------\n" + ] + } + ], + "source": [ + "from transformers import pipeline\n", + "\n", + "# 요약 파이프라인 정의 (인퍼런스를 위해 디바이스를 cpu 혹은 gpu로 매핑 가능)\n", + "summarizer = pipeline(\n", + " task=\"summarization\",\n", + " model=model,\n", + " tokenizer=tokenizer,\n", + " max_length=54,\n", + " device=0 if torch.cuda.is_available() else -1\n", + ")\n", + "\n", + "# 상위 5개 데이터 추출 및 비교 분석\n", + "for index in range(5):\n", + " news_text = test.text.iloc[index]\n", + " summarization = test.prediction.iloc[index]\n", + " predicted_summarization = summarizer(news_text)[0][\"summary_text\"]\n", + "\n", + " print(f\"[{index+1}번 뉴스 요약 비교]\")\n", + " print(f\"정답 요약문 : {summarization}\")\n", + " print(f\"모델 요약문 : {predicted_summarization}\")\n", + " print(\"-\" * 50)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "ae39e71510294570a5e05f7b96c7731c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6849bd5080bc4d7dadf42d09e4dca6a1", + "IPY_MODEL_6fe4dd581d0342298577887cab37ccf8", + "IPY_MODEL_85363b3181f249d28124ea5abc4903eb" + ], + "layout": "IPY_MODEL_ff3c56ca0bb44ebdb9fcf54fee98ed17" + } + }, + "6849bd5080bc4d7dadf42d09e4dca6a1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_db3b3b954a744518983eff72f4eef95b", + "placeholder": "​", + "style": "IPY_MODEL_dce445be829e4720beef6117f4a1ce24", + "value": "" + } + }, + "6fe4dd581d0342298577887cab37ccf8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_18ffd9f7b26b4c748e796092f969f790", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4741948e5ff2477e8c081c108c564434", + "value": 0 + } + }, + "85363b3181f249d28124ea5abc4903eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_48fe72a0503a4a9b91bdfef805bcd9b0", + "placeholder": "​", + "style": "IPY_MODEL_f4e9354c1b254a0fb6bd7bd641c8963f", + "value": " 0/0 [00:00<?, ?it/s]" + } + }, + "ff3c56ca0bb44ebdb9fcf54fee98ed17": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "db3b3b954a744518983eff72f4eef95b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dce445be829e4720beef6117f4a1ce24": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "18ffd9f7b26b4c748e796092f969f790": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "4741948e5ff2477e8c081c108c564434": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "48fe72a0503a4a9b91bdfef805bcd9b0": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f4e9354c1b254a0fb6bd7bd641c8963f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/README.md b/README.md index 2d9a624..3e7a827 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # 10th-Research -# 🐲EURON 10기 Research 세션🐲 +# 🐥EURON 10기 Research 세션🐥 ## 🗂️ Curriculum |주차|날짜|내용|발제자|필수 토의 참여자 |---|---|---|---|---| |0주차|2026/03/03|OT|| -|1주차|2026/03/10|1주차 논문스터디-ResNet||| -|2주차|2026/03/17|2주차 논문스터디-Transformer||| +|1주차|2026/03/10|1주차 논문스터디-ResNet|유다현, 장서연|고은서, 김윤서| +|2주차|2026/03/17|2주차 논문스터디-Transformer|고은서, 최지희|장서연, 박예나| |3주차|2026/03/24|3주차 논문스터디-VAE||| |4주차|2026/03/31|4주차 논문스터디-DQN||| |5주차|2026/04/07|5주차 논문스터디-BERT||| diff --git "a/Week13_\354\234\240\353\213\244\355\230\204_\354\230\210\354\212\265\352\263\274\354\240\234.pdf" "b/Week13_\354\234\240\353\213\244\355\230\204_\354\230\210\354\212\265\352\263\274\354\240\234.pdf" new file mode 100644 index 0000000..ee78049 Binary files /dev/null and "b/Week13_\354\234\240\353\213\244\355\230\204_\354\230\210\354\212\265\352\263\274\354\240\234.pdf" differ