diff --git "a/Week13_\354\230\210\354\212\265\352\263\274\354\240\234_\352\271\200\354\230\210\353\202\230.ipynb" "b/Week13_\354\230\210\354\212\265\352\263\274\354\240\234_\352\271\200\354\230\210\353\202\230.ipynb" new file mode 100644 index 0000000..dbafea9 --- /dev/null +++ "b/Week13_\354\230\210\354\212\265\352\263\274\354\240\234_\352\271\200\354\230\210\353\202\230.ipynb" @@ -0,0 +1,14060 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "BERT" + ], + "metadata": { + "id": "zkzsaAgjPO5W" + }, + "id": "zkzsaAgjPO5W" + }, + { + "cell_type": "code", + "source": [ + "!pip install Korpora" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 440 + }, + "id": "jigj2nVsPJnI", + "outputId": "d9729f8c-89f7-41da-c229-48f772d04bfa" + }, + "id": "jigj2nVsPJnI", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting Korpora\n", + " Downloading Korpora-0.2.0-py3-none-any.whl.metadata (26 kB)\n", + "Collecting dataclasses>=0.6 (from Korpora)\n", + " Downloading dataclasses-0.6-py3-none-any.whl.metadata (3.0 kB)\n", + "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.12/dist-packages (from Korpora) (2.0.2)\n", + "Requirement already satisfied: tqdm>=4.46.0 in /usr/local/lib/python3.12/dist-packages (from Korpora) (4.67.3)\n", + "Requirement already satisfied: requests>=2.20.0 in /usr/local/lib/python3.12/dist-packages (from Korpora) (2.32.4)\n", + "Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from Korpora) (2.0.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.20.0->Korpora) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.20.0->Korpora) (3.15)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.20.0->Korpora) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.20.0->Korpora) (2026.5.20)\n", + "Downloading Korpora-0.2.0-py3-none-any.whl (57 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.8/57.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading dataclasses-0.6-py3-none-any.whl (14 kB)\n", + "Installing collected packages: dataclasses, Korpora\n", + "Successfully installed Korpora-0.2.0 dataclasses-0.6\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "dataclasses" + ] + }, + "id": "67768d920ffb4454bdf0e98c15426240" + } + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0a583dd-df09-42a4-bed9-cbe3270fd368", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f0a583dd-df09-42a4-bed9-cbe3270fd368", + "outputId": "f162b6d5-068e-489e-bf3c-c843149706de" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + " Korpora 는 다른 분들이 연구 목적으로 공유해주신 말뭉치들을\n", + " 손쉽게 다운로드, 사용할 수 있는 기능만을 제공합니다.\n", + "\n", + " 말뭉치들을 공유해 주신 분들에게 감사드리며, 각 말뭉치 별 설명과 라이센스를 공유 드립니다.\n", + " 해당 말뭉치에 대해 자세히 알고 싶으신 분은 아래의 description 을 참고,\n", + " 해당 말뭉치를 연구/상용의 목적으로 이용하실 때에는 아래의 라이센스를 참고해 주시기 바랍니다.\n", + "\n", + " # Description\n", + " Author : e9t@github\n", + " Repository : https://github.com/e9t/nsmc\n", + " References : www.lucypark.kr/docs/2015-pyconkr/#39\n", + "\n", + " Naver sentiment movie corpus v1.0\n", + " This is a movie review dataset in the Korean language.\n", + " Reviews were scraped from Naver Movies.\n", + "\n", + " The dataset construction is based on the method noted in\n", + " [Large movie review dataset][^1] from Maas et al., 2011.\n", + "\n", + " [^1]: http://ai.stanford.edu/~amaas/data/sentiment/\n", + "\n", + " # License\n", + " CC0 1.0 Universal (CC0 1.0) Public Domain Dedication\n", + " Details in https://creativecommons.org/publicdomain/zero/1.0/\n", + "\n", + "[Korpora] Corpus `nsmc` is already installed at /root/Korpora/nsmc/ratings_train.txt\n", + "[Korpora] Corpus `nsmc` is already installed at /root/Korpora/nsmc/ratings_test.txt\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from Korpora import Korpora\n", + "\n", + "\n", + "corpus = Korpora.load(\"nsmc\")\n", + "df = pd.DataFrame(corpus.test).sample(20000, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aab87a20", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aab87a20", + "outputId": "d2b40d81-7e1f-4f7c-ec80-440a0baa9d73" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "| | text | label |\n", + "|------:|:---------------------------------------------------------|--------:|\n", + "| 26891 | 역시 코믹액션은 성룡, 홍금보, 원표 삼인방이 최고지!! | 1 |\n", + "| 25024 | 점수 후하게 줘야것네 별 반개~ | 0 |\n", + "| 11666 | 오랜만에 느낄수 있는 [감독] 구타욕구. | 0 |\n", + "| 40303 | 본지는 좀 됬지만 극장서 돈주고 본게 아직까지 아까운 영화 | 0 |\n", + "| 18010 | 징키스칸이란 소재를 가지고 이것밖에 못만드냐 | 0 |\n", + "Training Data Size : 12000\n", + "Validation Data Size : 4000\n", + "Testing Data Size : 4000\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/numpy/_core/fromnumeric.py:57: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.\n", + " return bound(*args, **kwds)\n" + ] + } + ], + "source": [ + "train, valid, test = np.split(\n", + " df.sample(frac=1, random_state=42), [int(0.6 * len(df)), int(0.8 * len(df))]\n", + ")\n", + "\n", + "print(train.head(5).to_markdown())\n", + "print(f\"Training Data Size : {len(train)}\")\n", + "print(f\"Validation Data Size : {len(valid)}\")\n", + "print(f\"Testing Data Size : {len(test)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61a66d8a-bb12-4b74-9ea7-529192ee562e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "61a66d8a-bb12-4b74-9ea7-529192ee562e", + "outputId": "ba2230f4-77d1-4ed3-eae6-f31165845005" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(tensor([ 101, 58466, 9812, 118956, 119122, 59095, 10892, 9434, 118888,\n", + " 117, 9992, 40032, 30005, 117, 9612, 37824, 9410, 12030,\n", + " 42337, 10739, 83491, 12508, 106, 106, 102, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0],\n", + " device='cuda:0'), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0], device='cuda:0'), tensor(1, device='cuda:0'))\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import BertTokenizer\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "from torch.utils.data import RandomSampler, SequentialSampler\n", + "\n", + "\n", + "def make_dataset(data, tokenizer, device):\n", + " tokenized = tokenizer(\n", + " text=data.text.tolist(),\n", + " padding=\"longest\",\n", + " truncation=True,\n", + " return_tensors=\"pt\"\n", + " )\n", + " input_ids = tokenized[\"input_ids\"].to(device)\n", + " attention_mask = tokenized[\"attention_mask\"].to(device)\n", + " labels = torch.tensor(data.label.values, dtype=torch.long).to(device)\n", + " return TensorDataset(input_ids, attention_mask, labels)\n", + "\n", + "\n", + "def get_datalodader(dataset, sampler, batch_size):\n", + " data_sampler = sampler(dataset)\n", + " dataloader = DataLoader(dataset, sampler=data_sampler, batch_size=batch_size)\n", + " return dataloader\n", + "\n", + "\n", + "epochs = 5\n", + "batch_size = 32\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "tokenizer = BertTokenizer.from_pretrained(\n", + " pretrained_model_name_or_path=\"bert-base-multilingual-cased\",\n", + " do_lower_case=False\n", + ")\n", + "\n", + "train_dataset = make_dataset(train, tokenizer, device)\n", + "train_dataloader = get_datalodader(train_dataset, RandomSampler, batch_size)\n", + "\n", + "valid_dataset = make_dataset(valid, tokenizer, device)\n", + "valid_dataloader = get_datalodader(valid_dataset, SequentialSampler, batch_size)\n", + "\n", + "test_dataset = make_dataset(test, tokenizer, device)\n", + "test_dataloader = get_datalodader(test_dataset, SequentialSampler, batch_size)\n", + "\n", + "print(train_dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b04f95af-2ed6-4ee4-a293-218280946df2", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 355, + "referenced_widgets": [ + "bb5b2007245a4ec0956acf8c8b4add2c", + "3f371eac218a4cecaf4f0e75c0f2e82e", + "2a6fd3ce99994f1b9041341c77fb41ba", + "df1192b72fb546ef9929d8d24a370fb8", + "ef7a20f97c87409f81ee765f797f65e1", + "95ca79c96f0349389a17b0977a9f5bdd", + "f9bf19ea697a4606b08a08625ce76c7d", + "ee30f4a55cfc4bc9ac4c97667c01f781", + "92ee632a4be14e6f9dc37a78fd7eaf9f", + "282d90ba69ca4703bf585425a6c20676", + "c9d7a2dd08b1443a818025fac2c59d3a" + ] + }, + "id": "b04f95af-2ed6-4ee4-a293-218280946df2", + "outputId": "9f4e33fa-6eb7-4d87-c429-f5d3f88ab7da" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/199 [00:00=2.0.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.0.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.0.2)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.2.2)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.67.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from evaluate) (3.7.0)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.70.16)\n", + "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2025.3.0)\n", + "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (1.16.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from evaluate) (26.2)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (from rouge_score) (3.9.1)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge_score) (1.17.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (3.29.0)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (18.1.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (6.0.3)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.13.5)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.5.0)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.28.1)\n", + "Requirement already satisfied: typer>=0.20.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.25.1)\n", + "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.15.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.15)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2026.5.20)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (8.4.0)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (1.5.3)\n", + "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (2025.11.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2026.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.2)\n", + "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (26.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.8.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.7.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.5.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.24.2)\n", + "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (4.13.0)\n", + "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (1.0.9)\n", + "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (0.16.0)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (1.5.4)\n", + "Requirement already satisfied: rich>=13.8.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (13.9.4)\n", + "Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (0.0.4)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (4.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (2.20.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (0.1.2)\n", + "Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: rouge_score\n", + " Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=26991bf05bd6360cc34a1613b894f9d63d142ffe04e36a886a12ed1e054b4909\n", + " Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70\n", + "Successfully built rouge_score\n", + "Installing collected packages: rouge_score, evaluate\n", + "Successfully installed evaluate-0.4.6 rouge_score-0.1.2\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "from datasets import load_dataset\n", + "\n", + "\n", + "news = load_dataset(\"argilla/news-summary\", split=\"test\")\n", + "df = news.to_pandas().sample(5000, random_state=42)[[\"text\", \"prediction\"]]\n", + "df[\"prediction\"] = df[\"prediction\"].map(lambda x: x[0][\"text\"])\n", + "train, valid, test = np.split(\n", + " df.sample(frac=1, random_state=42), [int(0.6 * len(df)), int(0.8 * len(df))]\n", + ")\n", + "\n", + "print(f\"Source News : {train.text.iloc[0][:200]}\")\n", + "print(f\"Summarization : {train.prediction.iloc[0][:50]}\")\n", + "print(f\"Training Data Size : {len(train)}\")\n", + "print(f\"Validation Data Size : {len(valid)}\")\n", + "print(f\"Testing Data Size : {len(test)}\")" + ], + "metadata": { + "id": "GeJyAfazPd4u", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 465, + "referenced_widgets": [ + "a379c25a8888457eb793210fce6749c1", + "c1509c1f622e4f5fa02ce19688d74d4f", + "57412aaabdac404e87b4d1324fe50714", + "055d827f4e27431985090722cd7d20c6", + "8e2fbe2ed05c4a90a73d5f92bbf02d0c", + "3aa99498eb5848a983f36875d77cbc74", + "98cae9d8f69b43b0bb214c45edf8a41b", + "5f2973ec80fa44eab05f7fe7aa30efd9", + "c32c27bdd40e40a5a5acdc4585fcfd43", + "25a74aa1a3984bbd93b65efdbf05a22b", + "61d7e39b2eee431786d735b76505fc51", + "7f9f719952764f89a7024e743b31a607", + "069d4b3c5ea84ce5a21f934bcd7a3b7f", + "e6b4a9756a094d8ca6d49d9941d3a16c", + "152c011f196947658a16b19af6e305ff", + "acee5e7e99914bb3a793e200f7d0991d", + "ecdff9135b7840c7a3bac3581b6b1c27", + "5b79140d1ad64a84850257e8a47edd63", + "124d2c9889e443489a3f7409a0a380e2", + "8ad008be5a1d49dabf1943f2ef765586", + "cc477b2bc9b343c78991bcb8492996de", + "9183e19e096f47639958ed0e7420165f", + "c38b48fec3fc49ee9043ce5176534867", + "e64bbbf0d37445d88b8b45748101f879", + "1bb1ad73b6b84010b487ee5c01237415", + "015d42196b4c4b208d769caffb5b251a", + "3ff82299702647b48b7ccabfea4a4de4", + "1ac963a8484c4188b2645464300d9ff4", + "70888fc3f38046e68dc925840382d39d", + "377064f8d87d41b39f2a754343ae0691", + "a67ff25639dd46159c3867c5096f76e5", + "255ba95b7f3d47618c0e175cff2a58bb", + "d923868cbcd94fb6ab15ba9d7ed7f0f4", + "f5d70eaf05944fe0bcf56dc43ac310d0", + "0c14fdba719f47ca8421430536b5cbf9", + "0156632be3d8468fbe35522fe06bf2ab", + "3086200abab44bf3aad2313a6d4a5ed2", + "f0980f31374d4ca284aaf7435adb6393", + "8f62f7c7fe144b3fb8a9300acda7bedb", + "6924da2b9ce94ed9b20f61be5d4b491a", + "5ff6e9f5dae945058a9b666023befe9b", + "d369af61630746a2998c4345978617da", + "4b5751ab3b6345d797a297374b5b76be", + "4759096f00d24fe7bbdb53d5080b94ea", + "fd0d85a2292840949b83f8408b6f4330", + "4a0d94db2bf6400791a87c4056a79547", + "9dfd49d0e03c45c79b4f386ac692be42", + "31fd9c7561d348c4a4bf5c470a9b3fbb", + "087a51e6f5be4ebdbe5b513724baf9a3", + "2037d6a4228444069422a38ad1820ba2", + "4b509266c7134ea08b63246897d56e4e", + "11a38fec630a4c068e2b1945e898ed8b", + "8eae634301b04949b7785f2ec5589e0a", + "eedddd83a6e6421fbb5694b320c0b8a7", + "8913279a84ac4a0bbdc349f47ec66891" + ] + }, + "outputId": "b2e9520b-f9ff-4099-fa84-adebc024a46f" + }, + "id": "GeJyAfazPd4u", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:112: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "README.md: 0%| | 0.00/2.02k [00:00=2.0.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.0.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.0.2)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.2.2)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from evaluate) (4.67.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from evaluate) (3.7.0)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from evaluate) (0.70.16)\n", + "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2025.3.0)\n", + "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from evaluate) (1.16.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from evaluate) (26.2)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (from rouge_score) (3.9.1)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge_score) (1.17.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (3.29.0)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (18.1.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets>=2.0.0->evaluate) (6.0.3)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.13.5)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.5.0)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.28.1)\n", + "Requirement already satisfied: typer>=0.20.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (0.25.1)\n", + "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.15.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.4.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (3.15)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->evaluate) (2026.5.20)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (8.4.0)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (1.5.3)\n", + "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.12/dist-packages (from nltk->rouge_score) (2025.11.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->evaluate) (2026.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.2)\n", + "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (26.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.8.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.7.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.5.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.24.2)\n", + "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (4.13.0)\n", + "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (1.0.9)\n", + "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub>=0.7.0->evaluate) (0.16.0)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (1.5.4)\n", + "Requirement already satisfied: rich>=13.8.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (13.9.4)\n", + "Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (0.0.4)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (4.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (2.20.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.8.0->typer>=0.20.0->huggingface-hub>=0.7.0->evaluate) (0.1.2)\n", + "Epoch 1: Train Loss: 1.6152 Val Loss: 1.8923 Val Rouge 0.2633\n", + "Saved the model weights\n", + "Epoch 2: Train Loss: 1.2504 Val Loss: 1.9201 Val Rouge 0.2361\n", + "Epoch 3: Train Loss: 0.9689 Val Loss: 2.1170 Val Rouge 0.2444\n", + "Epoch 4: Train Loss: 0.7942 Val Loss: 2.2227 Val Rouge 0.2433\n", + "Epoch 5: Train Loss: 0.5598 Val Loss: 2.3783 Val Rouge 0.2405\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model = BartForConditionalGeneration.from_pretrained(\n", + " pretrained_model_name_or_path=\"facebook/bart-base\"\n", + ").to(device)\n", + "model.load_state_dict(torch.load(\"../models/BartForConditionalGeneration.pt\"))\n", + "\n", + "test_loss, test_rouge_score = evaluation(model, test_dataloader)\n", + "print(f\"Test Loss : {test_loss:.4f}\")\n", + "print(f\"Test ROUGE-2 Score : {test_rouge_score:.4f}\")" + ], + "metadata": { + "id": "meNCXm_FPkOR", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85, + "referenced_widgets": [ + "c0d529a03b5e487a9f8ddd54b6ce7152", + "e022583a113d43cb9edb5a39f6dea0d7", + "628da7b360da41578e5aecb3cce85e7f", + "504334f312834edbbef9a49fca1dcc1d", + "7f8d297fe66147a7b7c73cf5a613b076", + "782874293a7d4f7790e64d0f0ed2c8c2", + "23edb7e65d924e5cbabe9670220da479", + "2064eee26c084991802a5b35b4a69187", + "ef0d1c4bd82947f2a7d66359ec972a79", + "8241412e7ffd4ee4b1e31186a09d848d", + "891f976d7f504ae28036ad1058f40974" + ] + }, + "outputId": "c1be94db-38dc-4627-f6c5-8b42c223a3b9" + }, + "id": "meNCXm_FPkOR", + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/259 [00:00 세션 다시 시작 > 아래 셀부터 실행" + ], + "metadata": { + "id": "4rq77NfBbByn" + }, + "id": "4rq77NfBbByn" + }, + { + "id": "ac52f320", + "cell_type": "code", + "metadata": { + "id": "ac52f320" + }, + "execution_count": null, + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", + "from datasets import load_dataset\n", + "from torch.optim import AdamW\n", + "from tqdm import tqdm" + ], + "outputs": [] + }, + { + "id": "393f4136", + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "393f4136", + "outputId": "c274cc6e-86e5-454b-ca2c-4a76852dfef0" + }, + "execution_count": null, + "source": [ + "# batch_size와 epochs를 조정해보세요!\n", + "batch_size = 16\n", + "epochs = 2\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# 데이터셋 로드\n", + "raw_datasets = load_dataset(\"sst2\")\n", + "raw_datasets" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 608, + "referenced_widgets": [ + "14fe1eeaab11475c92df6a6870535073", + "a3c3d32da6644f24973017815729cb3d", + "f09d349f2e0c4b18a21d46475326508c", + "52a68f446c5c422e8e2dc151328eb5aa", + "c9a6ecd1255c438ea5180b8d13dc629f", + "b69fa88a1b444866b2b8e61a102b0dbb", + "e25c01c9e40d4599a1da0e509cb3adc9", + "13892e0b62e245c78324be5b17e64b7c", + "fcc83677327848dd9a55437b10686d2d", + "1d754ea160a741a9a4a9ccfa0f825a01", + "bf2a1f4b25ac4d7a9e2bc00cc4124dbc", + "acbe4035bc8d4e69ad786a14c151425d", + "792b33b8a3824ca3a7b9abd830c27c59", + "d594cfc53392408dbf9ab3c615a82f5e", + "56ae63c230854bd6bd76abc8e5f7c4c0", + "e4ba9d50e9c24c66bfdf4bf792019765", + "52077253fce3457bb1e792a3d999ff91", + "b40e8c1c05be4a8fb83cd9810088fb24", + "d3ce86ed664849418dc2837552c8200d", + "86a5ea1c143f4a6ab27541b20f7ff8c5", + "2960e7c196a141e89ab34852d46dc182", + "1f30d54aa9c2439fb7315a9f64b2b8dc", + "51b55b21f30041dfaf97532b7b6ccc8d", + "2224f1d2d13c4b1c8e3fbc80d0c2042c", + "b7f797ad68954cc1ba15b3ae64caf791", + "a52e3856dd7a4a78b88106fe492accdc", + "6353b5169c8845ae85b68da0a5c7058a", + "b4ee587eae1d43789aaa1d20f0b4307c", + "4895f6731c6c4de182cc02ff68f677d8", + "c6cd6f8e82804f0690362a7fb3c7998e", + "3ac67929d3a94084981345b1827a881b", + "b9c0f920709c4e0ea0cecd3fe5d75964", + "09c68df66dc04e8e8b63d9c2d6e5f879", + "927b6e80d2384dda8651651ad772e21b", + "b963067f4245402d956fe372e4ade514", + "3116ba237f25483d851eeaad11fff007", + "48e077d1d1e34c328f702d821578644e", + "81a49b84afae4475bf5b7ffb25aa38fd", + "e65a6c75b62146e0803d71ab601d23ee", + "9ae9b0cb934247ce91c8099bfa9a9681", + "df859c0298b548c697f46a6402e163b3", + "b4e97304a0564ff99fab5ab935cff0a1", + "358d212bfe6e4a34ac3d88ef3e9fbdf1", + "ecf1bf46b8e34d598786246c67991ffd", + "ab35faf9cec640f7ad9be63c58739175", + "8ea3aaa4e26a4c3e8fab2ac13d4e1fa0", + "c8cf0e4c70d040128ab2c47eaa476fe1", + "0ca5efea7538463f84528ba21878e79b", + "e4ea5a6ad7ea42d1b8bb16d25c70a7de", + "033249a14bfd4c8eaa7173b51015615e", + "ec206cd0f6bc49c18cd9cdddf83512e7", + "46aef0739b764cf5acbbb79fba74867b", + "0963d9b3ca2849daa498083b188a99f2", + "9d3ec63100c940a98f719f8c2954ab97", + "61e15dcf1868406480b5ff9bf9116cb5", + "fed4f9c2ecad4060aed0685d131ff87b", + "56ed45db4ffd4b37ac03f81d4fbaf746", + "76b32bdbe500485987bcb53c062beb8e", + "5eefe8df23f646b2979fc00c4ac2f67a", + "83dd57665a304cbdbf1dc78bb30ad537", + "c705a67195bd4472b2135742379f66d0", + "2aa936544db944af89e3edec92fc0e57", + "8d5a2ed665ea4da797b79c8fb9fc8238", + "073ec7f730b94a8aa33257afd8ece784", + "216a4468b5e34229b0d75f3e2fa924ca", + "eecd550d39624edf87ac89d9824c2fd6", + "62dca9a20a9145f8a7c31a2606068d59", + "ec3cdd3e092343e095170fe410f7248b", + "3bf3b9b8e6d84cf9874bfbe7ae6ec19f", + "0a3589acea314f45a7ca24c0e014f487", + "40a3cb0679cd4670a9bed448f0dd16d1", + "a71f43c0a2874c63a082c8de425b6bbf", + "c238bb23d6bf4e3892527d7692b1cfa6", + "0e2a35f588f14ab4a407ed162288c019", + "f4a315c1c65843e2a83a8a1d15ab24c1", + "23f8a9922e6948f6b88b6a1268fc8986", + "96e898c42754452399a8070c97acd107" + ] + }, + "id": "QTVKkGiIflzk", + "outputId": "a335eac6-c541-49cb-f573-08211009218d" + }, + "id": "QTVKkGiIflzk", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "README.md: 0%| | 0.00/5.27k [00:00