diff --git a/app.log b/app.log new file mode 100644 index 0000000..e688b87 --- /dev/null +++ b/app.log @@ -0,0 +1,22 @@ +2026-03-24 14:58:30,184 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-24 15:01:29,011 - my_logger - INFO - Start training round 1 +2026-03-24 22:44:03,514 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-24 22:45:14,428 - my_logger - INFO - Start training round 1 +2026-03-25 09:13:23,571 - my_logger - INFO - Round 1 complete. Waiting for READY. +2026-03-25 09:13:29,550 - my_logger - INFO - Round 1 fully complete. +2026-03-25 09:13:32,331 - my_logger - INFO - Start training round 2 +2026-03-25 19:44:00,127 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-25 19:45:03,466 - my_logger - INFO - Start training round 1 +2026-03-25 19:47:07,712 - my_logger - INFO - Start training round 1 +2026-03-25 19:47:19,011 - my_logger - INFO - Start training round 1 +2026-03-25 19:49:41,646 - my_logger - INFO - Start training round 1 +2026-03-25 19:52:49,376 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-25 19:54:50,514 - my_logger - INFO - Start training round 1 +2026-03-25 19:56:16,092 - my_logger - INFO - Start training round 1 +2026-03-25 19:57:18,773 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-25 19:57:49,168 - my_logger - INFO - Start training round 1 +2026-03-25 20:05:50,761 - my_logger - INFO - Application start. Server waiting for [1, 1] clients. +2026-03-25 20:06:20,483 - my_logger - INFO - Start training round 1 +2026-03-26 04:53:12,204 - my_logger - INFO - Round 1 complete. Waiting for READY. +2026-03-26 04:53:17,911 - my_logger - INFO - Round 1 fully complete. +2026-03-26 04:53:20,860 - my_logger - INFO - Start training round 2 diff --git a/client - Copy.py b/client - Copy.py new file mode 100644 index 0000000..1f3e7b7 --- /dev/null +++ b/client - Copy.py @@ -0,0 +1,50 @@ +import pika +import uuid +import argparse +import yaml +import os + +import torch + +import src.Log +from src.RpcClient import RpcClient + +parser = argparse.ArgumentParser(description="Split learning framework") +parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1') +parser.add_argument('--device', type=str, required=False, help='Device of client') + +args = parser.parse_args() + +with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) + +client_id = uuid.uuid4() +address = config["rabbit"]["address"] +username = config["rabbit"]["username"] +password = config["rabbit"]["password"] +virtual_host = config["rabbit"]["virtual-host"] + +device = None +if args.device is None: + if torch.cuda.is_available(): + device = "cuda" + print(f"Using device: {torch.cuda.get_device_name(device)}") + else: + device = "cpu" + print(f"Using device: CPU") +else: + device = args.device + print(f"Using device: {device}") + +credentials = pika.PlainCredentials(username, password) +connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) +channel = connection.channel() + +if __name__ == "__main__": + src.Log.print_with_color("[>>>] Client sending registration message to server...", "red") + + data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"} + client = RpcClient(client_id, args.layer_id, channel, device) + client.send_to_server(data) + client.wait_response() + diff --git a/client.py b/client.py index 1f3e7b7..2d7f9ce 100644 --- a/client.py +++ b/client.py @@ -15,9 +15,8 @@ args = parser.parse_args() -with open('config.yaml', 'r') as file: +with open("config.yaml", "r", encoding="utf-8") as file: config = yaml.safe_load(file) - client_id = uuid.uuid4() address = config["rabbit"]["address"] username = config["rabbit"]["username"] @@ -37,7 +36,17 @@ print(f"Using device: {device}") credentials = pika.PlainCredentials(username, password) -connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) +# FIX: heartbeat=0 tắt timeout, tránh StreamLostError khi train lâu +connection = pika.BlockingConnection( + pika.ConnectionParameters( + host=address, + port=5672, + virtual_host=f'{virtual_host}', + credentials=credentials, + heartbeat=0, + blocked_connection_timeout=None, + ) +) channel = connection.channel() if __name__ == "__main__": @@ -46,5 +55,4 @@ data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"} client = RpcClient(client_id, args.layer_id, channel, device) client.send_to_server(data) - client.wait_response() - + client.wait_response() \ No newline at end of file diff --git a/config.yaml b/config.yaml index 6189867..203755e 100644 --- a/config.yaml +++ b/config.yaml @@ -1,12 +1,14 @@ name: SplitFedLLM + server: - global-round: 1 + global-round: 10 clients: - 1 - 1 cut-layers: 4 - model-name: Bert # GPT2/Llama/Bert - data-name: EMOTION # EMOTION/GSM8K + model-name: GPT2 # GPT2 / Llama / Bert + data-name: E2E # E2E / EMOTION / AG_NEWS / GSM8K + pretrained_path: GPT2.pt model: GPT2: n_block: 12 @@ -17,15 +19,15 @@ server: parameters: load: True save: True - validation: True + validation: False data-distribution: non-iid: False - num-sample: 500 - num-label: 10 + num-sample: 10000 + num-label: 1 dirichlet: alpha: 1 refresh-each-round: True - random-seed: 1 + random-seed: 42 rabbit: address: 127.0.0.1 @@ -34,14 +36,14 @@ rabbit: virtual-host: / log_path: . -debug_mode: True +debug_mode: False learning: - learning-rate: 0.00001 + learning-rate: 0.00005 weight-decay: 0.01 - batch-size: 2 - control-count: 1 - clip-grad-norm: 0.0 + batch-size: 8 + control-count: 2 + clip-grad-norm: 1.0 fine-tune: enable: True @@ -49,3 +51,14 @@ fine-tune: LoRA: r: 8 alpha: 16 + QLoRA: + r: 8 + alpha: 16 + bits: 4 + double_quant: True + +optimization: + flash_attention: False + precision: fp32 + quantize_hidden: False + gradient_checkpointing: True \ No newline at end of file diff --git a/convert.py b/convert.py new file mode 100644 index 0000000..5ab2029 --- /dev/null +++ b/convert.py @@ -0,0 +1,23 @@ +import torch +from transformers import GPT2LMHeadModel + +# load HF model +hf_model = GPT2LMHeadModel.from_pretrained("./gpt2_e2e_finetuned") + +hf_sd = hf_model.state_dict() +new_sd = {} + +for k, v in hf_sd.items(): + new_k = k + + # 🔥 remove prefix + if k.startswith("transformer."): + new_k = k.replace("transformer.", "") + + # 🔥 lm_head giữ nguyên + new_sd[new_k] = v + +# save +torch.save(new_sd, "GPT2.pt") + +print("Converted → GPT2.pt") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b8f6635..efa0dbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,19 @@ -torch +torch>=2.0.0 pika~=1.3.2 transformers>=4.36.2 datasets -peft +peft>=0.9.0 numpy nltk rouge_score + +# ── Tối ưu hóa ──────────────────────────────────────────────── +# QLoRA + INT8: quantize base model xuống 4-bit / 8-bit +bitsandbytes>=0.43.0 + +# Flash Attention 2: cần GPU Ampere+ (RTX 30xx, A100, H100,...) +# Cài thủ công nếu cần: pip install flash-attn --no-build-isolation +flash-attn>=2.5.0; sys_platform != "win32" + +# Accelerate: dùng chung với bitsandbytes & gradient checkpointing +accelerate>=0.27.0 diff --git a/server.py b/server.py index 4fa6369..9a424d3 100644 --- a/server.py +++ b/server.py @@ -10,7 +10,7 @@ args = parser.parse_args() -with open('config.yaml') as file: +with open("config.yaml", "r", encoding="utf-8") as file: config = yaml.safe_load(file) address = config["rabbit"]["address"] username = config["rabbit"]["username"] diff --git a/src/Optimizer.py b/src/Optimizer.py new file mode 100644 index 0000000..9adf242 --- /dev/null +++ b/src/Optimizer.py @@ -0,0 +1,321 @@ + + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from transformers.pytorch_utils import Conv1D + +try: + from peft import ( + LoraConfig, TaskType, get_peft_model, + prepare_model_for_kbit_training, + ) + from transformers import BitsAndBytesConfig + HAS_PEFT = True +except ImportError: + HAS_PEFT = False + +try: + import bitsandbytes as bnb + HAS_BNB_PKG = True +except ImportError: + HAS_BNB_PKG = False + + +def detect_fan_in_fan_out(model): + for module in model.modules(): + if isinstance(module, Conv1D): + return True + return False + +def build_qlora_config(qlora_cfg: dict,model, model_name: str): + + if not HAS_BNB_PKG or not HAS_PEFT: + return None, None + + bits = qlora_cfg.get("bits", 4) + double_quant = qlora_cfg.get("double_quant", True) + r = qlora_cfg.get("r", 8) + alpha = qlora_cfg.get("alpha", 16) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=(bits == 4), + load_in_8bit=(bits == 8), + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype = ( + torch.bfloat16 if torch.cuda.is_available() else torch.float16 +), + ) + + target_map = { + "GPT2": ["c_attn", "c_proj", "c_fc"], + "Llama": ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + "Bert": ["query", "key", "value", "dense"], + } + targets = target_map.get(model_name, ["query", "value"]) + + task = TaskType.SEQ_CLS if model_name == "Bert" else TaskType.CAUSAL_LM + fan_in = detect_fan_in_fan_out(model) + + lora_config = LoraConfig( + task_type=task, + r=r, + lora_alpha=alpha, + lora_dropout=0.05, + bias="none", + target_modules=targets, + fan_in_fan_out=fan_in, + ) + return bnb_config, lora_config + + +def apply_qlora(model, lora_config, model_name: str): + + if not HAS_PEFT: + return model + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=True + ) + model = get_peft_model(model, lora_config) + # Bert cần classifier trainable + if model_name == "Bert": + for param in model.classifier.parameters(): + param.requires_grad = True + model.print_trainable_parameters() + return model + + + +import torch.nn.functional as F + +def flash_scaled_dot_product(q, k, v, mask=None, dropout_p=0.0): + + + if not hasattr(F, "scaled_dot_product_attention"): + # fallback thủ công (giữ nguyên của bạn) + import math + scale = math.sqrt(q.size(-1)) + q_fp32 = q.float() + k_fp32 = k.float() + v_fp32 = v.float() + att = (q_fp32 @ k_fp32.transpose(-2, -1)) / scale + if mask is not None: + fill_val = torch.finfo(att.dtype).min / 2 + att = att.masked_fill(mask == 0, fill_val) + att = torch.softmax(att, dim=-1) + if dropout_p > 0.0: + att = F.dropout(att, p=dropout_p, training=True) + out = att @ v_fp32 + return out.to(q.dtype) + + # 👉 convert mask sang đúng format SDPA + attn_mask = None + if mask is not None: + attn_mask = ~mask.bool() # SDPA dùng attn_mask với True = masked + + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=mask is None + ) + +@contextmanager +def precision_context(precision: str, device: str = "cuda"): + + if not torch.cuda.is_available() or precision == "fp32": + yield + return + + dtype_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + if precision in dtype_map: + with torch.autocast(device_type="cuda", dtype=dtype_map[precision]): + yield + elif precision == "int8": + + with torch.autocast(device_type="cuda", dtype=torch.float16): + yield + else: + yield + + +def cast_model_precision(model: nn.Module, precision: str) -> nn.Module: + """ + Đổi dtype toàn bộ tham số model (dùng khi không có QLoRA). + int8 dùng bitsandbytes LinearInt8 thay thế nn.Linear. + """ + if precision == "fp16": + return model.half() + if precision == "bf16": + return model.to(torch.bfloat16) + if precision == "int8": + if HAS_BNB_PKG: + return _replace_linear_int8(model) + else: + print("[Optimizer] bitsandbytes chưa cài") + return model # fp32 mặc định + + +def _replace_linear_int8(model: nn.Module) -> nn.Module: + + for name, module in model.named_children(): + if isinstance(module, nn.Linear): + has_bias = module.bias is not None + new_layer = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + bias=has_bias, + has_fp16_weights=False, + threshold=6.0, # LLM.int8() outlier threshold + ) + new_layer.weight = nn.Parameter(module.weight.data) + if has_bias: + new_layer.bias = nn.Parameter(module.bias.data) + setattr(model, name, new_layer) + else: + _replace_linear_int8(module) + return model + + + + +def quantize_hidden(tensor: torch.Tensor) -> tuple: + """ + Nén hidden state từ FP16/FP32 → INT8 trước khi pickle + gửi. + Dùng symmetric per-tensor quantization: + scale = max(|x|) / 127 + q = round(x / scale).clamp(-127, 127).to(int8) + + Trả về (q_numpy: np.ndarray[int8], scale: float) + Kích thước giảm ~4x so với float32, ~2x so với float16. + """ + t = tensor.detach().float() + scale = t.abs().max().item() / 127.0 + if scale == 0.0: + scale = 1e-8 + q = (t / scale).round().clamp(-127, 127).to(torch.int8) + return q.cpu().numpy(), scale + + +def dequantize_hidden(q_numpy: np.ndarray, scale: float, + device: str, requires_grad: bool = False) -> torch.Tensor: + """ + Giải nén INT8 → FP32 / FP16 rồi đẩy lên đúng device. + """ + q = torch.from_numpy(q_numpy).float() * scale + q = q.to(device) + if requires_grad: + q.requires_grad_(True) + return q + + + +def enable_gradient_checkpointing(model: nn.Module) -> nn.Module: + """ + Bật gradient checkpointing cho bất kỳ model nào có method + gradient_checkpointing_enable() (HuggingFace convention), + hoặc tự áp dụng qua torch.utils.checkpoint. + """ + if hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable() + print("[Optimizer] gradient_checkpointing_enable() called.") + elif hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + # Wrap từng TransformerBlock / BertLayer / DecoderLayer + _wrap_checkpointing(model) + return model + + +def _wrap_checkpointing(model: nn.Module): + """ + Với custom model không phải HuggingFace, tìm và wrap các block + có tên chứa 'layer', 'block', 'decoder'. + """ + from torch.utils.checkpoint import checkpoint as ckpt + + KEYWORDS = ("layer", "block", "decoder") + + for name, module in model.named_children(): + if any(k in name.lower() for k in KEYWORDS): + orig_forward = module.forward + + def make_checkpointed(fwd=orig_forward): + def checkpointed_forward(*args, **kwargs): + def run(*a, **kw): + return fwd(*a, **kw) + return ckpt(run, *args, **kwargs, use_reentrant=False) + return checkpointed_forward + + module.forward = make_checkpointed() + else: + _wrap_checkpointing(module) + + + +def make_scaler(precision: str): + """ + Trả về GradScaler nếu precision == 'fp16', ngược lại None. + Dùng cùng optimizer.step(): + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + """ + if precision == "fp16" and torch.cuda.is_available(): + return torch.amp.GradScaler("cuda") + return None + + +class OptimizationBundle: + """ + Gói tất cả cấu hình tối ưu vào một object duy nhất để + truyền qua RpcClient → Ft_* một cách gọn gàng. + """ + def __init__(self, opt_cfg: dict): + + self.precision = opt_cfg.get("precision", "fp32") + self.quantize_hidden = opt_cfg.get("quantize_hidden", False) + self.gradient_checkpointing = opt_cfg.get("gradient_checkpointing", False) + self.scaler = make_scaler(self.precision) + self.flash_attention = opt_cfg.get("flash_attention", False) + def precision_ctx(self, device="cuda"): + return precision_context(self.precision, device) + + def quant(self, tensor: torch.Tensor): + """Nén hidden state nếu được bật.""" + if self.quantize_hidden: + return quantize_hidden(tensor) + # Trả về (numpy fp16, scale=None) để interface nhất quán + return tensor.detach().cpu().to(torch.float16).numpy(), None + + def dequant(self, q_numpy, scale, device, requires_grad=False): + """Giải nén hidden state.""" + if scale is not None: + return dequantize_hidden(q_numpy, scale, device, requires_grad) + # Không quantize: chỉ convert từ fp16 + t = torch.from_numpy(q_numpy.astype(np.float16)).to(device) + if requires_grad: + t.requires_grad_(True) + return t + + def step(self, loss, optimizer): + """ + Thực hiện backward + optimizer step, + tự động dùng GradScaler nếu FP16. + """ + if self.scaler is not None: + self.scaler.scale(loss).backward() + self.scaler.step(optimizer) + self.scaler.update() + else: + loss.backward() + optimizer.step() diff --git a/src/RpcClient.py b/src/RpcClient.py index 5657230..0b8e0e5 100644 --- a/src/RpcClient.py +++ b/src/RpcClient.py @@ -1,7 +1,8 @@ import time import pickle -import copy - +import re +import os +import torch import src.Log from src.fine_tune.GPT2 import Ft_GPT2 from src.fine_tune.Llama import Ft_Llama @@ -10,150 +11,332 @@ from src.model.GPT2 import GPT2 from src.model.Llama import Llama from src.model.Bert import Bert - +from src.Optimizer import ( + OptimizationBundle, + build_qlora_config, apply_qlora, + cast_model_precision, + enable_gradient_checkpointing, +) from peft import LoraConfig, TaskType, get_peft_model + +def _build_lora_config(fine_tune_config: dict, model_name: str): + r = fine_tune_config["LoRA"]["r"] + alpha = fine_tune_config["LoRA"]["alpha"] + if model_name == "GPT2": + return LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=r, lora_alpha=alpha, lora_dropout=0.05, bias="none", + target_modules=["c_attn", "c_proj", "c_fc"], + fan_in_fan_out=True, + ) + elif model_name == "Llama": + return LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=r, lora_alpha=alpha, lora_dropout=0.05, bias="none", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + ) + elif model_name == "Bert": + return LoraConfig( + task_type="SEQ_CLS", + r=r, lora_alpha=alpha, lora_dropout=0.1, bias="none", + target_modules=["query", "key", "value", "dense"], + ) + return None + + +def _merge_lora_into_base(base_sd: dict, lora_sd: dict, + lora_r: int, lora_alpha: int, + model_name: str = "GPT2") -> dict: + """ + - GPT2 (Conv1D): W=(in,out), A=(r,in), B=(out,r) → delta=(B@A).T * scale + - Llama/Bert (Linear): W=(out,in), A=(r,in), B=(out,r) → delta=(B@A) * scale + """ + merged = dict(base_sd) + scale = lora_alpha / lora_r + use_transpose = (model_name == "GPT2") # BUG 3 FIX + + lora_A, lora_B = {}, {} + for k, v in lora_sd.items(): + m = re.match(r'base_model\.model\.(.+)\.lora_([AB])\.default\.weight', k) + if m: + base_key = m.group(1) + ".weight" + if m.group(2) == "A": + lora_A[base_key] = v.float() + else: + lora_B[base_key] = v.float() + + applied = 0 + for base_key in lora_A: + if base_key in lora_B and base_key in merged: + A = lora_A[base_key] # (r, in) + B = lora_B[base_key] # (out, r) + W = merged[base_key].float() + delta = (B @ A).T * scale if use_transpose else (B @ A) * scale + merged[base_key] = (W + delta).to(merged[base_key].dtype) + applied += 1 + + if "lm_head.weight" not in merged and "wte.weight" in merged: + merged["lm_head.weight"] = merged["wte.weight"].clone() + + src.Log.print_with_color(f"[LoRA] Applied {applied} deltas vào base weights.", "green") + return merged + + class RpcClient: def __init__(self, client_id, layer_id, channel, device): - self.client_id = client_id - self.layer_id = layer_id - self.channel = channel - self.model_train = None - self.train_loader = None - self.device = device - - self.response = None - self.label_count = None + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + result = self.channel.queue_declare(queue="", exclusive=True) + self.callback_queue = result.method.queue + self.model_train = None + self.train_loader = None + self.response = None + self._refresh_loader = False def wait_response(self): - status = True - reply_queue_name = f'reply_{self.client_id}' - self.channel.queue_declare(reply_queue_name, durable=False) - while status: - method_frame, header_frame, body = self.channel.basic_get(queue=reply_queue_name, auto_ack=True) + reply_queue_name = f"reply_{self.client_id}" + self.channel.queue_declare(queue=reply_queue_name, durable=False) + while True: + method_frame, _, body = self.channel.basic_get( + queue=reply_queue_name, auto_ack=True + ) if body: status = self.response_message(body) - time.sleep(0.5) + if not status: + break + time.sleep(0.1) def response_message(self, body): self.response = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Client received: {self.response['message']}", "blue") + src.Log.print_with_color( + f"[<<<] Client received: {self.response['message']}", "blue" + ) action = self.response["action"] - state_dict = self.response["parameters"] - if action == "START": - model = None - model_name = self.response["model_name"] - cut_layers = self.response['cut_layers'] - # label_count = self.response['label_count'] - total_block = self.response['total_block'] - clip_grad_norm = self.response['clip_grad_norm'] - data_name = self.response["data_name"] - num_sample = self.response["num_sample"] - fine_tune_config = self.response['fine_tune_config'] - - batch_size = self.response["batch_size"] - lr = self.response["lr"] - weight_decay = self.response["weight_decay"] - control_count = self.response["control_count"] - - if model_name == 'GPT2': - self.model_train = Ft_GPT2(self.client_id, self.layer_id, self.channel, self.device) - elif model_name == 'Llama': - self.model_train = Ft_Llama(self.client_id, self.layer_id, self.channel, self.device) - elif model_name == 'Bert': - self.model_train = Ft_Bert(self.client_id, self.layer_id, self.channel, self.device) - - if fine_tune_config['name'] == 'LoRA': - if model_name == 'GPT2': - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=fine_tune_config['LoRA']['r'], lora_alpha=fine_tune_config['LoRA']['alpha'], lora_dropout=0.05, bias="none", - target_modules=["c_attn", "c_proj", "c_fc"], - fan_in_fan_out=True - ) - elif model_name == 'Llama': - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=fine_tune_config['LoRA']['r'], lora_alpha=fine_tune_config['LoRA']['alpha'], lora_dropout=0.05, bias="none", - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], - ) - elif model_name == 'Bert': - peft_config = LoraConfig( - task_type="SEQ_CLS", - r=8,lora_alpha=16,lora_dropout=0.1,bias="none", - target_modules=["query", "key", "value", "dense"] - ) - else: - peft_config = None - else: - peft_config = None - - if model_name == 'GPT2': - klass = GPT2 - elif model_name == 'Llama': - klass = Llama - elif model_name == 'Bert': - klass = Bert - else: - klass = globals()[f'GPT2'] + return self._handle_start() + elif action == "STOP": + return False + return True - if self.layer_id == 1: - model = klass(layer_id=1, n_block=cut_layers) + def _handle_start(self): + resp = self.response + model_name = resp["model_name"] + cut_layers = resp["cut_layers"] + total_block = resp["total_block"] + clip_grad_norm = resp["clip_grad_norm"] + data_name = resp["data_name"] + num_sample = resp["num_sample"] + fine_tune_config = resp["fine_tune_config"] + opt_cfg = resp.get("opt_config", {}) + batch_size = resp["batch_size"] + lr = resp["lr"] + weight_decay = resp["weight_decay"] + control_count = resp["control_count"] + refresh_each_round = resp.get("refresh_each_round", False) + + opt = OptimizationBundle(opt_cfg) + use_flash = getattr(opt, "flash_attention", False) + if use_flash and self.device == "cpu": + use_flash = False + + klass_map = {"GPT2": GPT2, "Llama": Llama, "Bert": Bert} + klass = klass_map.get(model_name, GPT2) + if self.layer_id == 1: + model = klass(layer_id=1, n_block=cut_layers, use_flash=use_flash) + else: + model = klass(layer_id=2, n_block=total_block - cut_layers, use_flash=use_flash) + + # ── 2. Load weights từ local file riêng của từng layer ─────────────── + + layer_file = f"{model_name}_layer{self.layer_id}.pt" + pretrained_file = f"{model_name}.pt" + base_file = f"{model_name}_base.pt" + load_file = layer_file if os.path.exists(layer_file) else pretrained_file + + # ACCUMULATION FIX: Snapshot base weights 1 lần duy nhất. + # Mọi round đều merge LoRA vào base gốc, không cộng dồn delta. + if not os.path.exists(base_file) and os.path.exists(pretrained_file): + import shutil + shutil.copy2(pretrained_file, base_file) + src.Log.print_with_color(f"[INFO] Created base snapshot: {base_file}", "green") + + if os.path.exists(load_file): + src.Log.print_with_color( + f"[INFO] Layer {self.layer_id}: Loading {load_file}", "green" + ) + full_sd = torch.load(load_file, map_location="cpu") + new_sd = {} + + if self.layer_id == 1: + for k, v in full_sd.items(): + if k.startswith("h."): + idx = int(k.split(".")[1]) + if idx < cut_layers: + new_sd[k] = v + elif k.startswith(("wte", "wpe")): + new_sd[k] = v else: - model = klass(layer_id=2, n_block=total_block-cut_layers) + for k, v in full_sd.items(): + if k.startswith("h."): + idx = int(k.split(".")[1]) + if idx >= cut_layers: + new_k = k.replace(f"h.{idx}", f"h.{idx - cut_layers}") + new_sd[new_k] = v + elif k.startswith(("ln_f", "lm_head")): + new_sd[k] = v - # Read parameters and load to model - if state_dict: - model.load_state_dict(state_dict) + missing, unexpected = model.load_state_dict(new_sd, strict=False) + src.Log.print_with_color( + f"[DEBUG] Layer {self.layer_id} missing={len(missing)} " + f"unexpected={len(unexpected)}", "blue" + ) + else: + src.Log.print_with_color( + f"[WARN] {load_file} not found — using random init", "yellow" + ) - if fine_tune_config['enable']: - if model_name == 'Bert': - model = get_peft_model(model, peft_config) - if self.layer_id == 2: - for param in model.classifier.parameters(): - param.requires_grad = True + # LoRA adapter + ft_name = fine_tune_config.get("name", "LoRA") + peft_config = None + + if fine_tune_config.get("enable", False): + if ft_name == "QLoRA": + _, peft_config = build_qlora_config( + fine_tune_config["QLoRA"], model, model_name + ) + if peft_config: + model = apply_qlora(model, peft_config, model_name) else: + peft_config = _build_lora_config(fine_tune_config, model_name) + if peft_config: + model = get_peft_model(model, peft_config) + else: + peft_config = _build_lora_config(fine_tune_config, model_name) + if peft_config: model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - model.to(self.device) + if model_name == "Bert" and self.layer_id == 2: + for p in model.classifier.parameters(): + p.requires_grad = True + model.print_trainable_parameters() - # Start training - if self.layer_id == 1: - if self.train_loader is None: - self.train_loader = dataloader(model_name, data_name, batch_size, num_sample, train=True) + if ft_name != "QLoRA": + model = cast_model_precision(model, opt.precision) + if opt.gradient_checkpointing: + model = enable_gradient_checkpointing(model) + + model.to(self.device) + + ft_map = {"GPT2": Ft_GPT2, "Llama": Ft_Llama, "Bert": Ft_Bert} + self.model_train = ft_map.get(model_name, Ft_GPT2)( + self.client_id, self.layer_id, self.channel, self.device + ) + + if refresh_each_round: + self.train_loader = None + + if self.layer_id == 1 and self.train_loader is None: + self.train_loader = dataloader( + model_name, data_name, batch_size, num_sample, train=True + ) - result, size = self.model_train.first_layer(model, lr, weight_decay, clip_grad_norm, - control_count, self.train_loader) + if self.layer_id == 1: + result, size = self.model_train.first_layer( + model, lr, weight_decay, clip_grad_norm, + control_count, self.train_loader, opt=opt + ) + else: + result, size = self.model_train.last_layer( + model, lr, weight_decay, clip_grad_norm, opt=opt + ) + if fine_tune_config.get("enable", False) and peft_config is not None: + if not result: + src.Log.print_with_color( + f"[WARN] Layer {self.layer_id}: Training failed, bỏ qua lưu LoRA.", + "yellow" + ) else: - result, size = self.model_train.last_layer(model, lr, weight_decay, clip_grad_norm) - - # Stop training, then send parameters to server - if fine_tune_config['enable']: - model = model.merge_and_unload() - - model_state_dict = copy.deepcopy(model.state_dict()) - - if self.device != "cpu": - for key in model_state_dict: - model_state_dict[key] = model_state_dict[key].to('cpu') - data = {"action": "UPDATE", "client_id": self.client_id, "layer_id": self.layer_id, - "result": result, "size": size, - "message": "Sent parameters to Server", "parameters": model_state_dict} - src.Log.print_with_color("[>>>] Client sent parameters to server", "red") - self.send_to_server(data) - return True - elif action == "STOP": - return False + lora_sd = { + k: v.detach().cpu().float() + for k, v in model.state_dict().items() + if "lora" in k.lower() + } + src.Log.print_with_color( + f"[INFO] Layer {self.layer_id}: {len(lora_sd)} LoRA keys collected.", + "green" + ) - def send_to_server(self, message): - self.response = None - self.channel.queue_declare('rpc_queue', durable=False) - self.channel.basic_publish(exchange='', - routing_key='rpc_queue', - body=pickle.dumps(message)) + # ACCUMULATION FIX: Luôn merge vào base gốc (không phải layer file đã có delta) + merge_src = base_file if os.path.exists(base_file) else pretrained_file + if os.path.exists(merge_src): + base_sd = torch.load(merge_src, map_location="cpu") + if "wte.weight" in base_sd: + lora_r = fine_tune_config.get("LoRA", {}).get("r", 8) + lora_alpha = fine_tune_config.get("LoRA", {}).get("alpha", 16) + merged_sd = _merge_lora_into_base( + base_sd, lora_sd, lora_r, lora_alpha, model_name + ) + + + partial_sd = {} + if self.layer_id == 1: + for k, v in merged_sd.items(): + if k.startswith(("wte", "wpe")): + partial_sd[k] = v + elif k.startswith("h."): + idx = int(k.split(".")[1]) + if idx < cut_layers: + partial_sd[k] = v + else: + for k, v in merged_sd.items(): + if k.startswith("h."): + idx = int(k.split(".")[1]) + if idx >= cut_layers: + new_k = k.replace(f"h.{idx}", f"h.{idx - cut_layers}") + partial_sd[new_k] = v + elif k.startswith(("ln_f", "lm_head")): + partial_sd[k] = v + + torch.save(partial_sd, layer_file) + size_mb = os.path.getsize(layer_file) / 1e6 + src.Log.print_with_color( + f"[INFO] Layer {self.layer_id}: Saved {len(partial_sd)} keys " + f"→ {layer_file} ({size_mb:.1f} MB)", "green" + ) + else: + src.Log.print_with_color( + f"[WARN] Layer {self.layer_id}: {merge_src} không có base weights.", + "yellow" + ) + else: + src.Log.print_with_color( + f"[WARN] Layer {self.layer_id}: Không tìm thấy {merge_src} để merge.", + "yellow" + ) + - return self.response + ready_msg = { + "action": "READY", + "client_id": self.client_id, + "layer_id": self.layer_id, + "message": f"Layer {self.layer_id} saved LoRA, ready for next round.", + } + src.Log.print_with_color( + f"[>>>] Layer {self.layer_id}: Gửi READY về server.", "red" + ) + self.send_to_server(ready_msg) + return True + + def send_to_server(self, message): + self.response = None + self.channel.queue_declare("rpc_queue", durable=False) + self.channel.basic_publish( + exchange="", routing_key="rpc_queue", body=pickle.dumps(message) + ) + return self.response \ No newline at end of file diff --git a/src/Server.py b/src/Server.py index 44946db..42fe8f4 100644 --- a/src/Server.py +++ b/src/Server.py @@ -1,3 +1,6 @@ +from src.model.GPT2 import GPT2 +from src.model.Llama import Llama +from src.model.Bert import Bert import torch import os import random @@ -5,301 +8,332 @@ import pickle import sys import numpy as np -import copy import src.Log import src.Utils - -from src.model.GPT2 import GPT2 -from src.model.Llama import Llama -from src.model.Bert import Bert from src.val.get_val import get_val + class Server: def __init__(self, config): - # RabbitMQ - address = config["rabbit"]["address"] - username = config["rabbit"]["username"] - password = config["rabbit"]["password"] + address = config["rabbit"]["address"] + username = config["rabbit"]["username"] + password = config["rabbit"]["password"] virtual_host = config["rabbit"]["virtual-host"] - self.model_name = config["server"]["model-name"] - self.data_name = config["server"]["data-name"] - self.total_clients = config["server"]["clients"] - self.cut_layers = config["server"]["cut-layers"] - self.global_round = config["server"]["global-round"] - self.round = self.global_round + self.model_name = config["server"]["model-name"] + self.data_name = config["server"]["data-name"] + self.total_clients = config["server"]["clients"] + self.cut_layers = config["server"]["cut-layers"] + self.global_round = config["server"]["global-round"] + self.round = self.global_round self.save_parameters = config["server"]["parameters"]["save"] self.load_parameters = config["server"]["parameters"]["load"] - self.validation = config["server"]["validation"] - - # Clients - self.total_block = config["server"]["model"][self.model_name]["n_block"] - self.batch_size = config["learning"]["batch-size"] - self.lr = config["learning"]["learning-rate"] - self.weight_decay = config["learning"]["weight-decay"] - self.control_count = config["learning"]["control-count"] - self.clip_grad_norm = config["learning"]["clip-grad-norm"] + self.validation = config["server"]["validation"] + + self.total_block = config["server"]["model"][self.model_name]["n_block"] + self.batch_size = config["learning"]["batch-size"] + self.lr = config["learning"]["learning-rate"] + self.weight_decay = config["learning"]["weight-decay"] + self.control_count = config["learning"]["control-count"] + self.clip_grad_norm = config["learning"]["clip-grad-norm"] self.data_distribution = config["server"]["data-distribution"] - # Data distribution - self.non_iid = self.data_distribution["non-iid"] - self.num_label = self.data_distribution["num-label"] - self.num_sample = self.data_distribution["num-sample"] - self.refresh_each_round = self.data_distribution["refresh-each-round"] - self.random_seed = config["server"]["random-seed"] - self.label_counts = None + self.non_iid = self.data_distribution["non-iid"] + self.num_label = self.data_distribution["num-label"] + self.num_sample = self.data_distribution["num-sample"] + self.refresh_each_round = self.data_distribution.get("refresh-each-round", False) + self.random_seed = config["server"].get("random-seed", 1) + + self.fine_tune_config = config["fine-tune"] + self.opt_config = config.get("optimization", {}) + self.config = config - # Fine tune config - self.fine_tune_config = config['fine-tune'] + self.model_params = { + "vocab_size": config["server"].get("vocab_size", 50257), + "n_embd": config["server"].get("n_embd", 768), + "n_layer": config["server"].get("n_layer", 12), + "n_head": config["server"].get("n_head", 12), + "pretrained_path": config["server"].get("pretrained_path", f"{self.model_name}.pt"), + } if self.random_seed: random.seed(self.random_seed) - log_path = config["log_path"] - + log_path = config["log_path"] credentials = pika.PlainCredentials(username, password) self.connection = pika.BlockingConnection( - pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) + pika.ConnectionParameters( + host=address, port=5672, + virtual_host=f"{virtual_host}", + credentials=credentials, + heartbeat=0, + blocked_connection_timeout=None, + ) + ) self.channel = self.connection.channel() - self.channel.queue_declare(queue='rpc_queue') + self.channel.queue_declare(queue="rpc_queue") - self.count_update = [0 for _ in range(len(self.total_clients))] + self.count_notify = 0 + self.count_ready = 0 self.register_clients = [0 for _ in range(len(self.total_clients))] - self.count_notify = 0 - self.responses = {} - self.list_clients = [] - self.round_result = True - - self.global_model_parameters = [[] for _ in range(len(self.total_clients))] - self.global_client_sizes = [[] for _ in range(len(self.total_clients))] - self.avg_state_dict = [] + self.responses = {} + self.list_clients = [] + self.round_result = True self.channel.basic_qos(prefetch_count=1) self.reply_channel = self.connection.channel() - self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) + self.channel.basic_consume(queue="rpc_queue", on_message_callback=self.on_request) debug_mode = config["debug_mode"] self.logger = src.Log.Logger(f"{log_path}/app.log", debug_mode) - self.logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.") - src.Log.print_with_color(f"Application start. Server is waiting for {self.total_clients} clients.", "green") + self.logger.log_info( + f"Application start. Server waiting for {self.total_clients} clients." + ) + src.Log.print_with_color( + f"Application start. Server waiting for {self.total_clients} clients.", "green" + ) + + # ── Helpers ─────────────────────────────────────────────────────────────── def distribution(self): + num_clients = sum(self.total_clients) if self.non_iid: - label_distribution = np.random.dirichlet([self.data_distribution["dirichlet"]["alpha"]] * self.num_label, - self.total_clients[0]) - - self.label_counts = (label_distribution * self.num_sample).astype(int) + label_dist = np.random.dirichlet( + [self.data_distribution["dirichlet"]["alpha"]] * self.num_label, + num_clients + ) + self.label_counts = (label_dist * self.num_sample).astype(int) else: - self.label_counts = np.full((self.total_clients[0], self.num_label), self.num_sample // self.num_label) + self.label_counts = np.full( + (num_clients, self.num_label), + self.num_sample // self.num_label + ) + + # ── Main handler ────────────────────────────────────────────────────────── def on_request(self, ch, method, props, body): - message = pickle.loads(body) - routing_key = props.reply_to - action = message["action"] + message = pickle.loads(body) + action = message["action"] client_id = message["client_id"] - layer_id = message["layer_id"] - self.responses[routing_key] = message + layer_id = message["layer_id"] + # ── REGISTER ────────────────────────────────────────────────────────── if action == "REGISTER": if (str(client_id), layer_id) not in self.list_clients: self.list_clients.append((str(client_id), layer_id)) - src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") - # Save messages from clients + src.Log.print_with_color( + f"[<<<] REGISTER from client {client_id} layer {layer_id}", "blue" + ) self.register_clients[layer_id - 1] += 1 - # If consumed all clients - Register for first time - if self.register_clients == self.total_clients: - src.Log.print_with_color("All clients are connected. Sending notifications.", "green") - + if all( + self.register_clients[i] >= self.total_clients[i] + for i in range(len(self.total_clients)) + ): + src.Log.print_with_color("All clients connected. Starting round 1.", "green") self.distribution() - - self.logger.log_info(f"Start training round 1") + self.logger.log_info("Start training round 1") self.notify_clients() - elif action == "NOTIFY": - src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") - message = {"action": "PAUSE", - "message": "Pause training and please send your parameters", - "parameters": None} + # Client báo hoàn thành train. Server gửi PAUSE để client biết có thể lưu LoRA. + elif action == "NOTIFY": + src.Log.print_with_color( + f"[<<<] NOTIFY from client {client_id} layer {layer_id}", "blue" + ) self.count_notify += 1 - if self.count_notify == self.total_clients[0]: + if self.count_notify == sum(self.total_clients): self.count_notify = 0 - src.Log.print_with_color(f"Received all the finish training notification", "yellow") - - for (client_id, layer_id) in self.list_clients: - self.send_to_response(client_id, pickle.dumps(message)) - - elif action == "UPDATE": - # self.distribution() - data_message = message["message"] - result = message["result"] - src.Log.print_with_color(f"[<<<] Received message from {client_id}: {data_message}", "blue") - - self.count_update[layer_id - 1] += 1 - if not result: - self.round_result = False - - # Save client's model parameters - if self.save_parameters and self.round_result: - model_state_dict = message["parameters"] - client_size = message["size"] - self.global_model_parameters[layer_id - 1].append(model_state_dict) - self.global_client_sizes[layer_id - 1].append(client_size) - - # If consumed all client's parameters - if self.count_update == self.total_clients: - src.Log.print_with_color("Collected all parameters.", "yellow") - if self.save_parameters and self.round_result: - - self.avg_all_parameters() - self.global_model_parameters = [[] for _ in range(len(self.total_clients))] - self.global_client_sizes = [[] for _ in range(len(self.total_clients))] - - self.count_update = [0 for _ in range(len(self.total_clients))] - # Test - if self.save_parameters and self.validation and self.round_result: - state_dict_full = self.concatenate() - self.avg_state_dict = [] - if not get_val(self.model_name, self.data_name, state_dict_full,self.logger): - self.logger.log_warning("Training failed!") - else: - # Save to files - torch.save(state_dict_full, f'{self.model_name}.pt') - self.round -= 1 - else: - self.round -= 1 - - # Start a new training round - self.round_result = True + current_round = self.global_round - self.round + 1 + src.Log.print_with_color( + f"All clients finished training round {current_round}.", "yellow" + ) + + # Gửi PAUSE → client nhận xong mới lưu LoRA rồi gửi READY + pause_msg = { + "action": "PAUSE", + "message": f"Round {current_round} done. Save your LoRA.", + "current_round": current_round, + "parameters": None, + } + for (cid, lid) in self.list_clients: + self.send_to_response(cid, pickle.dumps(pause_msg)) + + self.logger.log_info(f"Round {current_round} complete. Waiting for READY.") + + # Server chỉ gửi START round tiếp khi đủ tất cả READY. + elif action == "READY": + src.Log.print_with_color( + f"[<<<] READY from client {client_id} layer {layer_id}", "blue" + ) + self.count_ready += 1 + + if self.count_ready == sum(self.total_clients): + self.count_ready = 0 + self.round -= 1 + current_round = self.global_round - self.round + + src.Log.print_with_color( + f"All clients ready. Round {current_round} fully complete.", "green" + ) + self.logger.log_info(f"Round {current_round} fully complete.") + + # Merge GPT2_layer1.pt + GPT2_layer2.pt → GPT2.pt + self._merge_layer_files() if self.round > 0: - self.logger.log_info(f"Start training round {self.global_round - self.round + 1}") - if self.save_parameters: - self.notify_clients() - else: - self.notify_clients(register=False) + next_round = self.global_round - self.round + 1 + self.logger.log_info(f"Start training round {next_round}") + self.notify_clients() else: self.logger.log_info("Stop training !!!") self.notify_clients(start=False) sys.exit() - ch.basic_ack(delivery_tag=method.delivery_tag) + elif action == "UPDATE": + src.Log.print_with_color( + f"[WARN] Deprecated UPDATE from client {client_id} — ignored.", "yellow" + ) - def notify_clients(self, start=True, register=True): + try: + ch.basic_ack(delivery_tag=method.delivery_tag) + except Exception as e: + src.Log.print_with_color(f"[WARN] basic_ack failed: {e}", "yellow") + self._reconnect() - # Send message to clients when consumed all clients - if self.model_name == 'GPT2': - klass = GPT2 - elif self.model_name == 'Llama': - klass = Llama - elif self.model_name == 'Bert': - klass = Bert - else: - klass = globals()[f'{self.model_name}'] + def notify_clients(self, start=True): for (client_id, layer_id) in self.list_clients: - # Read parameters file - filepath = f'{self.model_name}.pt' - state_dict = None - - if start: - if self.load_parameters and register: - if os.path.exists(filepath): - full_state_dict = torch.load(filepath, weights_only=True) - - if layer_id == 1: - model = klass(layer_id=1, n_block=self.cut_layers) - state_dict = model.state_dict() - keys = state_dict.keys() - - for key in keys: - state_dict[key] = full_state_dict[key] - - else: - model = klass(layer_id=2, n_block=self.total_block - self.cut_layers) - state_dict = model.state_dict() - state_dict = src.Utils.change_keys(state_dict, self.cut_layers, True) - keys = state_dict.keys() - - for key in keys: - state_dict[key] = full_state_dict[key] - - state_dict =src.Utils.change_keys(state_dict, self.cut_layers, False) - src.Log.print_with_color(f"Load pretrain model successfully", "green") - - else: - src.Log.print_with_color(f"File {filepath} does not exist.", "yellow") - self.logger.log_info(f"File {filepath} does not exist.") - - src.Log.print_with_color(f"[>>>] Sent start training request to client {client_id}", "red") - - response = {"action": "START", - "message": "Server accept the connection!", - "parameters": copy.deepcopy(state_dict), - "cut_layers": self.cut_layers, - "total_block": self.total_block, - "model_name": self.model_name, - "data_name": self.data_name, - "num_sample": self.num_sample, - "control_count": self.control_count, - "batch_size": self.batch_size, - "lr": self.lr, - "weight_decay": self.weight_decay, - "clip_grad_norm": self.clip_grad_norm, - "fine_tune_config": self.fine_tune_config - } - self.send_to_response(client_id, pickle.dumps(response)) - + if not start: + self.send_to_response( + client_id, + pickle.dumps({"action": "STOP", "message": "Stop training!", "parameters": None}) + ) + src.Log.print_with_color(f"[>>>] STOP → client {client_id}", "red") + continue + response = { + "action": "START", + "message": "Server accept the connection!", + "parameters": None, + "cut_layers": self.cut_layers, + "total_block": self.total_block, + "model_name": self.model_name, + "data_name": self.data_name, + "num_sample": self.num_sample, + "control_count": self.control_count, + "batch_size": self.batch_size, + "lr": self.lr, + "weight_decay": self.weight_decay, + "clip_grad_norm": self.clip_grad_norm, + "fine_tune_config": self.fine_tune_config, + "opt_config": self.opt_config, + "refresh_each_round": self.refresh_each_round, + } + src.Log.print_with_color( + f"[>>>] START → client {client_id} layer {layer_id}", "red" + ) + self.send_to_response(client_id, pickle.dumps(response)) + + + def _merge_layer_files(self): + """ + Merge GPT2_layer1.pt (wte, wpe, h.0-3) và GPT2_layer2.pt (h.4-11, ln_f, lm_head) + thành GPT2.pt đầy đủ để round tiếp theo load. + Layer 2 keys cần được remap: h.0 → h.{cut_layers}, h.1 → h.{cut_layers+1}, ... + """ + f1 = f"{self.model_name}_layer1.pt" + f2 = f"{self.model_name}_layer2.pt" + + if not os.path.exists(f1) or not os.path.exists(f2): + src.Log.print_with_color( + f"[WARN] Merge skipped: {f1} exists={os.path.exists(f1)}, " + f"{f2} exists={os.path.exists(f2)}", "yellow" + ) + return + + sd1 = torch.load(f1, map_location="cpu") + sd2 = torch.load(f2, map_location="cpu") + + merged = {} + + for k, v in sd1.items(): + merged[k] = v + + for k, v in sd2.items(): + if k.startswith("h."): + parts = k.split(".") + old_idx = int(parts[1]) + new_idx = old_idx + self.cut_layers + new_k = ".".join([parts[0], str(new_idx)] + parts[2:]) + merged[new_k] = v else: - src.Log.print_with_color(f"[>>>] Sent stop training request to client {client_id}", "red") - response = {"action": "STOP", - "message": "Stop training!", - "parameters": None} - self.send_to_response(client_id, pickle.dumps(response)) + merged[k] = v - def start(self): - self.channel.start_consuming() - def send_to_response(self, client_id, message): - reply_queue_name = f'reply_{client_id}' - self.reply_channel.queue_declare(reply_queue_name, durable=False) + if "lm_head.weight" not in merged and "wte.weight" in merged: + merged["lm_head.weight"] = merged["wte.weight"].clone() - src.Log.print_with_color(f"[>>>] Sent notification to client {client_id}", "red") - self.reply_channel.basic_publish( - exchange='', - routing_key=reply_queue_name, - body=message + out_file = f"{self.model_name}.pt" + torch.save(merged, out_file) + src.Log.print_with_color( + f"[>>>] Merged {f1} + {f2} → {out_file} " + f"({len(merged)} keys)", "green" ) + self.logger.log_info(f"Merged layer files → {out_file}") + + def _reconnect(self): + src.Log.print_with_color("[>>>] Reconnecting to RabbitMQ...", "yellow") + try: + self.connection.close() + except Exception: + pass + credentials = pika.PlainCredentials( + self.config["rabbit"]["username"], + self.config["rabbit"]["password"] + ) + self.connection = pika.BlockingConnection( + pika.ConnectionParameters( + host=self.config["rabbit"]["address"], + port=5672, + virtual_host=self.config["rabbit"]["virtual-host"], + credentials=credentials, + heartbeat=0, + blocked_connection_timeout=None, + ) + ) + self.channel = self.connection.channel() + self.reply_channel = self.connection.channel() + self.channel.queue_declare(queue="rpc_queue") + self.channel.basic_qos(prefetch_count=1) + self.channel.basic_consume(queue="rpc_queue", on_message_callback=self.on_request) + src.Log.print_with_color("[>>>] Reconnected successfully.", "green") - def avg_all_parameters(self): - layer_sizes = self.global_client_sizes - layer_params = self.global_model_parameters - - for layer_idx, list_state_dicts in enumerate(layer_params): - list_sizes = layer_sizes[layer_idx] - if not list_state_dicts or not list_sizes: - self.avg_state_dict.append({}) - continue - avg_sd = src.Utils.fed_avg_state_dicts(list_state_dicts, weights=list_sizes) - self.avg_state_dict.append(avg_sd) - - def concatenate(self): - avg_layers = self.avg_state_dict - if not avg_layers: - print(f"Warning: don't has averaged layers, skipping.") - - full_dict = {} - for idx, layer_dict in enumerate(avg_layers): - if idx == 0: - sd = layer_dict - full_dict.update(copy.deepcopy(sd)) - else: - sd = src.Utils.change_keys(layer_dict, self.cut_layers, True) - full_dict.update(copy.deepcopy(sd)) + def start(self): + while True: + try: + self.channel.start_consuming() + except ( + pika.exceptions.ChannelWrongStateError, + pika.exceptions.StreamLostError, + pika.exceptions.ConnectionClosedByBroker, + Exception, + ) as e: + src.Log.print_with_color( + f"[WARN] Connection error: {e} — reconnecting...", "yellow" + ) + try: + self._reconnect() + except Exception as re: + src.Log.print_with_color(f"[ERROR] Reconnect failed: {re}", "red") + import time; time.sleep(5) - return full_dict + def send_to_response(self, client_id, message): + reply_queue_name = f"reply_{client_id}" + self.reply_channel.queue_declare(reply_queue_name, durable=False) + self.reply_channel.basic_publish( + exchange="", routing_key=reply_queue_name, body=message + ) \ No newline at end of file diff --git a/src/Utils.py b/src/Utils.py index 8f76063..258dca2 100644 --- a/src/Utils.py +++ b/src/Utils.py @@ -49,7 +49,7 @@ def change_keys(state_dict, num, increase=True): return new_state_dict -def fed_avg_state_dicts(state_dicts, weights = None): +def fed_avg_state_dicts(state_dicts, weights=None): num = len(state_dicts) if num == 0: raise ValueError("fed_avg_state_dicts: don't have any state_dict.") @@ -57,25 +57,35 @@ def fed_avg_state_dicts(state_dicts, weights = None): if weights is None: weights = [1.0] * num total_w = sum(weights) + if total_w == 0: + raise ValueError("fed_avg_state_dicts: tổng weight = 0.") all_keys = set().union(*(sd.keys() for sd in state_dicts)) avg_dict = {} for key in all_keys: + acc = None + acc_w = 0.0 # Fix: theo dõi tổng weight thực sự đóng góp cho key này - acc = None for sd, w in zip(state_dicts, weights): if key not in sd: continue t = sd[key].float() if torch.isnan(t).any(): - t = torch.nan_to_num(t) # zero-fill + t = torch.nan_to_num(t, nan=0.0) t = t * w - acc = t if acc is None else acc + t + acc = t if acc is None else acc + t + acc_w += w - avg = acc / total_w + # Fix: nếu key không có trong bất kỳ state_dict nào → bỏ qua + if acc is None or acc_w == 0: + continue + + avg = acc / acc_w # Fix: chia cho tổng weight thực tế của key, không total_w - orig = next(sd[key] for sd in state_dicts if key in sd) + orig = next((sd[key] for sd in state_dicts if key in sd), None) + if orig is None: + continue if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool): avg = avg.round().to(orig.dtype) else: diff --git a/src/__pycache__/Log.cpython-313.pyc b/src/__pycache__/Log.cpython-313.pyc new file mode 100644 index 0000000..30bbf19 Binary files /dev/null and b/src/__pycache__/Log.cpython-313.pyc differ diff --git a/src/__pycache__/Optimizer.cpython-310.pyc b/src/__pycache__/Optimizer.cpython-310.pyc new file mode 100644 index 0000000..4b5eb8c Binary files /dev/null and b/src/__pycache__/Optimizer.cpython-310.pyc differ diff --git a/src/__pycache__/Optimizer.cpython-313.pyc b/src/__pycache__/Optimizer.cpython-313.pyc new file mode 100644 index 0000000..e7740b9 Binary files /dev/null and b/src/__pycache__/Optimizer.cpython-313.pyc differ diff --git a/src/__pycache__/RpcClient.cpython-313.pyc b/src/__pycache__/RpcClient.cpython-313.pyc new file mode 100644 index 0000000..f5a8bd4 Binary files /dev/null and b/src/__pycache__/RpcClient.cpython-313.pyc differ diff --git a/src/__pycache__/Server.cpython-313.pyc b/src/__pycache__/Server.cpython-313.pyc new file mode 100644 index 0000000..6a59da5 Binary files /dev/null and b/src/__pycache__/Server.cpython-313.pyc differ diff --git a/src/__pycache__/Utils.cpython-313.pyc b/src/__pycache__/Utils.cpython-313.pyc new file mode 100644 index 0000000..d84d518 Binary files /dev/null and b/src/__pycache__/Utils.cpython-313.pyc differ diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e0edd48 Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/__pycache__/__init__.cpython-313.pyc b/src/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..ff6e90a Binary files /dev/null and b/src/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/convert.py b/src/convert.py new file mode 100644 index 0000000..f916861 --- /dev/null +++ b/src/convert.py @@ -0,0 +1,13 @@ +import torch +from transformers import GPT2LMHeadModel + +# load HF model +model = GPT2LMHeadModel.from_pretrained("./gpt2_e2e_finetuned") + +# lấy state_dict +state_dict = model.state_dict() + +# save về format server dùng +torch.save(state_dict, "GPT2.pt") + +print("Saved GPT2.pt") \ No newline at end of file diff --git a/src/dataset/__pycache__/EMOTION.cpython-313.pyc b/src/dataset/__pycache__/EMOTION.cpython-313.pyc new file mode 100644 index 0000000..30f2083 Binary files /dev/null and b/src/dataset/__pycache__/EMOTION.cpython-313.pyc differ diff --git a/src/dataset/__pycache__/GSM8K.cpython-313.pyc b/src/dataset/__pycache__/GSM8K.cpython-313.pyc new file mode 100644 index 0000000..1a9910e Binary files /dev/null and b/src/dataset/__pycache__/GSM8K.cpython-313.pyc differ diff --git a/src/dataset/__pycache__/__init__.cpython-313.pyc b/src/dataset/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..516a025 Binary files /dev/null and b/src/dataset/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/dataset/__pycache__/dataloader.cpython-313.pyc b/src/dataset/__pycache__/dataloader.cpython-313.pyc new file mode 100644 index 0000000..d60545d Binary files /dev/null and b/src/dataset/__pycache__/dataloader.cpython-313.pyc differ diff --git a/src/dataset/dataloader.py b/src/dataset/dataloader.py index fbfe2af..cfbd983 100644 --- a/src/dataset/dataloader.py +++ b/src/dataset/dataloader.py @@ -12,6 +12,64 @@ from torch.utils.data import DataLoader def dataloader(model_name =None, data_name=None, batch_size=None, distribution=500, train=True): + if data_name == 'E2E': + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + path = os.path.join("data/", "e2e_train.jsonl") + + with open(path, "r", encoding="utf-8") as f: + data = [json.loads(line) for line in f if line.strip()] + + random.shuffle(data) + + from torch.utils.data import Dataset + + class E2EDataset(Dataset): + def __init__(self, tokenizer, data, max_length=512): + self.tokenizer = tokenizer + self.data = data + self.max_length = max_length + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + ex = self.data[idx] + + prompt = f" {ex['input']} " + target = ex["output"] + " <|endoftext|>" + + full = prompt + " " + target + + enc = self.tokenizer( + full, + truncation=True, + padding="max_length", + max_length=self.max_length, + return_tensors="pt" + ) + + input_ids = enc["input_ids"].squeeze(0) + attention_mask = enc["attention_mask"].squeeze(0) + + labels = input_ids.clone() + + + prompt_len = len( + self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + ) + labels[:prompt_len] = -100 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels + } + + train_set = E2EDataset(tokenizer, data) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + + return train_loader if data_name == 'GSM8K': if model_name == 'GPT2': tokenizer = GPT2Tokenizer.from_pretrained("gpt2") @@ -34,7 +92,50 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5 print(f"{len(train_set)} train examples") - train_set = GSM8K(tokenizer, train_set, False) + from torch.utils.data import Dataset + + class E2EDataset(Dataset): + def __init__(self, tokenizer, data, max_length=512): + self.tokenizer = tokenizer + self.data = data + self.max_length = max_length + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + ex = self.data[idx] + + prompt = f"### Input:\n{ex['question']}\n\n### Output:\n" + target = ex["answer"] + " <|endoftext|>" + + full = prompt + target + + enc = self.tokenizer( + full, + truncation=True, + padding="max_length", + max_length=self.max_length, + return_tensors="pt" + ) + + input_ids = enc["input_ids"].squeeze(0) + attention_mask = enc["attention_mask"].squeeze(0) + + + labels = input_ids.clone() + + prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].squeeze(0) + prompt_len = len(self.tokenizer(prompt, add_special_tokens=False)["input_ids"]) + + labels[:prompt_len] = -100 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels + } + train_set = E2EDataset(tokenizer, train_set) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) return train_loader else: @@ -55,6 +156,28 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5 return test_loader if data_name == 'EMOTION': + + dataset = load_dataset( + 'emotion', + download_mode='reuse_dataset_if_exists', + cache_dir='./hf_cache' + ) + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + if train: + # EMOTION có 6 class (sadness/joy/love/anger/fear/surprise) + num_label = int(distribution / 6) + distribution = [num_label] * 6 + train_texts, train_labels = load_train_EMOTION(dataset, distribution) + train_set = EMOTIONDataset(train_texts, train_labels, tokenizer, max_length=128) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + return train_loader + else: + test_texts, test_label = load_test_EMOTION(2000, dataset) + test_set = EMOTIONDataset(test_texts, test_label, tokenizer, max_length=128) + test_loader = DataLoader(test_set, batch_size=100, shuffle=False) + return test_loader + + elif data_name == 'AG_NEWS': dataset = load_dataset( 'ag_news', download_mode='reuse_dataset_if_exists', @@ -63,7 +186,7 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5 tokenizer = BertTokenizer.from_pretrained('bert-base-cased') if train: num_label = int(distribution / 4) - distribution = [num_label, num_label, num_label, num_label] + distribution = [num_label] * 4 train_texts, train_labels = load_train_EMOTION(dataset, distribution) train_set = EMOTIONDataset(train_texts, train_labels, tokenizer, max_length=128) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) @@ -72,4 +195,4 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5 test_texts, test_label = load_test_EMOTION(2000, dataset) test_set = EMOTIONDataset(test_texts, test_label, tokenizer, max_length=128) test_loader = DataLoader(test_set, batch_size=100, shuffle=False) - return test_loader + return test_loader \ No newline at end of file diff --git a/src/fine_tune/Bert.py b/src/fine_tune/Bert.py index 4154a40..dcb359c 100644 --- a/src/fine_tune/Bert.py +++ b/src/fine_tune/Bert.py @@ -7,229 +7,249 @@ import torch.nn as nn import src.Log +from src.Optimizer import OptimizationBundle + + class Ft_Bert: def __init__(self, client_id, layer_id, channel, device): - self.client_id = client_id - self.layer_id = layer_id - self.channel = channel - self.device = device + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device self.data_count = 0 - self.size = None - - def send_intermediate_output(self, data_id, output, labels, trace): - - forward_queue_name = f'intermediate_queue_{self.layer_id}' - self.channel.queue_declare(forward_queue_name, durable=False) - - if trace: - trace.append(self.client_id) - message = pickle.dumps( - {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace} - ) - else: - message = pickle.dumps( - {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id]} - ) + self.size = None + + # ── Giao tiếp ──────────────────────────────────────────── + + def send_intermediate_output(self, data_id, q_numpy, scale, labels, trace): + fwd_q = f"intermediate_queue_{self.layer_id}" + self.channel.queue_declare(fwd_q, durable=False) + trace_out = list(trace) + [self.client_id] if trace else [self.client_id] + msg = pickle.dumps({ + "data_id": data_id, + "data": q_numpy, + "scale": scale, + "label": labels.cpu(), + "trace": trace_out, + }) if self.size is None: - self.size = len(message) - print(f'Length message: {self.size} (bytes).') - self.channel.basic_publish( - exchange='', - routing_key=forward_queue_name, - body=message - ) + self.size = len(msg) + print(f"Length message: {self.size} bytes.") + self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg) def send_gradient(self, data_id, gradient, trace): - to_client_id = trace[-1] - trace.pop(-1) - backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) - - message = pickle.dumps( - {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace}) - + to_id = trace[-1] + trace = trace[:-1] + bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}" + self.channel.queue_declare(queue=bwd_q, durable=False) + msg = pickle.dumps({ + "data_id": data_id, + "data": gradient.detach().cpu().numpy(), + "trace": trace, + }) if self.size is None: - self.size = len(message) - print(f'Length message: {self.size} (bytes).') + self.size = len(msg) + self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg) + + def send_to_server(self, message): + self.channel.queue_declare("rpc_queue", durable=False) self.channel.basic_publish( - exchange='', - routing_key=backward_queue_name, - body=message + exchange="", routing_key="rpc_queue", body=pickle.dumps(message) ) - def send_to_server(self, message): - self.channel.queue_declare('rpc_queue', durable=False) - self.channel.basic_publish(exchange='', - routing_key='rpc_queue', - body=pickle.dumps(message)) + # ── Layer 1 ─────────────────────────────────────────────── + + def first_layer(self, model, lr, weight_decay, clip_grad_norm, + control_count=1, train_loader=None, + opt: OptimizationBundle = None): - def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1, train_loader=None): - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + if opt is None: + opt = OptimizationBundle({}) - backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}" + self.channel.queue_declare(queue=bwd_q_name, durable=False) self.channel.basic_qos(prefetch_count=1) model = model.to(self.device) - data_iter = iter(train_loader) - forward = [] - backward = [] - comm = [] - num_forward = 0 - num_backward = 0 - end_data = False - data_store = {} + data_iter = iter(train_loader) + forward_t = [] + backward_t = [] + comm_t = [] + num_fwd = num_bwd = 0 + end_data = False + data_store = {} # {uuid: input_ids_tensor} with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: while True: - # Training model model.train() - optimizer.zero_grad() - # Process gradient - method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + # ── Backward pass nếu có gradient ────────── + method_frame, _, body = self.channel.basic_get( + queue=bwd_q_name, auto_ack=True + ) if method_frame and body: - num_backward += 1 - received_data = pickle.loads(body) - gradient_numpy = received_data["data"] - gradient = torch.tensor(gradient_numpy).to(self.device) - data_id = received_data["data_id"] - - data_input = data_store.pop(data_id) - start_backward = time.time() - output = model(input_ids=data_input) + num_bwd += 1 + recv = pickle.loads(body) + gradient = torch.tensor(recv["data"]).to(self.device) + data_id = recv["data_id"] + data_input = data_store.pop(data_id) # Fix Bug 7: xóa ngay sau dùng + + t0 = time.time() + optimizer.zero_grad() + with opt.precision_ctx(self.device): + output = model(input_ids=data_input) + + # Fix Bug 1: chỉ gọi backward 1 lần với gradient từ layer sau + # KHÔNG gọi opt.step() ở đây vì đây là split backward output.backward(gradient=gradient) - optimizer.step() - time_backward = time.time() - start_backward - backward.append(time_backward) + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + + # Fix Bug 1: scaler.step nếu FP16, không thì optimizer.step thường + if opt.scaler is not None: + opt.scaler.step(optimizer) + opt.scaler.update() + else: + optimizer.step() + + backward_t.append(time.time() - t0) + else: - # speed control + # Fix Bug 7: giới hạn data_store để tránh OOM if len(data_store) > control_count: continue - try: - batch = next(data_iter) - input_ids = batch['input_ids'].to(self.device) - labels = batch['labels'].to(self.device) + batch = next(data_iter) + inp_ids = batch["input_ids"].to(self.device) + labels = batch["labels"].to(self.device) data_id = uuid.uuid4() - data_store[data_id] = input_ids + data_store[data_id] = inp_ids - start_forward = time.time() - intermediate_output = model(input_ids=input_ids) - time_forward = time.time() - start_forward - forward.append(time_forward) - intermediate_output = intermediate_output.detach().requires_grad_(True) + t0 = time.time() + with opt.precision_ctx(self.device): + inter_out = model(input_ids=inp_ids) + forward_t.append(time.time() - t0) - num_forward += 1 + inter_out = inter_out.detach().requires_grad_(True) + num_fwd += 1 self.data_count += 1 - pbar.update(1) - start_comm = time.time() - self.send_intermediate_output(data_id, intermediate_output, labels, trace=None) - time_comm = time.time() - start_comm - comm.append(time_comm) + + t0 = time.time() + q_numpy, scale = opt.quant(inter_out) + self.send_intermediate_output(data_id, q_numpy, scale, labels, trace=None) + comm_t.append(time.time() - t0) except StopIteration: end_data = True - if end_data and (num_forward == num_backward): + if end_data and num_fwd == num_bwd: break - - notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, - "message": "Finish training!"} - + notify = { + "action": "NOTIFY", "client_id": self.client_id, + "layer_id": self.layer_id, "message": "Finish training!", + } src.Log.print_with_color("[>>>] Finish training!", "red") - self.send_to_server(notify_data) + self.send_to_server(notify) - broadcast_queue_name = f'reply_{self.client_id}' + bcast_q = f"reply_{self.client_id}" while True: - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": - print(f'Forward time: {forward}s.') - print(f'Backward time: {backward}s.') - print(f'Comm time: {comm}s.') + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] {recv}", "blue") + if recv["action"] == "PAUSE": + print(f"Forward: {forward_t}") + print(f"Backward: {backward_t}") + print(f"Comm: {comm_t}") return True, self.data_count time.sleep(0.5) - def last_layer(self, model, lr, weight_decay, clip_grad_norm): - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) - criterion = nn.CrossEntropyLoss() - result = True + # ── Layer 2 ─────────────────────────────────────────────── + + def last_layer(self, model, lr, weight_decay, clip_grad_norm, + opt: OptimizationBundle = None): + + if opt is None: + opt = OptimizationBundle({}) - forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' - self.channel.queue_declare(queue=forward_queue_name, durable=False) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + criterion = nn.CrossEntropyLoss() + result = True + + fwd_q_name = f"intermediate_queue_{self.layer_id - 1}" + self.channel.queue_declare(queue=fwd_q_name, durable=False) self.channel.basic_qos(prefetch_count=1) - print('Waiting for intermediate output. To exit press CTRL+C') + print("Waiting for intermediate output. To exit press CTRL+C") model.to(self.device) model.train() - execute = [] - comm = [] + + exec_t = [] + comm_t = [] while True: - method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + method_frame, _, body = self.channel.basic_get( + queue=fwd_q_name, auto_ack=True + ) if method_frame and body: - optimizer.zero_grad() - received_data = pickle.loads(body) - intermediate_output_numpy = received_data["data"] - trace = received_data["trace"] - data_id = received_data["data_id"] - labels = received_data["label"].to(self.device) - start_exec = time.time() - intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + recv = pickle.loads(body) + trace = recv["trace"] + data_id = recv["data_id"] + labels = recv["label"].to(self.device) - output = model(input_ids=intermediate_output) + inter_out = opt.dequant( + recv["data"], recv.get("scale"), self.device, requires_grad=True + ) - loss = criterion(output, labels) + t0 = time.time() + with opt.precision_ctx(self.device): + output = model(input_ids=inter_out) + loss = criterion(output, labels) - if torch.isnan(loss).any(): + if torch.isnan(loss): src.Log.print_with_color("NaN detected in loss", "yellow") result = False - print(f"Loss: {loss.item()}") - intermediate_output.retain_grad() - loss.backward() + print(f"Loss: {loss.item():.4f}") + + # Fix Bug 3: retain_grad() TRƯỚC backward + inter_out.retain_grad() + + # Fix Bug 2: backward + clip + step chỉ 1 lần, không lồng opt.step + if opt.scaler is not None: + opt.scaler.scale(loss).backward() + opt.scaler.unscale_(optimizer) + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + opt.scaler.step(optimizer) + opt.scaler.update() + else: + loss.backward() + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + optimizer.step() - optimizer.step() - time_exec = time.time() - start_exec - execute.append(time_exec) + exec_t.append(time.time() - t0) self.data_count += 1 - gradient = intermediate_output.grad - start_comm = time.time() - self.send_gradient(data_id, gradient, trace) - time_comm = time.time() - start_comm - comm.append(time_comm) + # Fix Bug 3: inter_out.grad giờ đã có sau backward + t0 = time.time() + self.send_gradient(data_id, inter_out.grad, trace) + comm_t.append(time.time() - t0) else: - broadcast_queue_name = f'reply_{self.client_id}' - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + bcast_q = f"reply_{self.client_id}" + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": - print(f'Forward + Backward: {execute}s') - print(f'Comm time: {comm}s') + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] {recv}", "blue") + if recv["action"] == "PAUSE": + print(f"Exec: {exec_t}") + print(f"Comm: {comm_t}") return result, self.data_count - def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None): - pass - - def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None): - pass - - - - - - - - - - - - + def train_on_middle_layer(self, *args, **kwargs): pass + def alone_training(self, *args, **kwargs): pass diff --git a/src/fine_tune/GPT2.py b/src/fine_tune/GPT2.py index ce6cbf1..e31a988 100644 --- a/src/fine_tune/GPT2.py +++ b/src/fine_tune/GPT2.py @@ -7,244 +7,337 @@ import torch.nn as nn import src.Log +from src.Optimizer import OptimizationBundle from transformers import GPT2Tokenizer +from torch.optim.lr_scheduler import CosineAnnealingLR + + class Ft_GPT2: def __init__(self, client_id, layer_id, channel, device): - self.client_id = client_id - self.layer_id = layer_id - self.channel = channel - self.device = device + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device self.data_count = 0 - self.gpu_cpu = [] - self.encode = [] - self.tokenizer = None - - def send_intermediate_output(self, data_id, output, attention_mask, labels, trace): - - forward_queue_name = f'intermediate_queue_{self.layer_id}' - self.channel.queue_declare(forward_queue_name, durable=False) - - if trace: - trace.append(self.client_id) - start_cpu = time.time() - output = output.detach().cpu().numpy() - labels = labels.cpu() - - self.gpu_cpu.append(time.time() - start_cpu) - message = pickle.dumps( - {"data_id": data_id, "data": output, "label": labels, "trace": trace, - "attention_mask": attention_mask.cpu()} - ) - else: - message = pickle.dumps( - {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id], - "attention_mask" :attention_mask.cpu()} - ) - print(f'len message : {len(message)} bytes') + def send_intermediate_output(self, data_id, q_numpy, scale, mask_out, labels, trace): + fwd_q = f"intermediate_queue_{self.layer_id}" + trace_out = list(trace) + [self.client_id] if trace else [self.client_id] + self.channel.queue_declare(fwd_q, durable=False) + msg = pickle.dumps({ + "data_id": data_id, + "data": q_numpy, + "scale": scale, + "label": labels.cpu(), + "trace": trace_out, + "mask": mask_out.cpu(), + }) + print(f"len message: {len(msg)} bytes") + self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg) + + def send_end_signal(self): + fwd_q = f"intermediate_queue_{self.layer_id}" + self.channel.queue_declare(fwd_q, durable=False) self.channel.basic_publish( - exchange='', - routing_key=forward_queue_name, - body=message + exchange="", routing_key=fwd_q, + body=pickle.dumps({"action": "END"}) ) + src.Log.print_with_color("[>>>] Client 1 gửi END signal cho client 2", "yellow") def send_gradient(self, data_id, gradient, trace): - to_client_id = trace[-1] - trace.pop(-1) - backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) - - message = pickle.dumps( - {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace}) + to_id = trace[-1] + trace = trace[:-1] + bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}" + self.channel.queue_declare(queue=bwd_q, durable=False) + msg = pickle.dumps({ + "data_id": data_id, + "data": gradient.detach().cpu().numpy(), + "trace": trace, + }) + self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg) + def send_to_server(self, message): + self.channel.queue_declare("rpc_queue", durable=False) self.channel.basic_publish( - exchange='', - routing_key=backward_queue_name, - body=message + exchange="", routing_key="rpc_queue", body=pickle.dumps(message) ) - def send_to_server(self, message): - self.channel.queue_declare('rpc_queue', durable=False) - self.channel.basic_publish(exchange='', - routing_key='rpc_queue', - body=pickle.dumps(message)) - def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1, - train_loader=None): - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) - backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) - self.channel.basic_qos(prefetch_count=1) + def first_layer(self, model, lr, weight_decay, clip_grad_norm, + control_count=1, train_loader=None, + opt: OptimizationBundle = None): + + if opt is None: + opt = OptimizationBundle({}) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}" + self.channel.queue_declare(queue=bwd_q_name, durable=False) + self.channel.basic_qos(prefetch_count=1) model = model.to(self.device) - forward_time = [] - backward_time = [] - comm_time = [] - - for i in range(1): - data_iter = iter(train_loader) - num_forward = 0 - num_backward = 0 - end_data = False - data_store = {} - - with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: - while True: - # Training model - model.train() - optimizer.zero_grad() - # Process gradient - method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) - if method_frame and body: - num_backward += 1 - received_data = pickle.loads(body) - gradient_numpy = received_data["data"] - gradient = torch.tensor(gradient_numpy).to(self.device) - data_id = received_data["data_id"] - - data_input = data_store.pop(data_id) - start_backward = time.time() - output, mask = model(input_ids=data_input[0], attention_mask=data_input[1]) - output.backward(gradient=gradient) - optimizer.step() - stop_backward = time.time() - backward_time.append(stop_backward - start_backward) - else: - # speed control - if len(data_store) >= control_count: - continue - - try: - batch = next(data_iter) - input_ids = batch['input_ids'].to(self.device) - attention_mask = batch['attention_mask'].to(self.device) - labels = batch['labels'].to(self.device) - data_id = uuid.uuid4() - data_store[data_id] = (input_ids, attention_mask) - start_forward = time.time() - intermediate_output, mask = model(input_ids=input_ids, attention_mask=attention_mask) - stop_forward = time.time() - forward_time.append(stop_forward - start_forward) - intermediate_output = intermediate_output.detach().requires_grad_(True) - - num_forward += 1 - self.data_count += 1 - - pbar.update(1) - start_comm = time.time() - self.send_intermediate_output(data_id, intermediate_output, mask, - labels, trace=None) - stop_comm = time.time() - comm_time.append(start_comm - stop_comm) - - except StopIteration: - end_data = True - - if end_data and (num_forward == num_backward): - break - - notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, - "message": "Finish training!"} - - src.Log.print_with_color("[>>>] Finish training!", "red") - self.send_to_server(notify_data) - - broadcast_queue_name = f'reply_{self.client_id}' - while True: # Wait for broadcast - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + forward_t = [] + backward_t = [] + comm_t = [] + data_iter = iter(train_loader) + num_fwd = num_bwd = 0 + end_data = False + data_store = {} + scheduler = CosineAnnealingLR( + optimizer, T_max=max(len(train_loader), 1), eta_min=lr / 10 + ) + + with tqdm(total=len(train_loader), desc="Layer 1", unit="step") as pbar: + while True: + model.train() + + method_frame, _, body = self.channel.basic_get( + queue=bwd_q_name, auto_ack=True + ) + if method_frame and body: + num_bwd += 1 + recv = pickle.loads(body) + gradient = torch.tensor(recv["data"]).to(self.device) + # FIX: cast gradient về đúng dtype của inter (fp16/fp32) + gradient = gradient.to(dtype=inter.dtype) + data_id = recv["data_id"] + inter = data_store.pop(data_id) + + if torch.isnan(gradient).any(): + print("[ERROR] Skip backward layer 1 (NaN gradient)") + continue + + optimizer.zero_grad() + inter.backward(gradient) + + has_nan = any( + p.grad is not None and torch.isnan(p.grad).any() + for p in model.parameters() + ) + if has_nan: + print("[ERROR] NaN grad → skip step") + optimizer.zero_grad() + continue + + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + + optimizer.step() + scheduler.step() + pbar.update(1) + + else: + + if len(data_store) >= control_count: + continue + try: + batch = next(data_iter) + inp_ids = batch["input_ids"].to(self.device) + attn = batch["attention_mask"].to(self.device) + labels = batch["labels"].to(self.device) + data_id = uuid.uuid4() + + t0 = time.time() + with opt.precision_ctx(self.device): + out = model(input_ids=inp_ids, attention_mask=attn) + inter = out["hidden_states"] + mask_out = out["mask"] + inter = inter.detach().requires_grad_(True) + data_store[data_id] = inter + forward_t.append(time.time() - t0) + + num_fwd += 1 + self.data_count += 1 + + t0 = time.time() + q_numpy, scale = opt.quant(inter) + self.send_intermediate_output( + data_id, q_numpy, scale, mask_out, labels, trace=None + ) + comm_t.append(time.time() - t0) + + except StopIteration: + end_data = True + + if end_data and num_fwd == num_bwd: + break + + self.send_end_signal() + + notify = { + "action": "NOTIFY", + "client_id": self.client_id, + "layer_id": self.layer_id, + "message": "Finish training!", + } + src.Log.print_with_color("[>>>] Client 1 gửi NOTIFY về server", "red") + self.send_to_server(notify) + + # Chờ PAUSE từ server + bcast_q = f"reply_{self.client_id}" + while True: + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": - print(f'forward: {forward_time}') - print(f'backward: {backward_time}') - print(f'comm: {comm_time}') + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Client 1: {recv['action']}", "blue") + if recv["action"] == "PAUSE": + src.Log.print_with_color( + f"[INFO] Forward: {len(forward_t)} steps, " + f"Backward: {num_bwd} steps", "green" + ) return True, self.data_count time.sleep(0.5) - def last_layer(self, model, lr, weight_decay, clip_grad_norm): + + + def last_layer(self, model, lr, weight_decay, clip_grad_norm, + opt: OptimizationBundle = None): + + if opt is None: + opt = OptimizationBundle({}) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - pad_id = tokenizer.pad_token_id - if pad_id is None: + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - pad_id = tokenizer.eos_token_id optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) - criterion = nn.CrossEntropyLoss(ignore_index=pad_id) - result = True + criterion = nn.CrossEntropyLoss(ignore_index=-100) + result = True + + scheduler = CosineAnnealingLR(optimizer, T_max=1000, eta_min=lr / 10) - forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' - self.channel.queue_declare(queue=forward_queue_name, durable=False) + fwd_q_name = f"intermediate_queue_{self.layer_id - 1}" + self.channel.queue_declare(queue=fwd_q_name, durable=False) self.channel.basic_qos(prefetch_count=1) - print('Waiting for intermediate output. To exit press CTRL+C') + src.Log.print_with_color("Layer 2: Waiting for hidden states...", "green") model.to(self.device) model.train() - exec_time = [] - comm_time = [] + + exec_t = [] + comm_t = [] + num_received = 0 + num_grad_sent = 0 + end_received = False + nan_count = 0 + while True: - method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + method_frame, _, body = self.channel.basic_get( + queue=fwd_q_name, auto_ack=True + ) if method_frame and body: - optimizer.zero_grad() - received_data = pickle.loads(body) - intermediate_output_numpy = received_data["data"] - attention_mask = received_data["attention_mask"].to(self.device) - trace = received_data["trace"] - data_id = received_data["data_id"] - labels = received_data["label"].to(self.device) - - intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) - - start = time.time() - output, _ = model(input_ids=intermediate_output, attention_mask=attention_mask) - shift_logits = output[:, :-1, :].contiguous() # [B, L-1, V] - shift_labels = labels[:, 1:].contiguous() # [B, L-1] - - loss = criterion( - shift_logits.view(-1, shift_logits.size(-1)), # [(B*(L-1)), V] - shift_labels.view(-1) # [(B*(L-1))] + recv = pickle.loads(body) + + # END sentinel + if recv.get("action") == "END": + src.Log.print_with_color( + f"[<<<] END received. Received {num_received} batches, " + f"sent {num_grad_sent} gradients.", "yellow" + ) + end_received = True + continue + + mask = recv["mask"].to(self.device) + trace = recv["trace"] + data_id = recv["data_id"] + labels = recv["label"].to(self.device) + num_received += 1 + + inter = opt.dequant( + recv["data"], recv.get("scale"), self.device, requires_grad=True ) + # FIX: đồng bộ dtype của inter với model để tránh Half/Float mismatch + model_dtype = next(model.parameters()).dtype + inter = inter.to(dtype=model_dtype) - # loss = criterion(output.view(-1, output.size(-1)), labels.view(-1)) - if torch.isnan(loss).any(): - src.Log.print_with_color("NaN detected in loss", "yellow") - result = False - - print(f"Loss: {loss.item()}") - intermediate_output.retain_grad() - loss.backward() + optimizer.zero_grad() - optimizer.step() - exec_time.append(time.time() - start) + t0 = time.time() + with opt.precision_ctx(self.device): + out = model(input_ids=inter, attention_mask=mask) + output = out["logits"] + shift_logits = output[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + loss = criterion( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + if torch.isnan(loss): + src.Log.print_with_color("NaN loss — sending zero gradient", "yellow") + nan_count += 1 + num_grad_sent += 1 + self.send_gradient(data_id, torch.zeros_like(inter), trace) + continue + + print(f"Loss: {loss.item():.4f}") + + inter.retain_grad() + + if opt.scaler is not None: + opt.scaler.scale(loss).backward() + opt.scaler.unscale_(optimizer) + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + opt.scaler.step(optimizer) + opt.scaler.update() + else: + loss.backward() + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + optimizer.step() + + scheduler.step() + exec_t.append(time.time() - t0) self.data_count += 1 - gradient = intermediate_output.grad - start_comm = time.time() - self.send_gradient(data_id, gradient, trace) # 1F1B - comm_time.append(time.time() - start_comm) - # Check training process - else: - broadcast_queue_name = f'reply_{self.client_id}' - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) - if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": - print(f'exec_time: {exec_time}') - print(f'comm_time: {comm_time}') - return result, self.data_count - - def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None): - pass - - def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None): - pass - - + t0 = time.time() + grad = inter.grad if inter.grad is not None else torch.zeros_like(inter) + if inter.grad is None: + src.Log.print_with_color( + "[WARN] inter.grad is None — sending zero gradient", "yellow" + ) + self.send_gradient(data_id, grad, trace) + num_grad_sent += 1 + comm_t.append(time.time() - t0) + else: + if end_received and num_grad_sent == num_received: + if num_received > 0 and nan_count / num_received > 0.5: + result = False + src.Log.print_with_color( + f"[WARN] {nan_count}/{num_received} batches NaN → round failed", + "yellow" + ) + + src.Log.print_with_color( + f"[>>>] Tất cả {num_grad_sent} gradient đã gửi. Gửi NOTIFY.", "red" + ) + notify = { + "action": "NOTIFY", + "client_id": self.client_id, + "layer_id": self.layer_id, + "message": "Finish training!", + } + self.send_to_server(notify) + end_received = False + + # Chờ PAUSE từ server + bcast_q = f"reply_{self.client_id}" + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) + if body: + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Client 2: {recv['action']}", "blue") + if recv["action"] == "PAUSE": + src.Log.print_with_color( + f"[INFO] Exec: {len(exec_t)} steps", "green" + ) + return result, self.data_count + time.sleep(0.1) + def train_on_middle_layer(self, *args, **kwargs): pass + def alone_training(self, *args, **kwargs): passs \ No newline at end of file diff --git a/src/fine_tune/Llama.py b/src/fine_tune/Llama.py index 5435d21..50a0c46 100644 --- a/src/fine_tune/Llama.py +++ b/src/fine_tune/Llama.py @@ -7,208 +7,226 @@ import torch.nn as nn import src.Log +from src.Optimizer import OptimizationBundle + class Ft_Llama: def __init__(self, client_id, layer_id, channel, device): - self.client_id = client_id - self.layer_id = layer_id - self.channel = channel - self.device = device + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device self.data_count = 0 - def send_intermediate_output(self, data_id, output, attention_mask, labels, trace): - - forward_queue_name = f'intermediate_queue_{self.layer_id}' - self.channel.queue_declare(forward_queue_name, durable=False) - - if trace: - trace.append(self.client_id) - message = pickle.dumps( - {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace, - "attention_mask": attention_mask.cpu()} - ) - else: - message = pickle.dumps( - {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id], - "attention_mask" :attention_mask.cpu()} - ) - - self.channel.basic_publish( - exchange='', - routing_key=forward_queue_name, - body=message - ) + # ── Giao tiếp ──────────────────────────────────────────── + + def send_intermediate_output(self, data_id, q_numpy, scale, attention_mask, labels, trace): + fwd_q = f"intermediate_queue_{self.layer_id}" + self.channel.queue_declare(fwd_q, durable=False) + trace_out = list(trace) + [self.client_id] if trace else [self.client_id] + msg = pickle.dumps({ + "data_id": data_id, + "data": q_numpy, + "scale": scale, + "label": labels.cpu(), + "trace": trace_out, + "attention_mask": attention_mask.cpu(), + }) + self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg) def send_gradient(self, data_id, gradient, trace): - to_client_id = trace[-1] - trace.pop(-1) - backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) - - message = pickle.dumps( - {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace}) + to_id = trace[-1] + trace = trace[:-1] + bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}" + self.channel.queue_declare(queue=bwd_q, durable=False) + msg = pickle.dumps({ + "data_id": data_id, + "data": gradient.detach().cpu().numpy(), + "trace": trace, + }) + self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg) + def send_to_server(self, message): + self.channel.queue_declare("rpc_queue", durable=False) self.channel.basic_publish( - exchange='', - routing_key=backward_queue_name, - body=message + exchange="", routing_key="rpc_queue", body=pickle.dumps(message) ) - def send_to_server(self, message): - self.channel.queue_declare('rpc_queue', durable=False) - self.channel.basic_publish(exchange='', - routing_key='rpc_queue', - body=pickle.dumps(message)) + # ── Layer 1 ─────────────────────────────────────────────── - def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1, - train_loader=None): - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + def first_layer(self, model, lr, weight_decay, clip_grad_norm, + control_count=1, train_loader=None, + opt: OptimizationBundle = None): - backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' - self.channel.queue_declare(queue=backward_queue_name, durable=False) - self.channel.basic_qos(prefetch_count=1) + if opt is None: + opt = OptimizationBundle({}) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}" + self.channel.queue_declare(queue=bwd_q_name, durable=False) + self.channel.basic_qos(prefetch_count=1) model = model.to(self.device) - for i in range(1): - data_iter = iter(train_loader) - num_forward = 0 - num_backward = 0 - end_data = False - data_store = {} - - with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: - while True: - # Training model - model.train() - optimizer.zero_grad() + data_iter = iter(train_loader) + num_fwd = num_bwd = 0 + end_data = False + data_store = {} # {uuid: (input_ids, attention_mask)} + + with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: + while True: + model.train() + + method_frame, _, body = self.channel.basic_get( + queue=bwd_q_name, auto_ack=True + ) + if method_frame and body: + num_bwd += 1 + recv = pickle.loads(body) + gradient = torch.tensor(recv["data"]).to(self.device) + data_id = recv["data_id"] + inp, mask = data_store.pop(data_id) # Fix Bug 7 - # Process gradient - method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) - if method_frame and body: - num_backward += 1 - received_data = pickle.loads(body) - gradient_numpy = received_data["data"] - gradient = torch.tensor(gradient_numpy).to(self.device) - data_id = received_data["data_id"] - - data_input = data_store.pop(data_id) - output, mask = model(input_ids=data_input[0], attention_mask=data_input[1]) - output.backward(gradient=gradient) - optimizer.step() + optimizer.zero_grad() + with opt.precision_ctx(self.device): + output, _ = model(input_ids=inp, attention_mask=mask) + + # Fix Bug 1: chỉ 1 backward với split gradient + output.backward(gradient=gradient) + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + if opt.scaler is not None: + opt.scaler.step(optimizer) + opt.scaler.update() else: - # speed control - if len(data_store) >= control_count: - continue - - try: - batch = next(data_iter) - input_ids = batch['input_ids'].to(self.device) - attention_mask = batch['attention_mask'].to(self.device) - labels = batch['input_ids'].to(self.device) - data_id = uuid.uuid4() - data_store[data_id] = (input_ids, attention_mask) - - intermediate_output, mask = model(input_ids=input_ids, attention_mask=attention_mask) - intermediate_output = intermediate_output.detach().requires_grad_(True) - - num_forward += 1 - self.data_count += 1 - - pbar.update(1) - - self.send_intermediate_output(data_id, intermediate_output, mask, - labels, trace=None) - - except StopIteration: - end_data = True - - if end_data and (num_forward == num_backward): - break - - notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, - "message": "Finish training!"} + optimizer.step() + else: + # Fix Bug 7: giới hạn pipeline depth + if len(data_store) >= control_count: + continue + try: + batch = next(data_iter) + inp_ids = batch["input_ids"].to(self.device) + attn = batch["attention_mask"].to(self.device) + labels = batch["labels"].to(self.device) + data_id = uuid.uuid4() + data_store[data_id] = (inp_ids, attn) + + with opt.precision_ctx(self.device): + inter, mask_out = model(input_ids=inp_ids, attention_mask=attn) + + inter = inter.detach().requires_grad_(True) + num_fwd += 1 + self.data_count += 1 + pbar.update(1) + + q_numpy, scale = opt.quant(inter) + self.send_intermediate_output( + data_id, q_numpy, scale, mask_out, labels, trace=None + ) + + except StopIteration: + end_data = True + + if end_data and num_fwd == num_bwd: + break + + notify = { + "action": "NOTIFY", "client_id": self.client_id, + "layer_id": self.layer_id, "message": "Finish training!", + } src.Log.print_with_color("[>>>] Finish training!", "red") - self.send_to_server(notify_data) + self.send_to_server(notify) - broadcast_queue_name = f'reply_{self.client_id}' - while True: # Wait for broadcast - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + bcast_q = f"reply_{self.client_id}" + while True: + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] {recv}", "blue") + if recv["action"] == "PAUSE": return True, self.data_count time.sleep(0.5) - def last_layer(self, model, lr, weight_decay, clip_grad_norm): + # ── Layer 2 ─────────────────────────────────────────────── + + def last_layer(self, model, lr, weight_decay, clip_grad_norm, + opt: OptimizationBundle = None): + + if opt is None: + opt = OptimizationBundle({}) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) criterion = nn.CrossEntropyLoss(ignore_index=-100) - result = True + result = True - forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' - self.channel.queue_declare(queue=forward_queue_name, durable=False) + fwd_q_name = f"intermediate_queue_{self.layer_id - 1}" + self.channel.queue_declare(queue=fwd_q_name, durable=False) self.channel.basic_qos(prefetch_count=1) - print('Waiting for intermediate output. To exit press CTRL+C') + print("Waiting for intermediate output. To exit press CTRL+C") model.to(self.device) model.train() + while True: - method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + method_frame, _, body = self.channel.basic_get( + queue=fwd_q_name, auto_ack=True + ) if method_frame and body: - optimizer.zero_grad() - received_data = pickle.loads(body) - intermediate_output_numpy = received_data["data"] - attention_mask = received_data["attention_mask"].to(self.device) - trace = received_data["trace"] - data_id = received_data["data_id"] - labels = received_data["label"].to(self.device) - - intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) - - output, _ = model(input_ids=intermediate_output, attention_mask=attention_mask) - - loss = criterion(output.view(-1, output.size(-1)), labels.view(-1)) - if torch.isnan(loss).any(): + recv = pickle.loads(body) + attn = recv["attention_mask"].to(self.device) + trace = recv["trace"] + data_id = recv["data_id"] + labels = recv["label"].to(self.device) + + inter = opt.dequant( + recv["data"], recv.get("scale"), self.device, requires_grad=True + ) + + with opt.precision_ctx(self.device): + output, _ = model(input_ids=inter, attention_mask=attn) + # Causal LM loss: shift + shift_logits = output[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + loss = criterion( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + if torch.isnan(loss): src.Log.print_with_color("NaN detected in loss", "yellow") result = False - print(f"Loss: {loss.item()}") - intermediate_output.retain_grad() - loss.backward() + print(f"Loss: {loss.item():.4f}") + + # Fix Bug 4: retain_grad() TRƯỚC backward + inter.retain_grad() + + # Fix Bug 2: 1 backward, 1 optimizer.step + if opt.scaler is not None: + opt.scaler.scale(loss).backward() + opt.scaler.unscale_(optimizer) + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + opt.scaler.step(optimizer) + opt.scaler.update() + else: + loss.backward() + if clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + optimizer.step() - optimizer.step() self.data_count += 1 + self.send_gradient(data_id, inter.grad, trace) - gradient = intermediate_output.grad - - self.send_gradient(data_id, gradient, trace) # 1F1B - # Check training process else: - broadcast_queue_name = f'reply_{self.client_id}' - method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + bcast_q = f"reply_{self.client_id}" + _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True) if body: - received_data = pickle.loads(body) - src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") - if received_data["action"] == "PAUSE": + recv = pickle.loads(body) + src.Log.print_with_color(f"[<<<] {recv}", "blue") + if recv["action"] == "PAUSE": return result, self.data_count - def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None): - pass - - def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None): - pass - - - - - - - - - - - - - + def train_on_middle_layer(self, *args, **kwargs): pass + def alone_training(self, *args, **kwargs): pass diff --git a/src/fine_tune/__pycache__/Bert.cpython-313.pyc b/src/fine_tune/__pycache__/Bert.cpython-313.pyc new file mode 100644 index 0000000..8c97500 Binary files /dev/null and b/src/fine_tune/__pycache__/Bert.cpython-313.pyc differ diff --git a/src/fine_tune/__pycache__/GPT2.cpython-313.pyc b/src/fine_tune/__pycache__/GPT2.cpython-313.pyc new file mode 100644 index 0000000..10b263b Binary files /dev/null and b/src/fine_tune/__pycache__/GPT2.cpython-313.pyc differ diff --git a/src/fine_tune/__pycache__/Llama.cpython-313.pyc b/src/fine_tune/__pycache__/Llama.cpython-313.pyc new file mode 100644 index 0000000..299657e Binary files /dev/null and b/src/fine_tune/__pycache__/Llama.cpython-313.pyc differ diff --git a/src/fine_tune/__pycache__/__init__.cpython-313.pyc b/src/fine_tune/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..8e4d268 Binary files /dev/null and b/src/fine_tune/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/model/Bert.py b/src/model/Bert.py index 6e44e5f..c4769bc 100644 --- a/src/model/Bert.py +++ b/src/model/Bert.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from src.Optimizer import flash_scaled_dot_product class DotDict(dict): def __getattr__(self, k): @@ -37,47 +38,35 @@ def forward(self, input_ids, token_type_ids=None): return embeddings class BertSdpaSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): + def __init__(self, hidden_size, num_attention_heads, dropout_prob, use_flash=False): super(BertSdpaSelfAttention, self).__init__() self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.use_flash = use_flash self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.dropout = nn.Dropout(dropout_prob) + self._dropout_p = dropout_prob def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + return x.permute(0, 2, 1, 3) # (B, H, T, D) def forward(self, hidden_states): + q = self.transpose_for_scores(self.query(hidden_states)) + k = self.transpose_for_scores(self.key(hidden_states)) + v = self.transpose_for_scores(self.value(hidden_states)) - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - import math - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) + dp = self._dropout_p if self.training else 0.0 + context_layer = flash_scaled_dot_product(q, k, v, dropout_p=dp) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer + new_shape = context_layer.size()[:-2] + (self.all_head_size,) + return context_layer.view(*new_shape) class BertSelfOutput(nn.Module): def __init__(self, hidden_size, dropout_prob): @@ -93,9 +82,10 @@ def forward(self, hidden_states, input_tensor): return hidden_states class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): + def __init__(self, hidden_size, num_attention_heads, dropout_prob, use_flash=False): super(BertAttention, self).__init__() - self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, + dropout_prob, use_flash=use_flash) self.output = BertSelfOutput(hidden_size, dropout_prob) def forward(self, hidden_states): @@ -128,9 +118,11 @@ def forward(self, hidden_states, input_tensor): return hidden_states class BertLayer(nn.Module): - def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob): + def __init__(self, hidden_size, num_attention_heads, intermediate_size, + dropout_prob, use_flash=False): super(BertLayer, self).__init__() - self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob) + self.attention = BertAttention(hidden_size, num_attention_heads, + dropout_prob, use_flash=use_flash) self.intermediate = BertIntermediate(hidden_size, intermediate_size) self.output = BertOutput(hidden_size, intermediate_size, dropout_prob) @@ -165,10 +157,12 @@ def forward(self, pooled_output): class Bert(nn.Module): def __init__( self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072, - max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, layer_id=0, n_block=12 + max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, layer_id=0, n_block=12, + use_flash=False ): super(Bert, self).__init__() self.layer_id = layer_id + self.use_flash = use_flash self.config = DotDict( model_type="bert", vocab_size=vocab_size, @@ -180,29 +174,30 @@ def __init__( self, vocab_size=28996, hidden_size=768, num_attention_heads=12, i use_return_dict=True, output_attentions=False, output_hidden_states=False ) + def _make_layers(n): + return nn.ModuleList([ + BertLayer(hidden_size, num_attention_heads, intermediate_size, + dropout_prob, use_flash=use_flash) + for _ in range(n) + ]) + if self.layer_id == 1: - self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, max_position_embeddings=max_position_embeddings, - type_vocab_size=type_vocab_size,dropout_prob=dropout_prob) - self.layers = nn.ModuleList( - [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) - for _ in range(n_block)] - ) + self.embeddings = BertEmbeddings( + vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + self.layers = _make_layers(n_block) elif self.layer_id == 2: - self.layers = nn.ModuleList( - [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) - for _ in range(n_block)] - ) + self.layers = _make_layers(n_block) self.pooler = BertPooler(hidden_size) self.dropout = nn.Dropout(dropout_prob) self.classifier = nn.Linear(hidden_size, 4) else: - self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, - max_position_embeddings=max_position_embeddings, - type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) - self.layers = nn.ModuleList( - [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) - for _ in range(n_block)] - ) + self.embeddings = BertEmbeddings( + vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + self.layers = _make_layers(n_block) self.pooler = BertPooler(hidden_size) self.dropout = nn.Dropout(dropout_prob) self.classifier = nn.Linear(hidden_size, 4) diff --git a/src/model/GPT2.py b/src/model/GPT2.py index ce4b9bd..1a4067c 100644 --- a/src/model/GPT2.py +++ b/src/model/GPT2.py @@ -4,6 +4,7 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import NewGELUActivation +from src.Optimizer import flash_scaled_dot_product class DotDict(dict): def __getattr__(self, k): @@ -14,38 +15,35 @@ def __delattr__(self, k): del self[k] class Attention(nn.Module): - def __init__(self, embed_size: int, num_heads: int, dropout=0.0): + def __init__(self, embed_size: int, num_heads: int, dropout=0.0, use_flash=False): super().__init__() assert embed_size % num_heads == 0 self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads + self.use_flash = use_flash self.c_attn = Conv1D(3 * embed_size, embed_size) self.c_proj = Conv1D(embed_size, embed_size) self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) + self._dropout_p = dropout def forward(self, x, mask=None): B, T, E = x.shape H, D = self.num_heads, self.head_dim - qkv = self.c_attn(x) # (B, T, 3E) - q, k, v = qkv.split(E, dim=-1) # (B, T, E) x3 - q = q.view(B, T, H, D) - k = k.view(B, T, H, D) - v = v.view(B, T, H, D) - - # (B, H, T, T) - att = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(D) - if mask is not None: - att = att.masked_fill(mask == 0, float("-1e20")) - att = torch.softmax(att, dim=-1) - att = self.attn_drop(att) - - # (B, T, H, D) -> (B, T, E) - y = torch.einsum("bhqk,bkhd->bqhd", att, v).contiguous().view(B, T, E) + qkv = self.c_attn(x) + q, k, v = qkv.split(E, dim=-1) + q = q.view(B, T, H, D).transpose(1, 2) + k = k.view(B, T, H, D).transpose(1, 2) + v = v.view(B, T, H, D).transpose(1, 2) + + dp = self._dropout_p if self.training else 0.0 + y = flash_scaled_dot_product(q, k, v, mask=mask, dropout_p=dp) # (B,H,T,D) + + y = y.transpose(1, 2).contiguous().view(B, T, E) y = self.c_proj(y) y = self.resid_drop(y) return y @@ -66,10 +64,10 @@ def forward(self, x): return x class GPT2Block(nn.Module): - def __init__(self, embed_size: int, num_heads: int, dropout: float): + def __init__(self, embed_size: int, num_heads: int, dropout: float, use_flash=False): super().__init__() self.ln_1 = nn.LayerNorm(embed_size) - self.attn = Attention(embed_size, num_heads, dropout) + self.attn = Attention(embed_size, num_heads, dropout, use_flash=use_flash) self.ln_2 = nn.LayerNorm(embed_size) self.mlp = MLP(embed_size, dropout) @@ -79,12 +77,12 @@ def forward(self, x, mask): return x def build_causal_mask(B, T, device): - return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T) + return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous() class GPT2(nn.Module): def __init__(self, vocab_size=50257, max_length=1024, n_layer=12, n_head=12, n_embd=768, dropout=0.1, - layer_id=0, n_block=12): + layer_id=0, n_block=12, use_flash=False): super().__init__() self.vocab_size = vocab_size self.max_length = max_length @@ -98,87 +96,99 @@ def __init__(self, vocab_size=50257, max_length=1024, n_layer=12, n_head=12, n_e bos_token_id=50256, eos_token_id=50256, pad_token_id=0, is_encoder_decoder=False, tie_word_embeddings=False, ) + + def _make_blocks(n): + return nn.ModuleList([ + GPT2Block(n_embd, n_head, dropout, use_flash=use_flash) + for _ in range(n) + ]) + if self.layer_id == 1: self.wte = nn.Embedding(vocab_size, n_embd) self.wpe = nn.Embedding(max_length, n_embd) self.drop = nn.Dropout(dropout) - - self.h = nn.ModuleList([ - GPT2Block(n_embd, n_head, dropout) - for _ in range(n_block) - ]) + self.h = _make_blocks(n_block) elif self.layer_id == 2: - self.h = nn.ModuleList([ - GPT2Block(n_embd, n_head, dropout) - for _ in range(n_block) - ]) + self.h = _make_blocks(n_block) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) else: self.wte = nn.Embedding(vocab_size, n_embd) self.wpe = nn.Embedding(max_length, n_embd) self.drop = nn.Dropout(dropout) - - self.h = nn.ModuleList([ - GPT2Block(n_embd, n_head, dropout) - for _ in range(n_block) - ]) + self.h = _make_blocks(n_block) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) + self.lm_head.weight = self.wte.weight + + def forward(self, input_ids=None, attention_mask=None, **kwargs): - def forward(self, input_ids, attention_mask=None, - **kwargs): if self.layer_id == 1: B, T = input_ids.shape h = self.wte(input_ids) - pos = torch.arange(0, T).unsqueeze(0).to(attention_mask.device) + pos = torch.arange(0, T).unsqueeze(0).to(input_ids.device) h = self.drop(h + self.wpe(pos)) - mask = build_causal_mask(B, T, attention_mask.device) + + mask = build_causal_mask(B, T, input_ids.device) + if attention_mask is not None: key_mask = attention_mask[:, None, None, :].to(mask.dtype) qry_mask = attention_mask[:, None, :, None].to(mask.dtype) - mask = mask * key_mask * qry_mask + mask = (mask * key_mask * qry_mask) > 0 for blk in self.h: h = blk(h, mask) + return {"hidden_states": h, "mask": mask} + elif self.layer_id == 2: h = input_ids mask = attention_mask + for blk in self.h: - h = blk(h, attention_mask) + h = blk(h, mask) + h = self.ln_f(h) - h = self.lm_head(h) + logits = self.lm_head(h) + + return {"logits": logits} else: B, T = input_ids.shape h = self.wte(input_ids) - pos = torch.arange(0, T).unsqueeze(0).to(attention_mask.device) + + pos = torch.arange(0, T).unsqueeze(0).to(input_ids.device) h = self.drop(h + self.wpe(pos)) - mask = build_causal_mask(B, T, attention_mask.device) + + mask = build_causal_mask(B, T, input_ids.device) + if attention_mask is not None: key_mask = attention_mask[:, None, None, :].to(mask.dtype) qry_mask = attention_mask[:, None, :, None].to(mask.dtype) - mask = mask * key_mask * qry_mask + mask = (mask * key_mask * qry_mask) > 0 for blk in self.h: h = blk(h, mask) - h = self.ln_f(h) - h = self.lm_head(h) - return h, mask + h = self.ln_f(h) + logits = self.lm_head(h) + return { + "logits": logits, + "hidden_states": h, + "mask": mask + } def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): - B, T = input_ids.shape - device = input_ids.device - causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T) - if attention_mask is not None: - key_mask = attention_mask[:, None, None, :].to(causal.dtype) - qry_mask = attention_mask[:, None, :, None].to(causal.dtype) - mask = causal * key_mask * qry_mask - else: - mask = causal - return {"input_ids": input_ids, "mask": mask} + B, T = input_ids.shape + device = input_ids.device + causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous() + if attention_mask is not None: + key_mask = attention_mask[:, None, None, :].to(causal.dtype) + qry_mask = attention_mask[:, None, :, None].to(causal.dtype) + mask = causal * key_mask * qry_mask + else: + mask = causal + return {"input_ids": input_ids, "mask": mask} def shift_labels_for_lm(input_ids, ignore_index=-100): labels = input_ids.clone() diff --git a/src/model/Llama.py b/src/model/Llama.py index bcc5b8e..6808026 100644 --- a/src/model/Llama.py +++ b/src/model/Llama.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import math +from src.Optimizer import flash_scaled_dot_product class DotDict(dict): def __getattr__(self, k): @@ -35,7 +36,7 @@ def apply_rotary_pos_emb(q, k, cos, sin): return q_rot, k_rot def build_causal_mask(B, T, device): - return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T) + return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous() class LlamaRotaryEmbedding(nn.Module): def __init__( @@ -62,17 +63,19 @@ def __init__( self._set_cos_sin_cache(max_position_embeddings, dtype=torch.float32, device=device) def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype, device=None): + # Fix Bug 12: khi device=None (gọi lại từ forward), dùng inv_freq.device device = device if device is not None else self.inv_freq.device - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [seq_len] + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freq = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat([freq, freq], dim=-1) cos = emb.cos()[None, None, :, :] # [1, 1, seq_len, dim] - sin = emb.sin()[None, None, :, :] # [1, 1, seq_len, dim] + sin = emb.sin()[None, None, :, :] - self.cos_cached = cos.to(dtype=dtype, device=device) - self.sin_cached = sin.to(dtype=dtype, device=device) + # Lưu FP32 rồi cast khi dùng để tránh precision mismatch + self.cos_cached = cos.to(dtype=torch.float32, device=device) + self.sin_cached = sin.to(dtype=torch.float32, device=device) self.max_seq_len_cached = seq_len def forward(self, x, seq_len: int = None): @@ -80,19 +83,22 @@ def forward(self, x, seq_len: int = None): seq_len = x.shape[-2] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len, dtype=x.dtype, device=x.device) + self._set_cos_sin_cache(seq_len, dtype=torch.float32, device=x.device) + # Cast về dtype và device của x tại thời điểm dùng (hỗ trợ FP16/BF16) cos = self.cos_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device) sin = self.sin_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device) return cos, sin class LlamaAttention(nn.Module): - def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=64): + def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=64, + use_flash=False): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim + self.use_flash = use_flash self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) @@ -101,6 +107,7 @@ def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=6 self.rope = LlamaRotaryEmbedding(head_dim) self.dropout = nn.Dropout(dropout) + self._dropout_p = dropout def forward(self, x, attention_mask=None): B, T, C = x.size() @@ -116,17 +123,9 @@ def forward(self, x, attention_mask=None): k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) - att = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim) + dp = self._dropout_p if self.training else 0.0 + out = flash_scaled_dot_product(q, k, v, mask=attention_mask, dropout_p=dp) - if attention_mask is not None: - att = att.masked_fill( - attention_mask == 0, - torch.finfo(att.dtype).min - ) - - att = F.softmax(att, dim=-1) - att = self.dropout(att) - out = att @ v out = out.transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(out) @@ -143,10 +142,12 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class LlamaDecoderLayer(nn.Module): - def __init__(self, hidden_size, num_heads, num_kv_heads, intermediate_size, rms_eps=1e-6): + def __init__(self, hidden_size, num_heads, num_kv_heads, intermediate_size, + rms_eps=1e-6, use_flash=False): super().__init__() self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_eps) - self.self_attn = LlamaAttention(hidden_size, num_heads, num_kv_heads) + self.self_attn = LlamaAttention(hidden_size, num_heads, num_kv_heads, + use_flash=use_flash) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_eps) self.mlp = LlamaMLP(hidden_size, intermediate_size) @@ -159,9 +160,10 @@ def forward(self, x, attention_mask=None): class Llama(nn.Module): def __init__(self, vocab_size=32000, hidden_size=768, intermediate_size=3072, num_attention_heads=12, num_key_value_heads=12, - layer_id=0, n_block=12): + layer_id=0, n_block=12, use_flash=False): super().__init__() self.layer_id = layer_id + self.use_flash = use_flash self.config = DotDict( model_type="llama", vocab_size=vocab_size, @@ -181,25 +183,25 @@ def __init__(self, vocab_size=32000, hidden_size=768, intermediate_size=3072, nu use_cache=True, torch_dtype="float32", ) + + def _make_layers(n): + return nn.ModuleList([ + LlamaDecoderLayer(hidden_size, num_attention_heads, + num_key_value_heads, intermediate_size, + use_flash=use_flash) + for _ in range(n) + ]) + if self.layer_id == 1: self.embed_tokens = nn.Embedding(vocab_size, hidden_size) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in - range(n_block) - ]) + self.layers = _make_layers(n_block) elif self.layer_id == 2: - self.layers = nn.ModuleList([ - LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in - range(n_block) - ]) + self.layers = _make_layers(n_block) self.norm = LlamaRMSNorm(hidden_size) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) else: self.embed_tokens = nn.Embedding(vocab_size, hidden_size) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in - range(n_block) - ]) + self.layers = _make_layers(n_block) self.norm = LlamaRMSNorm(hidden_size) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) @@ -208,7 +210,9 @@ def forward(self, input_ids, attention_mask=None, **kwargs): B, T = input_ids.shape x = self.embed_tokens(input_ids) - masks = build_causal_mask(B, T, attention_mask.device) + # Fix: dùng input_ids.device làm fallback khi attention_mask là None + ref_device = attention_mask.device if attention_mask is not None else input_ids.device + masks = build_causal_mask(B, T, ref_device) if attention_mask is not None: key_mask = attention_mask[:, None, None, :].to(masks.dtype) qry_mask = attention_mask[:, None, :, None].to(masks.dtype) @@ -229,7 +233,8 @@ def forward(self, input_ids, attention_mask=None, **kwargs): B, T = input_ids.shape x = self.embed_tokens(input_ids) - masks = build_causal_mask(B, T, attention_mask.device) + ref_device = attention_mask.device if attention_mask is not None else input_ids.device + masks = build_causal_mask(B, T, ref_device) if attention_mask is not None: key_mask = attention_mask[:, None, None, :].to(masks.dtype) qry_mask = attention_mask[:, None, :, None].to(masks.dtype) @@ -246,11 +251,4 @@ def forward(self, input_ids, attention_mask=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): B, T = input_ids.shape device = input_ids.device - causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T) - if attention_mask is not None: - key_mask = attention_mask[:, None, None, :].to(causal.dtype) - qry_mask = attention_mask[:, None, :, None].to(causal.dtype) - masks = causal * key_mask * qry_mask - else: - masks = causal - return {"input_ids": input_ids, "attention_mask": masks, **kwargs} + causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous() diff --git a/src/model/__pycache__/Bert.cpython-313.pyc b/src/model/__pycache__/Bert.cpython-313.pyc new file mode 100644 index 0000000..279f3da Binary files /dev/null and b/src/model/__pycache__/Bert.cpython-313.pyc differ diff --git a/src/model/__pycache__/GPT2.cpython-310.pyc b/src/model/__pycache__/GPT2.cpython-310.pyc new file mode 100644 index 0000000..23c41c5 Binary files /dev/null and b/src/model/__pycache__/GPT2.cpython-310.pyc differ diff --git a/src/model/__pycache__/GPT2.cpython-313.pyc b/src/model/__pycache__/GPT2.cpython-313.pyc new file mode 100644 index 0000000..eb3dfec Binary files /dev/null and b/src/model/__pycache__/GPT2.cpython-313.pyc differ diff --git a/src/model/__pycache__/Llama.cpython-313.pyc b/src/model/__pycache__/Llama.cpython-313.pyc new file mode 100644 index 0000000..d29f953 Binary files /dev/null and b/src/model/__pycache__/Llama.cpython-313.pyc differ diff --git a/src/model/__pycache__/__init__.cpython-310.pyc b/src/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..dcb7fd2 Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/model/__pycache__/__init__.cpython-313.pyc b/src/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..5412035 Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/val/Bert.py b/src/val/Bert.py index 48b0d40..f048c52 100644 --- a/src/val/Bert.py +++ b/src/val/Bert.py @@ -9,7 +9,7 @@ def val_Bert(model_name, data_name, state_dict_full, logger): criterion = nn.CrossEntropyLoss() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - test_loader = dataloader(model_name==model_name, data_name=data_name, train=False) + test_loader = dataloader(model_name=model_name, data_name=data_name, train=False) model = Bert() model = model.to(device) model.load_state_dict(state_dict_full) @@ -35,6 +35,8 @@ def val_Bert(model_name, data_name, state_dict_full, logger): logger.log_info(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}") + return True + diff --git a/src/val/GPT2.py b/src/val/GPT2.py index 5973ec7..4e3de3c 100644 --- a/src/val/GPT2.py +++ b/src/val/GPT2.py @@ -1,3 +1,4 @@ +import os import torch import torch.nn as nn from tqdm import tqdm @@ -7,96 +8,163 @@ import re -def extract_final_number(s: str) -> str: +# Metric cho E2E (cần cài: pip install sacrebleu rouge-score) +try: + from sacrebleu import corpus_bleu + from rouge_score import rouge_scorer as rouge_scorer_lib + BLEU_AVAILABLE = True +except ImportError: + BLEU_AVAILABLE = False + +def extract_final_number(s: str) -> str: if s is None: return "" - m = re.search(r"####\s*([\-+]?\d+(?:\.\d+)?)", s) if m: return m.group(1).strip() - nums = re.findall(r"[\-+]?\d+(?:\.\d+)?", s) if nums: return nums[-1].strip() return "" + +def _greedy_generate(model, prompt_ids, prompt_mask, max_new_tokens, pad_id, device): + cur_ids = prompt_ids.clone() + cur_mask = prompt_mask.clone() + generated = [] + + for _ in range(max_new_tokens): + out = model(input_ids=cur_ids, attention_mask=cur_mask) + logits = out["hidden_states"] + next_token = logits[0, -1, :].argmax(-1).item() + generated.append(next_token) + + if next_token == pad_id: + break + + next_tensor = torch.tensor([[next_token]], device=device) + next_mask = torch.ones((1, 1), device=device, dtype=cur_mask.dtype) + cur_ids = torch.cat([cur_ids, next_tensor], dim=1) + cur_mask = torch.cat([cur_mask, next_mask], dim=1) + + return generated + + def val_GPT2(model_name, data_name, state_dict_full, logger): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Eval device:", device) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - pad_id = tokenizer.pad_token_id - loss_fct = nn.CrossEntropyLoss( - ignore_index=pad_id, - ) + loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='sum') test_loader = dataloader(model_name=model_name, data_name=data_name, train=False) model = GPT2() - model.load_state_dict(state_dict_full) + + pretrained_path = f"{model_name}.pt" + if os.path.exists(pretrained_path): + base_state = torch.load(pretrained_path, map_location="cpu") + missing, unexpected = model.load_state_dict(base_state, strict=False) + if missing: + logger.log_info(f"[val_GPT2] Base load missing keys: {missing}") + if unexpected: + logger.log_info(f"[val_GPT2] Base load unexpected keys: {unexpected}") + else: + logger.log_warning( + f"[val_GPT2] Pretrained file '{pretrained_path}' not found. " + f"Validating with random base weights." + ) + + if state_dict_full: + missing, unexpected = model.load_state_dict(state_dict_full, strict=False) + if missing: + logger.log_info(f"[val_GPT2] LoRA overlay missing keys: {missing}") + if unexpected: + logger.log_info(f"[val_GPT2] LoRA overlay unexpected keys: {unexpected}") + model = model.to(device) model.eval() - total_loss = 0.0 - total_tokens = 0 - - total_samples = 0 + total_loss = 0.0 + total_tokens = 0 + total_samples = 0 correct_samples = 0 + nl_id = tokenizer.encode("\n", add_special_tokens=False)[0] + is_e2e = (data_name == 'E2E') + hypotheses = [] + references = [] + with torch.no_grad(): for batch in tqdm(test_loader): input_ids = batch['input_ids'].to(device) - mask = batch['attention_mask'].to(device) - labels = batch['labels'].to(device) + mask = batch['attention_mask'].to(device) + labels = batch['labels'].to(device) - logits, _ = model(input_ids, mask) + out = model(input_ids, mask) + logits = out["hidden_states"] shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() - loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) + shift_labels.view(-1), ) - total_loss += loss.item() - - predicted_tokens = torch.argmax(logits, dim=-1) - - generated_texts = [ - tokenizer.decode(p, skip_special_tokens=True).strip() - for p in predicted_tokens - ] - reference_texts = [ - tokenizer.decode(l, skip_special_tokens=True).strip() - for l in labels - ] - - for gen, ref in zip(generated_texts, reference_texts): - if not gen: - gen = " " - if not ref: - ref = " " - # print(f'gen : {gen}') - # print(f'ref : {ref}') - # break - # break - - pred_num = extract_final_number(gen) - gold_num = extract_final_number(ref) + valid = (shift_labels != pad_id).sum().item() + total_loss += loss.item() + total_tokens += valid + + for i in range(input_ids.size(0)): + tokens = input_ids[i].tolist() + if nl_id in tokens: + cut = tokens.index(nl_id) + 1 + else: + cut = len(tokens) // 2 + + prompt_ids = input_ids[i, :cut].unsqueeze(0) + prompt_mask = mask[i, :cut].unsqueeze(0) + + generated = _greedy_generate( + model, prompt_ids, prompt_mask, + max_new_tokens=64, pad_id=pad_id, device=device + ) + + gen_text = tokenizer.decode(generated, skip_special_tokens=True).strip() + ref_text = tokenizer.decode( + labels[i][labels[i] != pad_id], skip_special_tokens=True + ).strip() total_samples += 1 - if gold_num and pred_num == gold_num: - correct_samples += 1 - - avg_loss = total_loss / max(total_samples, 1) - accuracy = correct_samples / max(total_samples, 1) - - print(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%") - logger.log_info( - f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%" - ) + if is_e2e: + # E2E: dùng BLEU/ROUGE + hypotheses.append(gen_text) + references.append(ref_text) + else: + # GSM8K: dùng exact match số + pred_num = extract_final_number(gen_text) + gold_num = extract_final_number(ref_text) + if gold_num and pred_num == gold_num: + correct_samples += 1 + + avg_loss = total_loss / max(total_tokens, 1) + + if is_e2e and BLEU_AVAILABLE and hypotheses: + bleu = corpus_bleu(hypotheses, [references]).score + scorer = rouge_scorer_lib.RougeScorer(["rougeL"], use_stemmer=True) + rouge_l = sum( + scorer.score(r, h)["rougeL"].fmeasure + for h, r in zip(hypotheses, references) + ) / len(hypotheses) + print(f"Loss / token : {avg_loss:.4f}; BLEU: {bleu:.2f}; ROUGE-L: {rouge_l:.4f}") + logger.log_info(f"Loss / token : {avg_loss:.4f}; BLEU: {bleu:.2f}; ROUGE-L: {rouge_l:.4f}") + else: + accuracy = correct_samples / max(total_samples, 1) + print(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%") + logger.log_info(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%") + + return model.state_dict() \ No newline at end of file diff --git a/src/val/Llama.py b/src/val/Llama.py index c62caec..4c6da23 100644 --- a/src/val/Llama.py +++ b/src/val/Llama.py @@ -14,10 +14,11 @@ def val_Llama(model_name, data_name, state_dict_full, logger): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Eval device:", device) - smooth = SmoothingFunction().method1 - scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) + smooth = SmoothingFunction().method1 + scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-160m") tokenizer.pad_token = tokenizer.eos_token + pad_id = tokenizer.pad_token_id test_loader = dataloader(model_name=model_name, data_name=data_name, train=False) @@ -26,74 +27,70 @@ def val_Llama(model_name, data_name, state_dict_full, logger): model = model.to(device) model.eval() - total_bleu = 0.0 - total_rouge = 0.0 - total_perplexity = 0.0 + # Fix: tạo criterion 1 lần ngoài vòng lặp (tránh tạo object mới mỗi batch) + loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='sum') + + total_bleu = 0.0 + total_rouge = 0.0 + total_log_loss = 0.0 # Fix: tích lũy log-loss thay vì perplexity trực tiếp + total_valid_tok = 0 # Fix: đếm token hợp lệ để avg chính xác count = 0 with torch.no_grad(): for batch in tqdm(test_loader): input_ids = batch['input_ids'].to(device) - mask = batch['attention_mask'].to(device) - labels = batch['input_ids'].to(device) + mask = batch['attention_mask'].to(device) + labels = batch['input_ids'].to(device) # causal LM: labels = input logits, _ = model(input_ids, mask) predicted_tokens = torch.argmax(logits, dim=-1) - - generated_texts = [ + generated_texts = [ tokenizer.decode(p, skip_special_tokens=True).strip() for p in predicted_tokens ] - reference_texts = [ + reference_texts = [ tokenizer.decode(l, skip_special_tokens=True).strip() for l in labels ] for gen, ref in zip(generated_texts, reference_texts): - if not gen: - gen = " " - if not ref: - ref = " " - - bleu_score = sentence_bleu( - [ref.split()], - gen.split(), - smoothing_function=smooth - ) + gen = gen or " " + ref = ref or " " + bleu = sentence_bleu([ref.split()], gen.split(), + smoothing_function=smooth) try: rouge_l = scorer.score(ref, gen)["rougeL"].fmeasure - except: + except Exception: rouge_l = 0.0 - total_bleu += bleu_score + total_bleu += bleu total_rouge += rouge_l count += 1 + # Fix: perplexity = exp(avg_cross_entropy_per_token) + # Tích lũy tổng cross-entropy và tổng token; tính exp 1 lần cuối shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - loss_fct = nn.CrossEntropyLoss( - ignore_index=tokenizer.pad_token_id, - reduction='sum' - ) - - loss = loss_fct( + loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) + shift_labels.view(-1), ) + valid_tokens = (shift_labels != pad_id).sum().item() + total_log_loss += loss.item() + total_valid_tok += max(valid_tokens, 1) - valid_tokens = (shift_labels != tokenizer.pad_token_id).sum().item() - perplexity = math.exp(loss.item() / max(1, valid_tokens)) - total_perplexity += perplexity + # Fix: tính perplexity từ avg log-loss (tránh exp overflow mỗi batch) + avg_log_loss = total_log_loss / max(total_valid_tok, 1) + # Clamp để tránh math overflow với exp() + avg_perplexity = math.exp(min(avg_log_loss, 20.0)) - avg_bleu = total_bleu / count - avg_rouge = total_rouge / count - avg_perplexity = total_perplexity / count + avg_bleu = total_bleu / max(count, 1) + avg_rouge = total_rouge / max(count, 1) print(f"Evaluation Results: BLEU: {avg_bleu:.4f}, ROUGE-L: {avg_rouge:.4f}, Perplexity: {avg_perplexity:.4f}") - logger.log_info( f"Evaluation Results: BLEU: {avg_bleu:.4f}, ROUGE-L: {avg_rouge:.4f}, Perplexity: {avg_perplexity:.4f}" ) diff --git a/src/val/__pycache__/Bert.cpython-313.pyc b/src/val/__pycache__/Bert.cpython-313.pyc new file mode 100644 index 0000000..f22348b Binary files /dev/null and b/src/val/__pycache__/Bert.cpython-313.pyc differ diff --git a/src/val/__pycache__/GPT2.cpython-313.pyc b/src/val/__pycache__/GPT2.cpython-313.pyc new file mode 100644 index 0000000..2b4277b Binary files /dev/null and b/src/val/__pycache__/GPT2.cpython-313.pyc differ diff --git a/src/val/__pycache__/Llama.cpython-313.pyc b/src/val/__pycache__/Llama.cpython-313.pyc new file mode 100644 index 0000000..9d068d7 Binary files /dev/null and b/src/val/__pycache__/Llama.cpython-313.pyc differ diff --git a/src/val/__pycache__/__init__.cpython-313.pyc b/src/val/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..3cc69da Binary files /dev/null and b/src/val/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/val/__pycache__/get_val.cpython-313.pyc b/src/val/__pycache__/get_val.cpython-313.pyc new file mode 100644 index 0000000..54c88b2 Binary files /dev/null and b/src/val/__pycache__/get_val.cpython-313.pyc differ diff --git a/src/val/get_val.py b/src/val/get_val.py index 700c749..9bfc08c 100644 --- a/src/val/get_val.py +++ b/src/val/get_val.py @@ -2,11 +2,22 @@ from src.val.Llama import val_Llama from src.val.Bert import val_Bert + def get_val(model_name, data_name, state_dict_full, logger): - if model_name == 'GPT2': - val_GPT2(model_name, data_name, state_dict_full, logger) - elif model_name == 'Llama': - val_Llama(model_name, data_name, state_dict_full, logger) - elif model_name == 'Bert': - val_Bert(model_name, data_name, state_dict_full, logger) - return True + """ + Chạy validation và trả về full state dict nếu thành công, None nếu lỗi. + Server dùng full state dict này để lưu vào GPT2.pt cho round tiếp theo. + """ + try: + if model_name == 'GPT2': + return val_GPT2(model_name, data_name, state_dict_full, logger) + elif model_name == 'Llama': + return val_Llama(model_name, data_name, state_dict_full, logger) + elif model_name == 'Bert': + return val_Bert(model_name, data_name, state_dict_full, logger) + else: + logger.log_warning(f"Unknown model_name '{model_name}' for validation.") + return None + except Exception as e: + logger.log_error(f"Validation failed with exception: {e}") + return None \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..3c58111 --- /dev/null +++ b/test.py @@ -0,0 +1,177 @@ +""" +test.py — Evaluate GPT2 checkpoints sau khi train SplitFedLLM + +Usage: + python test.py # test GPT2.pt + python test.py --model GPT2_round3.pt + python test.py --model GPT2.pt --device cpu +""" +import argparse +import os +import torch +import torch.nn as nn +from transformers import GPT2Tokenizer +from src.model.GPT2 import GPT2 + +# ── Args ────────────────────────────────────────────────────────────────────── +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="GPT2.pt", help="Path to .pt checkpoint") +parser.add_argument("--device", default=None, help="cpu / cuda (auto-detect nếu bỏ qua)") +parser.add_argument("--tokens", default=80, type=int, help="Max new tokens khi generate") +args = parser.parse_args() + +DEVICE = torch.device( + args.device if args.device + else ("cuda" if torch.cuda.is_available() else "cpu") +) +print(f"[INFO] Device: {DEVICE}") + +TEST_PROMPTS = [ + " name[The Golden Curry], food[Fast food], customer rating[low], area[riverside], familyFriendly[yes], near[Café Rouge] ", + " name[Fitzbillies], eatType[coffee shop], food[French], priceRange[£20-25], customer rating[3 out of 5] ", + " name[The Twenty Two], eatType[restaurant], food[Italian], familyFriendly[no] ", + " name[Cotto], eatType[coffee shop], food[Indian], priceRange[moderate], area[riverside], near[The Portland Arms] ", + " name[Giraffe], eatType[pub], food[Fast food], area[city centre], familyFriendly[no] ", +] + +EVAL_PAIRS = [ + ( + " name[The Golden Curry], food[Fast food] ", + "The Golden Curry is a fast food place with a low customer rating located near Café Rouge." + ), + ( + " name[Fitzbillies], eatType[coffee shop], food[French] ", + "Fitzbillies is a French coffee shop with a customer rating of 3 out of 5." + ), +] + +# ── Load model ──────────────────────────────────────────────────────────────── +def load_model(path: str): + if not os.path.exists(path): + raise FileNotFoundError(f"Không tìm thấy: {path}") + + print(f"\n[INFO] Loading {path} ({os.path.getsize(path)/1e6:.1f} MB)") + sd = torch.load(path, map_location="cpu") + + model = GPT2() + missing, unexpected = model.load_state_dict(sd, strict=False) + print(f"[DEBUG] Missing keys : {len(missing)}") + print(f"[DEBUG] Unexpected keys: {len(unexpected)}") + + if hasattr(model, "lm_head") and hasattr(model, "wte"): + model.lm_head.weight = model.wte.weight + print("[OK] Weight tying applied") + + model.to(DEVICE) + model.eval() + return model + +# ── Generate ────────────────────────────────────────────────────────────────── +def generate(model, tokenizer, prompt: str, max_new_tokens: int = 80) -> str: + """Top-k sampling + repetition penalty.""" + ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) + prompt_len = ids.shape[1] + generated = [] + + for _ in range(max_new_tokens): + with torch.no_grad(): + out = model(input_ids=ids) + logits = out["logits"][:, -1, :].clone() # (1, vocab) + + # Repetition penalty mạnh hơn (1.5 thay vì 1.3) + for tid in set(ids[0].tolist()): + logits[0, tid] /= 1.5 + + # Penalty riêng cho các từ lặp gần đây (context 20 tokens cuối) + recent = ids[0, -20:].tolist() + for tid in set(recent): + logits[0, tid] /= 1.3 + + # Temperature + top-k (top-k=40 để output tập trung hơn) + logits /= 0.7 + topk_vals, topk_idx = torch.topk(logits, 40) + probs = torch.softmax(topk_vals, dim=-1) + next_token = topk_idx[0, torch.multinomial(probs, 1)] + + if next_token.item() == tokenizer.eos_token_id: + break + + generated.append(next_token.item()) + ids = torch.cat([ids, next_token.view(1, 1)], dim=1) + + output = tokenizer.decode(generated, skip_special_tokens=True).strip() + return output + +# ── Loss / Perplexity ───────────────────────────────────────────────────────── +def compute_loss(model, tokenizer, pairs: list) -> tuple: + """ + pairs: list of (prompt, reference) tuples. + Tính loss chỉ trên phần reference (không tính prompt), giống notebook. + """ + criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum") + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for prompt, ref in pairs: + full_text = prompt + " " + ref + full_ids = tokenizer.encode(full_text, return_tensors="pt").to(DEVICE) + prompt_len = len(tokenizer.encode(prompt)) + + out = model(input_ids=full_ids) + logits = out["logits"] # (1, seq, vocab) + + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = full_ids[:, 1:].contiguous().clone() + + # Mask phần prompt: chỉ tính loss trên phần reference + shift_labels[:, :prompt_len - 1] = -100 + + loss = criterion( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + n_ref_tokens = (shift_labels != -100).sum().item() + total_loss += loss.item() + total_tokens += n_ref_tokens + + avg_loss = total_loss / max(total_tokens, 1) + ppl = torch.exp(torch.tensor(avg_loss)).item() + return avg_loss, ppl + +# ── Sanity check ────────────────────────────────────────────────────────────── +def sanity_check(model, tokenizer): + print("\n[Sanity check]") + ids = tokenizer.encode("The restaurant", return_tensors="pt").to(DEVICE) + with torch.no_grad(): + out = model(input_ids=ids) + print("[DEBUG] Output keys:", list(out.keys())) + if "logits" not in out: + raise RuntimeError("Model không trả về 'logits' → kiểm tra GPT2.forward()") + topk = torch.topk(out["logits"][0, -1], 5) + print("[DEBUG] Top next tokens:", [tokenizer.decode([i]) for i in topk.indices.tolist()]) + +# ── Main ────────────────────────────────────────────────────────────────────── +def main(): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + + model = load_model(args.model) + sanity_check(model, tokenizer) + + # ── Generation ──────────────────────────────────────────────────────────── + print("\n" + "=" * 60) + for prompt in TEST_PROMPTS: + output = generate(model, tokenizer, prompt, max_new_tokens=args.tokens) + print(f"\nPROMPT:\n{prompt}") + print(f"OUTPUT:\n{output}") + print("-" * 60) + + # ── Loss / Perplexity ───────────────────────────────────────────────────── + loss, ppl = compute_loss(model, tokenizer, EVAL_PAIRS) + print("\n" + "=" * 60) + print(f"[METRIC] Loss: {loss:.4f} | Perplexity: {ppl:.2f}") + print("=" * 60) + +if __name__ == "__main__": + main() \ No newline at end of file