diff --git a/app.log b/app.log
new file mode 100644
index 0000000..e688b87
--- /dev/null
+++ b/app.log
@@ -0,0 +1,22 @@
+2026-03-24 14:58:30,184 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-24 15:01:29,011 - my_logger - INFO - Start training round 1
+2026-03-24 22:44:03,514 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-24 22:45:14,428 - my_logger - INFO - Start training round 1
+2026-03-25 09:13:23,571 - my_logger - INFO - Round 1 complete. Waiting for READY.
+2026-03-25 09:13:29,550 - my_logger - INFO - Round 1 fully complete.
+2026-03-25 09:13:32,331 - my_logger - INFO - Start training round 2
+2026-03-25 19:44:00,127 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-25 19:45:03,466 - my_logger - INFO - Start training round 1
+2026-03-25 19:47:07,712 - my_logger - INFO - Start training round 1
+2026-03-25 19:47:19,011 - my_logger - INFO - Start training round 1
+2026-03-25 19:49:41,646 - my_logger - INFO - Start training round 1
+2026-03-25 19:52:49,376 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-25 19:54:50,514 - my_logger - INFO - Start training round 1
+2026-03-25 19:56:16,092 - my_logger - INFO - Start training round 1
+2026-03-25 19:57:18,773 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-25 19:57:49,168 - my_logger - INFO - Start training round 1
+2026-03-25 20:05:50,761 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
+2026-03-25 20:06:20,483 - my_logger - INFO - Start training round 1
+2026-03-26 04:53:12,204 - my_logger - INFO - Round 1 complete. Waiting for READY.
+2026-03-26 04:53:17,911 - my_logger - INFO - Round 1 fully complete.
+2026-03-26 04:53:20,860 - my_logger - INFO - Start training round 2
diff --git a/client - Copy.py b/client - Copy.py
new file mode 100644
index 0000000..1f3e7b7
--- /dev/null
+++ b/client - Copy.py
@@ -0,0 +1,50 @@
+import pika
+import uuid
+import argparse
+import yaml
+import os
+
+import torch
+
+import src.Log
+from src.RpcClient import RpcClient
+
+parser = argparse.ArgumentParser(description="Split learning framework")
+parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1')
+parser.add_argument('--device', type=str, required=False, help='Device of client')
+
+args = parser.parse_args()
+
+with open('config.yaml', 'r') as file:
+ config = yaml.safe_load(file)
+
+client_id = uuid.uuid4()
+address = config["rabbit"]["address"]
+username = config["rabbit"]["username"]
+password = config["rabbit"]["password"]
+virtual_host = config["rabbit"]["virtual-host"]
+
+device = None
+if args.device is None:
+ if torch.cuda.is_available():
+ device = "cuda"
+ print(f"Using device: {torch.cuda.get_device_name(device)}")
+ else:
+ device = "cpu"
+ print(f"Using device: CPU")
+else:
+ device = args.device
+ print(f"Using device: {device}")
+
+credentials = pika.PlainCredentials(username, password)
+connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
+channel = connection.channel()
+
+if __name__ == "__main__":
+ src.Log.print_with_color("[>>>] Client sending registration message to server...", "red")
+
+ data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"}
+ client = RpcClient(client_id, args.layer_id, channel, device)
+ client.send_to_server(data)
+ client.wait_response()
+
diff --git a/client.py b/client.py
index 1f3e7b7..2d7f9ce 100644
--- a/client.py
+++ b/client.py
@@ -15,9 +15,8 @@
args = parser.parse_args()
-with open('config.yaml', 'r') as file:
+with open("config.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
-
client_id = uuid.uuid4()
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
@@ -37,7 +36,17 @@
print(f"Using device: {device}")
credentials = pika.PlainCredentials(username, password)
-connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
+# FIX: heartbeat=0 tắt timeout, tránh StreamLostError khi train lâu
+connection = pika.BlockingConnection(
+ pika.ConnectionParameters(
+ host=address,
+ port=5672,
+ virtual_host=f'{virtual_host}',
+ credentials=credentials,
+ heartbeat=0,
+ blocked_connection_timeout=None,
+ )
+)
channel = connection.channel()
if __name__ == "__main__":
@@ -46,5 +55,4 @@
data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"}
client = RpcClient(client_id, args.layer_id, channel, device)
client.send_to_server(data)
- client.wait_response()
-
+ client.wait_response()
\ No newline at end of file
diff --git a/config.yaml b/config.yaml
index 6189867..203755e 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,12 +1,14 @@
name: SplitFedLLM
+
server:
- global-round: 1
+ global-round: 10
clients:
- 1
- 1
cut-layers: 4
- model-name: Bert # GPT2/Llama/Bert
- data-name: EMOTION # EMOTION/GSM8K
+ model-name: GPT2 # GPT2 / Llama / Bert
+ data-name: E2E # E2E / EMOTION / AG_NEWS / GSM8K
+ pretrained_path: GPT2.pt
model:
GPT2:
n_block: 12
@@ -17,15 +19,15 @@ server:
parameters:
load: True
save: True
- validation: True
+ validation: False
data-distribution:
non-iid: False
- num-sample: 500
- num-label: 10
+ num-sample: 10000
+ num-label: 1
dirichlet:
alpha: 1
refresh-each-round: True
- random-seed: 1
+ random-seed: 42
rabbit:
address: 127.0.0.1
@@ -34,14 +36,14 @@ rabbit:
virtual-host: /
log_path: .
-debug_mode: True
+debug_mode: False
learning:
- learning-rate: 0.00001
+ learning-rate: 0.00005
weight-decay: 0.01
- batch-size: 2
- control-count: 1
- clip-grad-norm: 0.0
+ batch-size: 8
+ control-count: 2
+ clip-grad-norm: 1.0
fine-tune:
enable: True
@@ -49,3 +51,14 @@ fine-tune:
LoRA:
r: 8
alpha: 16
+ QLoRA:
+ r: 8
+ alpha: 16
+ bits: 4
+ double_quant: True
+
+optimization:
+ flash_attention: False
+ precision: fp32
+ quantize_hidden: False
+ gradient_checkpointing: True
\ No newline at end of file
diff --git a/convert.py b/convert.py
new file mode 100644
index 0000000..5ab2029
--- /dev/null
+++ b/convert.py
@@ -0,0 +1,23 @@
+import torch
+from transformers import GPT2LMHeadModel
+
+# load HF model
+hf_model = GPT2LMHeadModel.from_pretrained("./gpt2_e2e_finetuned")
+
+hf_sd = hf_model.state_dict()
+new_sd = {}
+
+for k, v in hf_sd.items():
+ new_k = k
+
+ # 🔥 remove prefix
+ if k.startswith("transformer."):
+ new_k = k.replace("transformer.", "")
+
+ # 🔥 lm_head giữ nguyên
+ new_sd[new_k] = v
+
+# save
+torch.save(new_sd, "GPT2.pt")
+
+print("Converted → GPT2.pt")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index b8f6635..efa0dbc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,19 @@
-torch
+torch>=2.0.0
pika~=1.3.2
transformers>=4.36.2
datasets
-peft
+peft>=0.9.0
numpy
nltk
rouge_score
+
+# ── Tối ưu hóa ────────────────────────────────────────────────
+# QLoRA + INT8: quantize base model xuống 4-bit / 8-bit
+bitsandbytes>=0.43.0
+
+# Flash Attention 2: cần GPU Ampere+ (RTX 30xx, A100, H100,...)
+# Cài thủ công nếu cần: pip install flash-attn --no-build-isolation
+flash-attn>=2.5.0; sys_platform != "win32"
+
+# Accelerate: dùng chung với bitsandbytes & gradient checkpointing
+accelerate>=0.27.0
diff --git a/server.py b/server.py
index 4fa6369..9a424d3 100644
--- a/server.py
+++ b/server.py
@@ -10,7 +10,7 @@
args = parser.parse_args()
-with open('config.yaml') as file:
+with open("config.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
diff --git a/src/Optimizer.py b/src/Optimizer.py
new file mode 100644
index 0000000..9adf242
--- /dev/null
+++ b/src/Optimizer.py
@@ -0,0 +1,321 @@
+
+
+import torch
+import torch.nn as nn
+import numpy as np
+from contextlib import contextmanager
+from transformers.pytorch_utils import Conv1D
+
+try:
+ from peft import (
+ LoraConfig, TaskType, get_peft_model,
+ prepare_model_for_kbit_training,
+ )
+ from transformers import BitsAndBytesConfig
+ HAS_PEFT = True
+except ImportError:
+ HAS_PEFT = False
+
+try:
+ import bitsandbytes as bnb
+ HAS_BNB_PKG = True
+except ImportError:
+ HAS_BNB_PKG = False
+
+
+def detect_fan_in_fan_out(model):
+ for module in model.modules():
+ if isinstance(module, Conv1D):
+ return True
+ return False
+
+def build_qlora_config(qlora_cfg: dict,model, model_name: str):
+
+ if not HAS_BNB_PKG or not HAS_PEFT:
+ return None, None
+
+ bits = qlora_cfg.get("bits", 4)
+ double_quant = qlora_cfg.get("double_quant", True)
+ r = qlora_cfg.get("r", 8)
+ alpha = qlora_cfg.get("alpha", 16)
+
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=(bits == 4),
+ load_in_8bit=(bits == 8),
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_use_double_quant=double_quant,
+ bnb_4bit_compute_dtype = (
+ torch.bfloat16 if torch.cuda.is_available() else torch.float16
+),
+ )
+
+ target_map = {
+ "GPT2": ["c_attn", "c_proj", "c_fc"],
+ "Llama": ["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj"],
+ "Bert": ["query", "key", "value", "dense"],
+ }
+ targets = target_map.get(model_name, ["query", "value"])
+
+ task = TaskType.SEQ_CLS if model_name == "Bert" else TaskType.CAUSAL_LM
+ fan_in = detect_fan_in_fan_out(model)
+
+ lora_config = LoraConfig(
+ task_type=task,
+ r=r,
+ lora_alpha=alpha,
+ lora_dropout=0.05,
+ bias="none",
+ target_modules=targets,
+ fan_in_fan_out=fan_in,
+ )
+ return bnb_config, lora_config
+
+
+def apply_qlora(model, lora_config, model_name: str):
+
+ if not HAS_PEFT:
+ return model
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=True
+ )
+ model = get_peft_model(model, lora_config)
+ # Bert cần classifier trainable
+ if model_name == "Bert":
+ for param in model.classifier.parameters():
+ param.requires_grad = True
+ model.print_trainable_parameters()
+ return model
+
+
+
+import torch.nn.functional as F
+
+def flash_scaled_dot_product(q, k, v, mask=None, dropout_p=0.0):
+
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ # fallback thủ công (giữ nguyên của bạn)
+ import math
+ scale = math.sqrt(q.size(-1))
+ q_fp32 = q.float()
+ k_fp32 = k.float()
+ v_fp32 = v.float()
+ att = (q_fp32 @ k_fp32.transpose(-2, -1)) / scale
+ if mask is not None:
+ fill_val = torch.finfo(att.dtype).min / 2
+ att = att.masked_fill(mask == 0, fill_val)
+ att = torch.softmax(att, dim=-1)
+ if dropout_p > 0.0:
+ att = F.dropout(att, p=dropout_p, training=True)
+ out = att @ v_fp32
+ return out.to(q.dtype)
+
+ # 👉 convert mask sang đúng format SDPA
+ attn_mask = None
+ if mask is not None:
+ attn_mask = ~mask.bool() # SDPA dùng attn_mask với True = masked
+
+ return F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=mask is None
+ )
+
+@contextmanager
+def precision_context(precision: str, device: str = "cuda"):
+
+ if not torch.cuda.is_available() or precision == "fp32":
+ yield
+ return
+
+ dtype_map = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ }
+
+ if precision in dtype_map:
+ with torch.autocast(device_type="cuda", dtype=dtype_map[precision]):
+ yield
+ elif precision == "int8":
+
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
+ yield
+ else:
+ yield
+
+
+def cast_model_precision(model: nn.Module, precision: str) -> nn.Module:
+ """
+ Đổi dtype toàn bộ tham số model (dùng khi không có QLoRA).
+ int8 dùng bitsandbytes LinearInt8 thay thế nn.Linear.
+ """
+ if precision == "fp16":
+ return model.half()
+ if precision == "bf16":
+ return model.to(torch.bfloat16)
+ if precision == "int8":
+ if HAS_BNB_PKG:
+ return _replace_linear_int8(model)
+ else:
+ print("[Optimizer] bitsandbytes chưa cài")
+ return model # fp32 mặc định
+
+
+def _replace_linear_int8(model: nn.Module) -> nn.Module:
+
+ for name, module in model.named_children():
+ if isinstance(module, nn.Linear):
+ has_bias = module.bias is not None
+ new_layer = bnb.nn.Linear8bitLt(
+ module.in_features,
+ module.out_features,
+ bias=has_bias,
+ has_fp16_weights=False,
+ threshold=6.0, # LLM.int8() outlier threshold
+ )
+ new_layer.weight = nn.Parameter(module.weight.data)
+ if has_bias:
+ new_layer.bias = nn.Parameter(module.bias.data)
+ setattr(model, name, new_layer)
+ else:
+ _replace_linear_int8(module)
+ return model
+
+
+
+
+def quantize_hidden(tensor: torch.Tensor) -> tuple:
+ """
+ Nén hidden state từ FP16/FP32 → INT8 trước khi pickle + gửi.
+ Dùng symmetric per-tensor quantization:
+ scale = max(|x|) / 127
+ q = round(x / scale).clamp(-127, 127).to(int8)
+
+ Trả về (q_numpy: np.ndarray[int8], scale: float)
+ Kích thước giảm ~4x so với float32, ~2x so với float16.
+ """
+ t = tensor.detach().float()
+ scale = t.abs().max().item() / 127.0
+ if scale == 0.0:
+ scale = 1e-8
+ q = (t / scale).round().clamp(-127, 127).to(torch.int8)
+ return q.cpu().numpy(), scale
+
+
+def dequantize_hidden(q_numpy: np.ndarray, scale: float,
+ device: str, requires_grad: bool = False) -> torch.Tensor:
+ """
+ Giải nén INT8 → FP32 / FP16 rồi đẩy lên đúng device.
+ """
+ q = torch.from_numpy(q_numpy).float() * scale
+ q = q.to(device)
+ if requires_grad:
+ q.requires_grad_(True)
+ return q
+
+
+
+def enable_gradient_checkpointing(model: nn.Module) -> nn.Module:
+ """
+ Bật gradient checkpointing cho bất kỳ model nào có method
+ gradient_checkpointing_enable() (HuggingFace convention),
+ hoặc tự áp dụng qua torch.utils.checkpoint.
+ """
+ if hasattr(model, "gradient_checkpointing_enable"):
+ model.gradient_checkpointing_enable()
+ print("[Optimizer] gradient_checkpointing_enable() called.")
+ elif hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ # Wrap từng TransformerBlock / BertLayer / DecoderLayer
+ _wrap_checkpointing(model)
+ return model
+
+
+def _wrap_checkpointing(model: nn.Module):
+ """
+ Với custom model không phải HuggingFace, tìm và wrap các block
+ có tên chứa 'layer', 'block', 'decoder'.
+ """
+ from torch.utils.checkpoint import checkpoint as ckpt
+
+ KEYWORDS = ("layer", "block", "decoder")
+
+ for name, module in model.named_children():
+ if any(k in name.lower() for k in KEYWORDS):
+ orig_forward = module.forward
+
+ def make_checkpointed(fwd=orig_forward):
+ def checkpointed_forward(*args, **kwargs):
+ def run(*a, **kw):
+ return fwd(*a, **kw)
+ return ckpt(run, *args, **kwargs, use_reentrant=False)
+ return checkpointed_forward
+
+ module.forward = make_checkpointed()
+ else:
+ _wrap_checkpointing(module)
+
+
+
+def make_scaler(precision: str):
+ """
+ Trả về GradScaler nếu precision == 'fp16', ngược lại None.
+ Dùng cùng optimizer.step():
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ """
+ if precision == "fp16" and torch.cuda.is_available():
+ return torch.amp.GradScaler("cuda")
+ return None
+
+
+class OptimizationBundle:
+ """
+ Gói tất cả cấu hình tối ưu vào một object duy nhất để
+ truyền qua RpcClient → Ft_* một cách gọn gàng.
+ """
+ def __init__(self, opt_cfg: dict):
+
+ self.precision = opt_cfg.get("precision", "fp32")
+ self.quantize_hidden = opt_cfg.get("quantize_hidden", False)
+ self.gradient_checkpointing = opt_cfg.get("gradient_checkpointing", False)
+ self.scaler = make_scaler(self.precision)
+ self.flash_attention = opt_cfg.get("flash_attention", False)
+ def precision_ctx(self, device="cuda"):
+ return precision_context(self.precision, device)
+
+ def quant(self, tensor: torch.Tensor):
+ """Nén hidden state nếu được bật."""
+ if self.quantize_hidden:
+ return quantize_hidden(tensor)
+ # Trả về (numpy fp16, scale=None) để interface nhất quán
+ return tensor.detach().cpu().to(torch.float16).numpy(), None
+
+ def dequant(self, q_numpy, scale, device, requires_grad=False):
+ """Giải nén hidden state."""
+ if scale is not None:
+ return dequantize_hidden(q_numpy, scale, device, requires_grad)
+ # Không quantize: chỉ convert từ fp16
+ t = torch.from_numpy(q_numpy.astype(np.float16)).to(device)
+ if requires_grad:
+ t.requires_grad_(True)
+ return t
+
+ def step(self, loss, optimizer):
+ """
+ Thực hiện backward + optimizer step,
+ tự động dùng GradScaler nếu FP16.
+ """
+ if self.scaler is not None:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(optimizer)
+ self.scaler.update()
+ else:
+ loss.backward()
+ optimizer.step()
diff --git a/src/RpcClient.py b/src/RpcClient.py
index 5657230..0b8e0e5 100644
--- a/src/RpcClient.py
+++ b/src/RpcClient.py
@@ -1,7 +1,8 @@
import time
import pickle
-import copy
-
+import re
+import os
+import torch
import src.Log
from src.fine_tune.GPT2 import Ft_GPT2
from src.fine_tune.Llama import Ft_Llama
@@ -10,150 +11,332 @@
from src.model.GPT2 import GPT2
from src.model.Llama import Llama
from src.model.Bert import Bert
-
+from src.Optimizer import (
+ OptimizationBundle,
+ build_qlora_config, apply_qlora,
+ cast_model_precision,
+ enable_gradient_checkpointing,
+)
from peft import LoraConfig, TaskType, get_peft_model
+
+def _build_lora_config(fine_tune_config: dict, model_name: str):
+ r = fine_tune_config["LoRA"]["r"]
+ alpha = fine_tune_config["LoRA"]["alpha"]
+ if model_name == "GPT2":
+ return LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=r, lora_alpha=alpha, lora_dropout=0.05, bias="none",
+ target_modules=["c_attn", "c_proj", "c_fc"],
+ fan_in_fan_out=True,
+ )
+ elif model_name == "Llama":
+ return LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=r, lora_alpha=alpha, lora_dropout=0.05, bias="none",
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj"],
+ )
+ elif model_name == "Bert":
+ return LoraConfig(
+ task_type="SEQ_CLS",
+ r=r, lora_alpha=alpha, lora_dropout=0.1, bias="none",
+ target_modules=["query", "key", "value", "dense"],
+ )
+ return None
+
+
+def _merge_lora_into_base(base_sd: dict, lora_sd: dict,
+ lora_r: int, lora_alpha: int,
+ model_name: str = "GPT2") -> dict:
+ """
+ - GPT2 (Conv1D): W=(in,out), A=(r,in), B=(out,r) → delta=(B@A).T * scale
+ - Llama/Bert (Linear): W=(out,in), A=(r,in), B=(out,r) → delta=(B@A) * scale
+ """
+ merged = dict(base_sd)
+ scale = lora_alpha / lora_r
+ use_transpose = (model_name == "GPT2") # BUG 3 FIX
+
+ lora_A, lora_B = {}, {}
+ for k, v in lora_sd.items():
+ m = re.match(r'base_model\.model\.(.+)\.lora_([AB])\.default\.weight', k)
+ if m:
+ base_key = m.group(1) + ".weight"
+ if m.group(2) == "A":
+ lora_A[base_key] = v.float()
+ else:
+ lora_B[base_key] = v.float()
+
+ applied = 0
+ for base_key in lora_A:
+ if base_key in lora_B and base_key in merged:
+ A = lora_A[base_key] # (r, in)
+ B = lora_B[base_key] # (out, r)
+ W = merged[base_key].float()
+ delta = (B @ A).T * scale if use_transpose else (B @ A) * scale
+ merged[base_key] = (W + delta).to(merged[base_key].dtype)
+ applied += 1
+
+ if "lm_head.weight" not in merged and "wte.weight" in merged:
+ merged["lm_head.weight"] = merged["wte.weight"].clone()
+
+ src.Log.print_with_color(f"[LoRA] Applied {applied} deltas vào base weights.", "green")
+ return merged
+
+
class RpcClient:
def __init__(self, client_id, layer_id, channel, device):
- self.client_id = client_id
- self.layer_id = layer_id
- self.channel = channel
- self.model_train = None
- self.train_loader = None
- self.device = device
-
- self.response = None
- self.label_count = None
+ self.client_id = client_id
+ self.layer_id = layer_id
+ self.channel = channel
+ self.device = device
+ result = self.channel.queue_declare(queue="", exclusive=True)
+ self.callback_queue = result.method.queue
+ self.model_train = None
+ self.train_loader = None
+ self.response = None
+ self._refresh_loader = False
def wait_response(self):
- status = True
- reply_queue_name = f'reply_{self.client_id}'
- self.channel.queue_declare(reply_queue_name, durable=False)
- while status:
- method_frame, header_frame, body = self.channel.basic_get(queue=reply_queue_name, auto_ack=True)
+ reply_queue_name = f"reply_{self.client_id}"
+ self.channel.queue_declare(queue=reply_queue_name, durable=False)
+ while True:
+ method_frame, _, body = self.channel.basic_get(
+ queue=reply_queue_name, auto_ack=True
+ )
if body:
status = self.response_message(body)
- time.sleep(0.5)
+ if not status:
+ break
+ time.sleep(0.1)
def response_message(self, body):
self.response = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Client received: {self.response['message']}", "blue")
+ src.Log.print_with_color(
+ f"[<<<] Client received: {self.response['message']}", "blue"
+ )
action = self.response["action"]
- state_dict = self.response["parameters"]
-
if action == "START":
- model = None
- model_name = self.response["model_name"]
- cut_layers = self.response['cut_layers']
- # label_count = self.response['label_count']
- total_block = self.response['total_block']
- clip_grad_norm = self.response['clip_grad_norm']
- data_name = self.response["data_name"]
- num_sample = self.response["num_sample"]
- fine_tune_config = self.response['fine_tune_config']
-
- batch_size = self.response["batch_size"]
- lr = self.response["lr"]
- weight_decay = self.response["weight_decay"]
- control_count = self.response["control_count"]
-
- if model_name == 'GPT2':
- self.model_train = Ft_GPT2(self.client_id, self.layer_id, self.channel, self.device)
- elif model_name == 'Llama':
- self.model_train = Ft_Llama(self.client_id, self.layer_id, self.channel, self.device)
- elif model_name == 'Bert':
- self.model_train = Ft_Bert(self.client_id, self.layer_id, self.channel, self.device)
-
- if fine_tune_config['name'] == 'LoRA':
- if model_name == 'GPT2':
- peft_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM,
- r=fine_tune_config['LoRA']['r'], lora_alpha=fine_tune_config['LoRA']['alpha'], lora_dropout=0.05, bias="none",
- target_modules=["c_attn", "c_proj", "c_fc"],
- fan_in_fan_out=True
- )
- elif model_name == 'Llama':
- peft_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM,
- r=fine_tune_config['LoRA']['r'], lora_alpha=fine_tune_config['LoRA']['alpha'], lora_dropout=0.05, bias="none",
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
- )
- elif model_name == 'Bert':
- peft_config = LoraConfig(
- task_type="SEQ_CLS",
- r=8,lora_alpha=16,lora_dropout=0.1,bias="none",
- target_modules=["query", "key", "value", "dense"]
- )
- else:
- peft_config = None
- else:
- peft_config = None
-
- if model_name == 'GPT2':
- klass = GPT2
- elif model_name == 'Llama':
- klass = Llama
- elif model_name == 'Bert':
- klass = Bert
- else:
- klass = globals()[f'GPT2']
+ return self._handle_start()
+ elif action == "STOP":
+ return False
+ return True
- if self.layer_id == 1:
- model = klass(layer_id=1, n_block=cut_layers)
+ def _handle_start(self):
+ resp = self.response
+ model_name = resp["model_name"]
+ cut_layers = resp["cut_layers"]
+ total_block = resp["total_block"]
+ clip_grad_norm = resp["clip_grad_norm"]
+ data_name = resp["data_name"]
+ num_sample = resp["num_sample"]
+ fine_tune_config = resp["fine_tune_config"]
+ opt_cfg = resp.get("opt_config", {})
+ batch_size = resp["batch_size"]
+ lr = resp["lr"]
+ weight_decay = resp["weight_decay"]
+ control_count = resp["control_count"]
+ refresh_each_round = resp.get("refresh_each_round", False)
+
+ opt = OptimizationBundle(opt_cfg)
+ use_flash = getattr(opt, "flash_attention", False)
+ if use_flash and self.device == "cpu":
+ use_flash = False
+
+ klass_map = {"GPT2": GPT2, "Llama": Llama, "Bert": Bert}
+ klass = klass_map.get(model_name, GPT2)
+ if self.layer_id == 1:
+ model = klass(layer_id=1, n_block=cut_layers, use_flash=use_flash)
+ else:
+ model = klass(layer_id=2, n_block=total_block - cut_layers, use_flash=use_flash)
+
+ # ── 2. Load weights từ local file riêng của từng layer ───────────────
+
+ layer_file = f"{model_name}_layer{self.layer_id}.pt"
+ pretrained_file = f"{model_name}.pt"
+ base_file = f"{model_name}_base.pt"
+ load_file = layer_file if os.path.exists(layer_file) else pretrained_file
+
+ # ACCUMULATION FIX: Snapshot base weights 1 lần duy nhất.
+ # Mọi round đều merge LoRA vào base gốc, không cộng dồn delta.
+ if not os.path.exists(base_file) and os.path.exists(pretrained_file):
+ import shutil
+ shutil.copy2(pretrained_file, base_file)
+ src.Log.print_with_color(f"[INFO] Created base snapshot: {base_file}", "green")
+
+ if os.path.exists(load_file):
+ src.Log.print_with_color(
+ f"[INFO] Layer {self.layer_id}: Loading {load_file}", "green"
+ )
+ full_sd = torch.load(load_file, map_location="cpu")
+ new_sd = {}
+
+ if self.layer_id == 1:
+ for k, v in full_sd.items():
+ if k.startswith("h."):
+ idx = int(k.split(".")[1])
+ if idx < cut_layers:
+ new_sd[k] = v
+ elif k.startswith(("wte", "wpe")):
+ new_sd[k] = v
else:
- model = klass(layer_id=2, n_block=total_block-cut_layers)
+ for k, v in full_sd.items():
+ if k.startswith("h."):
+ idx = int(k.split(".")[1])
+ if idx >= cut_layers:
+ new_k = k.replace(f"h.{idx}", f"h.{idx - cut_layers}")
+ new_sd[new_k] = v
+ elif k.startswith(("ln_f", "lm_head")):
+ new_sd[k] = v
- # Read parameters and load to model
- if state_dict:
- model.load_state_dict(state_dict)
+ missing, unexpected = model.load_state_dict(new_sd, strict=False)
+ src.Log.print_with_color(
+ f"[DEBUG] Layer {self.layer_id} missing={len(missing)} "
+ f"unexpected={len(unexpected)}", "blue"
+ )
+ else:
+ src.Log.print_with_color(
+ f"[WARN] {load_file} not found — using random init", "yellow"
+ )
- if fine_tune_config['enable']:
- if model_name == 'Bert':
- model = get_peft_model(model, peft_config)
- if self.layer_id == 2:
- for param in model.classifier.parameters():
- param.requires_grad = True
+ # LoRA adapter
+ ft_name = fine_tune_config.get("name", "LoRA")
+ peft_config = None
+
+ if fine_tune_config.get("enable", False):
+ if ft_name == "QLoRA":
+ _, peft_config = build_qlora_config(
+ fine_tune_config["QLoRA"], model, model_name
+ )
+ if peft_config:
+ model = apply_qlora(model, peft_config, model_name)
else:
+ peft_config = _build_lora_config(fine_tune_config, model_name)
+ if peft_config:
+ model = get_peft_model(model, peft_config)
+ else:
+ peft_config = _build_lora_config(fine_tune_config, model_name)
+ if peft_config:
model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
- model.to(self.device)
+ if model_name == "Bert" and self.layer_id == 2:
+ for p in model.classifier.parameters():
+ p.requires_grad = True
+ model.print_trainable_parameters()
- # Start training
- if self.layer_id == 1:
- if self.train_loader is None:
- self.train_loader = dataloader(model_name, data_name, batch_size, num_sample, train=True)
+ if ft_name != "QLoRA":
+ model = cast_model_precision(model, opt.precision)
+ if opt.gradient_checkpointing:
+ model = enable_gradient_checkpointing(model)
+
+ model.to(self.device)
+
+ ft_map = {"GPT2": Ft_GPT2, "Llama": Ft_Llama, "Bert": Ft_Bert}
+ self.model_train = ft_map.get(model_name, Ft_GPT2)(
+ self.client_id, self.layer_id, self.channel, self.device
+ )
+
+ if refresh_each_round:
+ self.train_loader = None
+
+ if self.layer_id == 1 and self.train_loader is None:
+ self.train_loader = dataloader(
+ model_name, data_name, batch_size, num_sample, train=True
+ )
- result, size = self.model_train.first_layer(model, lr, weight_decay, clip_grad_norm,
- control_count, self.train_loader)
+ if self.layer_id == 1:
+ result, size = self.model_train.first_layer(
+ model, lr, weight_decay, clip_grad_norm,
+ control_count, self.train_loader, opt=opt
+ )
+ else:
+ result, size = self.model_train.last_layer(
+ model, lr, weight_decay, clip_grad_norm, opt=opt
+ )
+ if fine_tune_config.get("enable", False) and peft_config is not None:
+ if not result:
+ src.Log.print_with_color(
+ f"[WARN] Layer {self.layer_id}: Training failed, bỏ qua lưu LoRA.",
+ "yellow"
+ )
else:
- result, size = self.model_train.last_layer(model, lr, weight_decay, clip_grad_norm)
-
- # Stop training, then send parameters to server
- if fine_tune_config['enable']:
- model = model.merge_and_unload()
-
- model_state_dict = copy.deepcopy(model.state_dict())
-
- if self.device != "cpu":
- for key in model_state_dict:
- model_state_dict[key] = model_state_dict[key].to('cpu')
- data = {"action": "UPDATE", "client_id": self.client_id, "layer_id": self.layer_id,
- "result": result, "size": size,
- "message": "Sent parameters to Server", "parameters": model_state_dict}
- src.Log.print_with_color("[>>>] Client sent parameters to server", "red")
- self.send_to_server(data)
- return True
- elif action == "STOP":
- return False
+ lora_sd = {
+ k: v.detach().cpu().float()
+ for k, v in model.state_dict().items()
+ if "lora" in k.lower()
+ }
+ src.Log.print_with_color(
+ f"[INFO] Layer {self.layer_id}: {len(lora_sd)} LoRA keys collected.",
+ "green"
+ )
- def send_to_server(self, message):
- self.response = None
- self.channel.queue_declare('rpc_queue', durable=False)
- self.channel.basic_publish(exchange='',
- routing_key='rpc_queue',
- body=pickle.dumps(message))
+ # ACCUMULATION FIX: Luôn merge vào base gốc (không phải layer file đã có delta)
+ merge_src = base_file if os.path.exists(base_file) else pretrained_file
+ if os.path.exists(merge_src):
+ base_sd = torch.load(merge_src, map_location="cpu")
+ if "wte.weight" in base_sd:
+ lora_r = fine_tune_config.get("LoRA", {}).get("r", 8)
+ lora_alpha = fine_tune_config.get("LoRA", {}).get("alpha", 16)
+ merged_sd = _merge_lora_into_base(
+ base_sd, lora_sd, lora_r, lora_alpha, model_name
+ )
+
+
+ partial_sd = {}
+ if self.layer_id == 1:
+ for k, v in merged_sd.items():
+ if k.startswith(("wte", "wpe")):
+ partial_sd[k] = v
+ elif k.startswith("h."):
+ idx = int(k.split(".")[1])
+ if idx < cut_layers:
+ partial_sd[k] = v
+ else:
+ for k, v in merged_sd.items():
+ if k.startswith("h."):
+ idx = int(k.split(".")[1])
+ if idx >= cut_layers:
+ new_k = k.replace(f"h.{idx}", f"h.{idx - cut_layers}")
+ partial_sd[new_k] = v
+ elif k.startswith(("ln_f", "lm_head")):
+ partial_sd[k] = v
+
+ torch.save(partial_sd, layer_file)
+ size_mb = os.path.getsize(layer_file) / 1e6
+ src.Log.print_with_color(
+ f"[INFO] Layer {self.layer_id}: Saved {len(partial_sd)} keys "
+ f"→ {layer_file} ({size_mb:.1f} MB)", "green"
+ )
+ else:
+ src.Log.print_with_color(
+ f"[WARN] Layer {self.layer_id}: {merge_src} không có base weights.",
+ "yellow"
+ )
+ else:
+ src.Log.print_with_color(
+ f"[WARN] Layer {self.layer_id}: Không tìm thấy {merge_src} để merge.",
+ "yellow"
+ )
+
- return self.response
+ ready_msg = {
+ "action": "READY",
+ "client_id": self.client_id,
+ "layer_id": self.layer_id,
+ "message": f"Layer {self.layer_id} saved LoRA, ready for next round.",
+ }
+ src.Log.print_with_color(
+ f"[>>>] Layer {self.layer_id}: Gửi READY về server.", "red"
+ )
+ self.send_to_server(ready_msg)
+ return True
+
+ def send_to_server(self, message):
+ self.response = None
+ self.channel.queue_declare("rpc_queue", durable=False)
+ self.channel.basic_publish(
+ exchange="", routing_key="rpc_queue", body=pickle.dumps(message)
+ )
+ return self.response
\ No newline at end of file
diff --git a/src/Server.py b/src/Server.py
index 44946db..42fe8f4 100644
--- a/src/Server.py
+++ b/src/Server.py
@@ -1,3 +1,6 @@
+from src.model.GPT2 import GPT2
+from src.model.Llama import Llama
+from src.model.Bert import Bert
import torch
import os
import random
@@ -5,301 +8,332 @@
import pickle
import sys
import numpy as np
-import copy
import src.Log
import src.Utils
-
-from src.model.GPT2 import GPT2
-from src.model.Llama import Llama
-from src.model.Bert import Bert
from src.val.get_val import get_val
+
class Server:
def __init__(self, config):
- # RabbitMQ
- address = config["rabbit"]["address"]
- username = config["rabbit"]["username"]
- password = config["rabbit"]["password"]
+ address = config["rabbit"]["address"]
+ username = config["rabbit"]["username"]
+ password = config["rabbit"]["password"]
virtual_host = config["rabbit"]["virtual-host"]
- self.model_name = config["server"]["model-name"]
- self.data_name = config["server"]["data-name"]
- self.total_clients = config["server"]["clients"]
- self.cut_layers = config["server"]["cut-layers"]
- self.global_round = config["server"]["global-round"]
- self.round = self.global_round
+ self.model_name = config["server"]["model-name"]
+ self.data_name = config["server"]["data-name"]
+ self.total_clients = config["server"]["clients"]
+ self.cut_layers = config["server"]["cut-layers"]
+ self.global_round = config["server"]["global-round"]
+ self.round = self.global_round
self.save_parameters = config["server"]["parameters"]["save"]
self.load_parameters = config["server"]["parameters"]["load"]
- self.validation = config["server"]["validation"]
-
- # Clients
- self.total_block = config["server"]["model"][self.model_name]["n_block"]
- self.batch_size = config["learning"]["batch-size"]
- self.lr = config["learning"]["learning-rate"]
- self.weight_decay = config["learning"]["weight-decay"]
- self.control_count = config["learning"]["control-count"]
- self.clip_grad_norm = config["learning"]["clip-grad-norm"]
+ self.validation = config["server"]["validation"]
+
+ self.total_block = config["server"]["model"][self.model_name]["n_block"]
+ self.batch_size = config["learning"]["batch-size"]
+ self.lr = config["learning"]["learning-rate"]
+ self.weight_decay = config["learning"]["weight-decay"]
+ self.control_count = config["learning"]["control-count"]
+ self.clip_grad_norm = config["learning"]["clip-grad-norm"]
self.data_distribution = config["server"]["data-distribution"]
- # Data distribution
- self.non_iid = self.data_distribution["non-iid"]
- self.num_label = self.data_distribution["num-label"]
- self.num_sample = self.data_distribution["num-sample"]
- self.refresh_each_round = self.data_distribution["refresh-each-round"]
- self.random_seed = config["server"]["random-seed"]
- self.label_counts = None
+ self.non_iid = self.data_distribution["non-iid"]
+ self.num_label = self.data_distribution["num-label"]
+ self.num_sample = self.data_distribution["num-sample"]
+ self.refresh_each_round = self.data_distribution.get("refresh-each-round", False)
+ self.random_seed = config["server"].get("random-seed", 1)
+
+ self.fine_tune_config = config["fine-tune"]
+ self.opt_config = config.get("optimization", {})
+ self.config = config
- # Fine tune config
- self.fine_tune_config = config['fine-tune']
+ self.model_params = {
+ "vocab_size": config["server"].get("vocab_size", 50257),
+ "n_embd": config["server"].get("n_embd", 768),
+ "n_layer": config["server"].get("n_layer", 12),
+ "n_head": config["server"].get("n_head", 12),
+ "pretrained_path": config["server"].get("pretrained_path", f"{self.model_name}.pt"),
+ }
if self.random_seed:
random.seed(self.random_seed)
- log_path = config["log_path"]
-
+ log_path = config["log_path"]
credentials = pika.PlainCredentials(username, password)
self.connection = pika.BlockingConnection(
- pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
+ pika.ConnectionParameters(
+ host=address, port=5672,
+ virtual_host=f"{virtual_host}",
+ credentials=credentials,
+ heartbeat=0,
+ blocked_connection_timeout=None,
+ )
+ )
self.channel = self.connection.channel()
- self.channel.queue_declare(queue='rpc_queue')
+ self.channel.queue_declare(queue="rpc_queue")
- self.count_update = [0 for _ in range(len(self.total_clients))]
+ self.count_notify = 0
+ self.count_ready = 0
self.register_clients = [0 for _ in range(len(self.total_clients))]
- self.count_notify = 0
- self.responses = {}
- self.list_clients = []
- self.round_result = True
-
- self.global_model_parameters = [[] for _ in range(len(self.total_clients))]
- self.global_client_sizes = [[] for _ in range(len(self.total_clients))]
- self.avg_state_dict = []
+ self.responses = {}
+ self.list_clients = []
+ self.round_result = True
self.channel.basic_qos(prefetch_count=1)
self.reply_channel = self.connection.channel()
- self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request)
+ self.channel.basic_consume(queue="rpc_queue", on_message_callback=self.on_request)
debug_mode = config["debug_mode"]
self.logger = src.Log.Logger(f"{log_path}/app.log", debug_mode)
- self.logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.")
- src.Log.print_with_color(f"Application start. Server is waiting for {self.total_clients} clients.", "green")
+ self.logger.log_info(
+ f"Application start. Server waiting for {self.total_clients} clients."
+ )
+ src.Log.print_with_color(
+ f"Application start. Server waiting for {self.total_clients} clients.", "green"
+ )
+
+ # ── Helpers ───────────────────────────────────────────────────────────────
def distribution(self):
+ num_clients = sum(self.total_clients)
if self.non_iid:
- label_distribution = np.random.dirichlet([self.data_distribution["dirichlet"]["alpha"]] * self.num_label,
- self.total_clients[0])
-
- self.label_counts = (label_distribution * self.num_sample).astype(int)
+ label_dist = np.random.dirichlet(
+ [self.data_distribution["dirichlet"]["alpha"]] * self.num_label,
+ num_clients
+ )
+ self.label_counts = (label_dist * self.num_sample).astype(int)
else:
- self.label_counts = np.full((self.total_clients[0], self.num_label), self.num_sample // self.num_label)
+ self.label_counts = np.full(
+ (num_clients, self.num_label),
+ self.num_sample // self.num_label
+ )
+
+ # ── Main handler ──────────────────────────────────────────────────────────
def on_request(self, ch, method, props, body):
- message = pickle.loads(body)
- routing_key = props.reply_to
- action = message["action"]
+ message = pickle.loads(body)
+ action = message["action"]
client_id = message["client_id"]
- layer_id = message["layer_id"]
- self.responses[routing_key] = message
+ layer_id = message["layer_id"]
+ # ── REGISTER ──────────────────────────────────────────────────────────
if action == "REGISTER":
if (str(client_id), layer_id) not in self.list_clients:
self.list_clients.append((str(client_id), layer_id))
- src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue")
- # Save messages from clients
+ src.Log.print_with_color(
+ f"[<<<] REGISTER from client {client_id} layer {layer_id}", "blue"
+ )
self.register_clients[layer_id - 1] += 1
- # If consumed all clients - Register for first time
- if self.register_clients == self.total_clients:
- src.Log.print_with_color("All clients are connected. Sending notifications.", "green")
-
+ if all(
+ self.register_clients[i] >= self.total_clients[i]
+ for i in range(len(self.total_clients))
+ ):
+ src.Log.print_with_color("All clients connected. Starting round 1.", "green")
self.distribution()
-
- self.logger.log_info(f"Start training round 1")
+ self.logger.log_info("Start training round 1")
self.notify_clients()
- elif action == "NOTIFY":
- src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue")
- message = {"action": "PAUSE",
- "message": "Pause training and please send your parameters",
- "parameters": None}
+ # Client báo hoàn thành train. Server gửi PAUSE để client biết có thể lưu LoRA.
+ elif action == "NOTIFY":
+ src.Log.print_with_color(
+ f"[<<<] NOTIFY from client {client_id} layer {layer_id}", "blue"
+ )
self.count_notify += 1
- if self.count_notify == self.total_clients[0]:
+ if self.count_notify == sum(self.total_clients):
self.count_notify = 0
- src.Log.print_with_color(f"Received all the finish training notification", "yellow")
-
- for (client_id, layer_id) in self.list_clients:
- self.send_to_response(client_id, pickle.dumps(message))
-
- elif action == "UPDATE":
- # self.distribution()
- data_message = message["message"]
- result = message["result"]
- src.Log.print_with_color(f"[<<<] Received message from {client_id}: {data_message}", "blue")
-
- self.count_update[layer_id - 1] += 1
- if not result:
- self.round_result = False
-
- # Save client's model parameters
- if self.save_parameters and self.round_result:
- model_state_dict = message["parameters"]
- client_size = message["size"]
- self.global_model_parameters[layer_id - 1].append(model_state_dict)
- self.global_client_sizes[layer_id - 1].append(client_size)
-
- # If consumed all client's parameters
- if self.count_update == self.total_clients:
- src.Log.print_with_color("Collected all parameters.", "yellow")
- if self.save_parameters and self.round_result:
-
- self.avg_all_parameters()
- self.global_model_parameters = [[] for _ in range(len(self.total_clients))]
- self.global_client_sizes = [[] for _ in range(len(self.total_clients))]
-
- self.count_update = [0 for _ in range(len(self.total_clients))]
- # Test
- if self.save_parameters and self.validation and self.round_result:
- state_dict_full = self.concatenate()
- self.avg_state_dict = []
- if not get_val(self.model_name, self.data_name, state_dict_full,self.logger):
- self.logger.log_warning("Training failed!")
- else:
- # Save to files
- torch.save(state_dict_full, f'{self.model_name}.pt')
- self.round -= 1
- else:
- self.round -= 1
-
- # Start a new training round
- self.round_result = True
+ current_round = self.global_round - self.round + 1
+ src.Log.print_with_color(
+ f"All clients finished training round {current_round}.", "yellow"
+ )
+
+ # Gửi PAUSE → client nhận xong mới lưu LoRA rồi gửi READY
+ pause_msg = {
+ "action": "PAUSE",
+ "message": f"Round {current_round} done. Save your LoRA.",
+ "current_round": current_round,
+ "parameters": None,
+ }
+ for (cid, lid) in self.list_clients:
+ self.send_to_response(cid, pickle.dumps(pause_msg))
+
+ self.logger.log_info(f"Round {current_round} complete. Waiting for READY.")
+
+ # Server chỉ gửi START round tiếp khi đủ tất cả READY.
+ elif action == "READY":
+ src.Log.print_with_color(
+ f"[<<<] READY from client {client_id} layer {layer_id}", "blue"
+ )
+ self.count_ready += 1
+
+ if self.count_ready == sum(self.total_clients):
+ self.count_ready = 0
+ self.round -= 1
+ current_round = self.global_round - self.round
+
+ src.Log.print_with_color(
+ f"All clients ready. Round {current_round} fully complete.", "green"
+ )
+ self.logger.log_info(f"Round {current_round} fully complete.")
+
+ # Merge GPT2_layer1.pt + GPT2_layer2.pt → GPT2.pt
+ self._merge_layer_files()
if self.round > 0:
- self.logger.log_info(f"Start training round {self.global_round - self.round + 1}")
- if self.save_parameters:
- self.notify_clients()
- else:
- self.notify_clients(register=False)
+ next_round = self.global_round - self.round + 1
+ self.logger.log_info(f"Start training round {next_round}")
+ self.notify_clients()
else:
self.logger.log_info("Stop training !!!")
self.notify_clients(start=False)
sys.exit()
- ch.basic_ack(delivery_tag=method.delivery_tag)
+ elif action == "UPDATE":
+ src.Log.print_with_color(
+ f"[WARN] Deprecated UPDATE from client {client_id} — ignored.", "yellow"
+ )
- def notify_clients(self, start=True, register=True):
+ try:
+ ch.basic_ack(delivery_tag=method.delivery_tag)
+ except Exception as e:
+ src.Log.print_with_color(f"[WARN] basic_ack failed: {e}", "yellow")
+ self._reconnect()
- # Send message to clients when consumed all clients
- if self.model_name == 'GPT2':
- klass = GPT2
- elif self.model_name == 'Llama':
- klass = Llama
- elif self.model_name == 'Bert':
- klass = Bert
- else:
- klass = globals()[f'{self.model_name}']
+ def notify_clients(self, start=True):
for (client_id, layer_id) in self.list_clients:
- # Read parameters file
- filepath = f'{self.model_name}.pt'
- state_dict = None
-
- if start:
- if self.load_parameters and register:
- if os.path.exists(filepath):
- full_state_dict = torch.load(filepath, weights_only=True)
-
- if layer_id == 1:
- model = klass(layer_id=1, n_block=self.cut_layers)
- state_dict = model.state_dict()
- keys = state_dict.keys()
-
- for key in keys:
- state_dict[key] = full_state_dict[key]
-
- else:
- model = klass(layer_id=2, n_block=self.total_block - self.cut_layers)
- state_dict = model.state_dict()
- state_dict = src.Utils.change_keys(state_dict, self.cut_layers, True)
- keys = state_dict.keys()
-
- for key in keys:
- state_dict[key] = full_state_dict[key]
-
- state_dict =src.Utils.change_keys(state_dict, self.cut_layers, False)
- src.Log.print_with_color(f"Load pretrain model successfully", "green")
-
- else:
- src.Log.print_with_color(f"File {filepath} does not exist.", "yellow")
- self.logger.log_info(f"File {filepath} does not exist.")
-
- src.Log.print_with_color(f"[>>>] Sent start training request to client {client_id}", "red")
-
- response = {"action": "START",
- "message": "Server accept the connection!",
- "parameters": copy.deepcopy(state_dict),
- "cut_layers": self.cut_layers,
- "total_block": self.total_block,
- "model_name": self.model_name,
- "data_name": self.data_name,
- "num_sample": self.num_sample,
- "control_count": self.control_count,
- "batch_size": self.batch_size,
- "lr": self.lr,
- "weight_decay": self.weight_decay,
- "clip_grad_norm": self.clip_grad_norm,
- "fine_tune_config": self.fine_tune_config
- }
- self.send_to_response(client_id, pickle.dumps(response))
-
+ if not start:
+ self.send_to_response(
+ client_id,
+ pickle.dumps({"action": "STOP", "message": "Stop training!", "parameters": None})
+ )
+ src.Log.print_with_color(f"[>>>] STOP → client {client_id}", "red")
+ continue
+ response = {
+ "action": "START",
+ "message": "Server accept the connection!",
+ "parameters": None,
+ "cut_layers": self.cut_layers,
+ "total_block": self.total_block,
+ "model_name": self.model_name,
+ "data_name": self.data_name,
+ "num_sample": self.num_sample,
+ "control_count": self.control_count,
+ "batch_size": self.batch_size,
+ "lr": self.lr,
+ "weight_decay": self.weight_decay,
+ "clip_grad_norm": self.clip_grad_norm,
+ "fine_tune_config": self.fine_tune_config,
+ "opt_config": self.opt_config,
+ "refresh_each_round": self.refresh_each_round,
+ }
+ src.Log.print_with_color(
+ f"[>>>] START → client {client_id} layer {layer_id}", "red"
+ )
+ self.send_to_response(client_id, pickle.dumps(response))
+
+
+ def _merge_layer_files(self):
+ """
+ Merge GPT2_layer1.pt (wte, wpe, h.0-3) và GPT2_layer2.pt (h.4-11, ln_f, lm_head)
+ thành GPT2.pt đầy đủ để round tiếp theo load.
+ Layer 2 keys cần được remap: h.0 → h.{cut_layers}, h.1 → h.{cut_layers+1}, ...
+ """
+ f1 = f"{self.model_name}_layer1.pt"
+ f2 = f"{self.model_name}_layer2.pt"
+
+ if not os.path.exists(f1) or not os.path.exists(f2):
+ src.Log.print_with_color(
+ f"[WARN] Merge skipped: {f1} exists={os.path.exists(f1)}, "
+ f"{f2} exists={os.path.exists(f2)}", "yellow"
+ )
+ return
+
+ sd1 = torch.load(f1, map_location="cpu")
+ sd2 = torch.load(f2, map_location="cpu")
+
+ merged = {}
+
+ for k, v in sd1.items():
+ merged[k] = v
+
+ for k, v in sd2.items():
+ if k.startswith("h."):
+ parts = k.split(".")
+ old_idx = int(parts[1])
+ new_idx = old_idx + self.cut_layers
+ new_k = ".".join([parts[0], str(new_idx)] + parts[2:])
+ merged[new_k] = v
else:
- src.Log.print_with_color(f"[>>>] Sent stop training request to client {client_id}", "red")
- response = {"action": "STOP",
- "message": "Stop training!",
- "parameters": None}
- self.send_to_response(client_id, pickle.dumps(response))
+ merged[k] = v
- def start(self):
- self.channel.start_consuming()
- def send_to_response(self, client_id, message):
- reply_queue_name = f'reply_{client_id}'
- self.reply_channel.queue_declare(reply_queue_name, durable=False)
+ if "lm_head.weight" not in merged and "wte.weight" in merged:
+ merged["lm_head.weight"] = merged["wte.weight"].clone()
- src.Log.print_with_color(f"[>>>] Sent notification to client {client_id}", "red")
- self.reply_channel.basic_publish(
- exchange='',
- routing_key=reply_queue_name,
- body=message
+ out_file = f"{self.model_name}.pt"
+ torch.save(merged, out_file)
+ src.Log.print_with_color(
+ f"[>>>] Merged {f1} + {f2} → {out_file} "
+ f"({len(merged)} keys)", "green"
)
+ self.logger.log_info(f"Merged layer files → {out_file}")
+
+ def _reconnect(self):
+ src.Log.print_with_color("[>>>] Reconnecting to RabbitMQ...", "yellow")
+ try:
+ self.connection.close()
+ except Exception:
+ pass
+ credentials = pika.PlainCredentials(
+ self.config["rabbit"]["username"],
+ self.config["rabbit"]["password"]
+ )
+ self.connection = pika.BlockingConnection(
+ pika.ConnectionParameters(
+ host=self.config["rabbit"]["address"],
+ port=5672,
+ virtual_host=self.config["rabbit"]["virtual-host"],
+ credentials=credentials,
+ heartbeat=0,
+ blocked_connection_timeout=None,
+ )
+ )
+ self.channel = self.connection.channel()
+ self.reply_channel = self.connection.channel()
+ self.channel.queue_declare(queue="rpc_queue")
+ self.channel.basic_qos(prefetch_count=1)
+ self.channel.basic_consume(queue="rpc_queue", on_message_callback=self.on_request)
+ src.Log.print_with_color("[>>>] Reconnected successfully.", "green")
- def avg_all_parameters(self):
- layer_sizes = self.global_client_sizes
- layer_params = self.global_model_parameters
-
- for layer_idx, list_state_dicts in enumerate(layer_params):
- list_sizes = layer_sizes[layer_idx]
- if not list_state_dicts or not list_sizes:
- self.avg_state_dict.append({})
- continue
- avg_sd = src.Utils.fed_avg_state_dicts(list_state_dicts, weights=list_sizes)
- self.avg_state_dict.append(avg_sd)
-
- def concatenate(self):
- avg_layers = self.avg_state_dict
- if not avg_layers:
- print(f"Warning: don't has averaged layers, skipping.")
-
- full_dict = {}
- for idx, layer_dict in enumerate(avg_layers):
- if idx == 0:
- sd = layer_dict
- full_dict.update(copy.deepcopy(sd))
- else:
- sd = src.Utils.change_keys(layer_dict, self.cut_layers, True)
- full_dict.update(copy.deepcopy(sd))
+ def start(self):
+ while True:
+ try:
+ self.channel.start_consuming()
+ except (
+ pika.exceptions.ChannelWrongStateError,
+ pika.exceptions.StreamLostError,
+ pika.exceptions.ConnectionClosedByBroker,
+ Exception,
+ ) as e:
+ src.Log.print_with_color(
+ f"[WARN] Connection error: {e} — reconnecting...", "yellow"
+ )
+ try:
+ self._reconnect()
+ except Exception as re:
+ src.Log.print_with_color(f"[ERROR] Reconnect failed: {re}", "red")
+ import time; time.sleep(5)
- return full_dict
+ def send_to_response(self, client_id, message):
+ reply_queue_name = f"reply_{client_id}"
+ self.reply_channel.queue_declare(reply_queue_name, durable=False)
+ self.reply_channel.basic_publish(
+ exchange="", routing_key=reply_queue_name, body=message
+ )
\ No newline at end of file
diff --git a/src/Utils.py b/src/Utils.py
index 8f76063..258dca2 100644
--- a/src/Utils.py
+++ b/src/Utils.py
@@ -49,7 +49,7 @@ def change_keys(state_dict, num, increase=True):
return new_state_dict
-def fed_avg_state_dicts(state_dicts, weights = None):
+def fed_avg_state_dicts(state_dicts, weights=None):
num = len(state_dicts)
if num == 0:
raise ValueError("fed_avg_state_dicts: don't have any state_dict.")
@@ -57,25 +57,35 @@ def fed_avg_state_dicts(state_dicts, weights = None):
if weights is None:
weights = [1.0] * num
total_w = sum(weights)
+ if total_w == 0:
+ raise ValueError("fed_avg_state_dicts: tổng weight = 0.")
all_keys = set().union(*(sd.keys() for sd in state_dicts))
avg_dict = {}
for key in all_keys:
+ acc = None
+ acc_w = 0.0 # Fix: theo dõi tổng weight thực sự đóng góp cho key này
- acc = None
for sd, w in zip(state_dicts, weights):
if key not in sd:
continue
t = sd[key].float()
if torch.isnan(t).any():
- t = torch.nan_to_num(t) # zero-fill
+ t = torch.nan_to_num(t, nan=0.0)
t = t * w
- acc = t if acc is None else acc + t
+ acc = t if acc is None else acc + t
+ acc_w += w
- avg = acc / total_w
+ # Fix: nếu key không có trong bất kỳ state_dict nào → bỏ qua
+ if acc is None or acc_w == 0:
+ continue
+
+ avg = acc / acc_w # Fix: chia cho tổng weight thực tế của key, không total_w
- orig = next(sd[key] for sd in state_dicts if key in sd)
+ orig = next((sd[key] for sd in state_dicts if key in sd), None)
+ if orig is None:
+ continue
if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool):
avg = avg.round().to(orig.dtype)
else:
diff --git a/src/__pycache__/Log.cpython-313.pyc b/src/__pycache__/Log.cpython-313.pyc
new file mode 100644
index 0000000..30bbf19
Binary files /dev/null and b/src/__pycache__/Log.cpython-313.pyc differ
diff --git a/src/__pycache__/Optimizer.cpython-310.pyc b/src/__pycache__/Optimizer.cpython-310.pyc
new file mode 100644
index 0000000..4b5eb8c
Binary files /dev/null and b/src/__pycache__/Optimizer.cpython-310.pyc differ
diff --git a/src/__pycache__/Optimizer.cpython-313.pyc b/src/__pycache__/Optimizer.cpython-313.pyc
new file mode 100644
index 0000000..e7740b9
Binary files /dev/null and b/src/__pycache__/Optimizer.cpython-313.pyc differ
diff --git a/src/__pycache__/RpcClient.cpython-313.pyc b/src/__pycache__/RpcClient.cpython-313.pyc
new file mode 100644
index 0000000..f5a8bd4
Binary files /dev/null and b/src/__pycache__/RpcClient.cpython-313.pyc differ
diff --git a/src/__pycache__/Server.cpython-313.pyc b/src/__pycache__/Server.cpython-313.pyc
new file mode 100644
index 0000000..6a59da5
Binary files /dev/null and b/src/__pycache__/Server.cpython-313.pyc differ
diff --git a/src/__pycache__/Utils.cpython-313.pyc b/src/__pycache__/Utils.cpython-313.pyc
new file mode 100644
index 0000000..d84d518
Binary files /dev/null and b/src/__pycache__/Utils.cpython-313.pyc differ
diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000..e0edd48
Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/__pycache__/__init__.cpython-313.pyc b/src/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..ff6e90a
Binary files /dev/null and b/src/__pycache__/__init__.cpython-313.pyc differ
diff --git a/src/convert.py b/src/convert.py
new file mode 100644
index 0000000..f916861
--- /dev/null
+++ b/src/convert.py
@@ -0,0 +1,13 @@
+import torch
+from transformers import GPT2LMHeadModel
+
+# load HF model
+model = GPT2LMHeadModel.from_pretrained("./gpt2_e2e_finetuned")
+
+# lấy state_dict
+state_dict = model.state_dict()
+
+# save về format server dùng
+torch.save(state_dict, "GPT2.pt")
+
+print("Saved GPT2.pt")
\ No newline at end of file
diff --git a/src/dataset/__pycache__/EMOTION.cpython-313.pyc b/src/dataset/__pycache__/EMOTION.cpython-313.pyc
new file mode 100644
index 0000000..30f2083
Binary files /dev/null and b/src/dataset/__pycache__/EMOTION.cpython-313.pyc differ
diff --git a/src/dataset/__pycache__/GSM8K.cpython-313.pyc b/src/dataset/__pycache__/GSM8K.cpython-313.pyc
new file mode 100644
index 0000000..1a9910e
Binary files /dev/null and b/src/dataset/__pycache__/GSM8K.cpython-313.pyc differ
diff --git a/src/dataset/__pycache__/__init__.cpython-313.pyc b/src/dataset/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..516a025
Binary files /dev/null and b/src/dataset/__pycache__/__init__.cpython-313.pyc differ
diff --git a/src/dataset/__pycache__/dataloader.cpython-313.pyc b/src/dataset/__pycache__/dataloader.cpython-313.pyc
new file mode 100644
index 0000000..d60545d
Binary files /dev/null and b/src/dataset/__pycache__/dataloader.cpython-313.pyc differ
diff --git a/src/dataset/dataloader.py b/src/dataset/dataloader.py
index fbfe2af..cfbd983 100644
--- a/src/dataset/dataloader.py
+++ b/src/dataset/dataloader.py
@@ -12,6 +12,64 @@
from torch.utils.data import DataLoader
def dataloader(model_name =None, data_name=None, batch_size=None, distribution=500, train=True):
+ if data_name == 'E2E':
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ tokenizer.pad_token = tokenizer.eos_token
+ path = os.path.join("data/", "e2e_train.jsonl")
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = [json.loads(line) for line in f if line.strip()]
+
+ random.shuffle(data)
+
+ from torch.utils.data import Dataset
+
+ class E2EDataset(Dataset):
+ def __init__(self, tokenizer, data, max_length=512):
+ self.tokenizer = tokenizer
+ self.data = data
+ self.max_length = max_length
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ ex = self.data[idx]
+
+ prompt = f" {ex['input']} "
+ target = ex["output"] + " <|endoftext|>"
+
+ full = prompt + " " + target
+
+ enc = self.tokenizer(
+ full,
+ truncation=True,
+ padding="max_length",
+ max_length=self.max_length,
+ return_tensors="pt"
+ )
+
+ input_ids = enc["input_ids"].squeeze(0)
+ attention_mask = enc["attention_mask"].squeeze(0)
+
+ labels = input_ids.clone()
+
+
+ prompt_len = len(
+ self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
+ )
+ labels[:prompt_len] = -100
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels
+ }
+
+ train_set = E2EDataset(tokenizer, data)
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
+
+ return train_loader
if data_name == 'GSM8K':
if model_name == 'GPT2':
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
@@ -34,7 +92,50 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5
print(f"{len(train_set)} train examples")
- train_set = GSM8K(tokenizer, train_set, False)
+ from torch.utils.data import Dataset
+
+ class E2EDataset(Dataset):
+ def __init__(self, tokenizer, data, max_length=512):
+ self.tokenizer = tokenizer
+ self.data = data
+ self.max_length = max_length
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ ex = self.data[idx]
+
+ prompt = f"### Input:\n{ex['question']}\n\n### Output:\n"
+ target = ex["answer"] + " <|endoftext|>"
+
+ full = prompt + target
+
+ enc = self.tokenizer(
+ full,
+ truncation=True,
+ padding="max_length",
+ max_length=self.max_length,
+ return_tensors="pt"
+ )
+
+ input_ids = enc["input_ids"].squeeze(0)
+ attention_mask = enc["attention_mask"].squeeze(0)
+
+
+ labels = input_ids.clone()
+
+ prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].squeeze(0)
+ prompt_len = len(self.tokenizer(prompt, add_special_tokens=False)["input_ids"])
+
+ labels[:prompt_len] = -100
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels
+ }
+ train_set = E2EDataset(tokenizer, train_set)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
return train_loader
else:
@@ -55,6 +156,28 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5
return test_loader
if data_name == 'EMOTION':
+
+ dataset = load_dataset(
+ 'emotion',
+ download_mode='reuse_dataset_if_exists',
+ cache_dir='./hf_cache'
+ )
+ tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ if train:
+ # EMOTION có 6 class (sadness/joy/love/anger/fear/surprise)
+ num_label = int(distribution / 6)
+ distribution = [num_label] * 6
+ train_texts, train_labels = load_train_EMOTION(dataset, distribution)
+ train_set = EMOTIONDataset(train_texts, train_labels, tokenizer, max_length=128)
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
+ return train_loader
+ else:
+ test_texts, test_label = load_test_EMOTION(2000, dataset)
+ test_set = EMOTIONDataset(test_texts, test_label, tokenizer, max_length=128)
+ test_loader = DataLoader(test_set, batch_size=100, shuffle=False)
+ return test_loader
+
+ elif data_name == 'AG_NEWS':
dataset = load_dataset(
'ag_news',
download_mode='reuse_dataset_if_exists',
@@ -63,7 +186,7 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
if train:
num_label = int(distribution / 4)
- distribution = [num_label, num_label, num_label, num_label]
+ distribution = [num_label] * 4
train_texts, train_labels = load_train_EMOTION(dataset, distribution)
train_set = EMOTIONDataset(train_texts, train_labels, tokenizer, max_length=128)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
@@ -72,4 +195,4 @@ def dataloader(model_name =None, data_name=None, batch_size=None, distribution=5
test_texts, test_label = load_test_EMOTION(2000, dataset)
test_set = EMOTIONDataset(test_texts, test_label, tokenizer, max_length=128)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)
- return test_loader
+ return test_loader
\ No newline at end of file
diff --git a/src/fine_tune/Bert.py b/src/fine_tune/Bert.py
index 4154a40..dcb359c 100644
--- a/src/fine_tune/Bert.py
+++ b/src/fine_tune/Bert.py
@@ -7,229 +7,249 @@
import torch.nn as nn
import src.Log
+from src.Optimizer import OptimizationBundle
+
+
class Ft_Bert:
def __init__(self, client_id, layer_id, channel, device):
- self.client_id = client_id
- self.layer_id = layer_id
- self.channel = channel
- self.device = device
+ self.client_id = client_id
+ self.layer_id = layer_id
+ self.channel = channel
+ self.device = device
self.data_count = 0
- self.size = None
-
- def send_intermediate_output(self, data_id, output, labels, trace):
-
- forward_queue_name = f'intermediate_queue_{self.layer_id}'
- self.channel.queue_declare(forward_queue_name, durable=False)
-
- if trace:
- trace.append(self.client_id)
- message = pickle.dumps(
- {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace}
- )
- else:
- message = pickle.dumps(
- {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id]}
- )
+ self.size = None
+
+ # ── Giao tiếp ────────────────────────────────────────────
+
+ def send_intermediate_output(self, data_id, q_numpy, scale, labels, trace):
+ fwd_q = f"intermediate_queue_{self.layer_id}"
+ self.channel.queue_declare(fwd_q, durable=False)
+ trace_out = list(trace) + [self.client_id] if trace else [self.client_id]
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": q_numpy,
+ "scale": scale,
+ "label": labels.cpu(),
+ "trace": trace_out,
+ })
if self.size is None:
- self.size = len(message)
- print(f'Length message: {self.size} (bytes).')
- self.channel.basic_publish(
- exchange='',
- routing_key=forward_queue_name,
- body=message
- )
+ self.size = len(msg)
+ print(f"Length message: {self.size} bytes.")
+ self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg)
def send_gradient(self, data_id, gradient, trace):
- to_client_id = trace[-1]
- trace.pop(-1)
- backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
-
- message = pickle.dumps(
- {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace})
-
+ to_id = trace[-1]
+ trace = trace[:-1]
+ bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}"
+ self.channel.queue_declare(queue=bwd_q, durable=False)
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": gradient.detach().cpu().numpy(),
+ "trace": trace,
+ })
if self.size is None:
- self.size = len(message)
- print(f'Length message: {self.size} (bytes).')
+ self.size = len(msg)
+ self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg)
+
+ def send_to_server(self, message):
+ self.channel.queue_declare("rpc_queue", durable=False)
self.channel.basic_publish(
- exchange='',
- routing_key=backward_queue_name,
- body=message
+ exchange="", routing_key="rpc_queue", body=pickle.dumps(message)
)
- def send_to_server(self, message):
- self.channel.queue_declare('rpc_queue', durable=False)
- self.channel.basic_publish(exchange='',
- routing_key='rpc_queue',
- body=pickle.dumps(message))
+ # ── Layer 1 ───────────────────────────────────────────────
+
+ def first_layer(self, model, lr, weight_decay, clip_grad_norm,
+ control_count=1, train_loader=None,
+ opt: OptimizationBundle = None):
- def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1, train_loader=None):
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ if opt is None:
+ opt = OptimizationBundle({})
- backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}"
+ self.channel.queue_declare(queue=bwd_q_name, durable=False)
self.channel.basic_qos(prefetch_count=1)
model = model.to(self.device)
- data_iter = iter(train_loader)
- forward = []
- backward = []
- comm = []
- num_forward = 0
- num_backward = 0
- end_data = False
- data_store = {}
+ data_iter = iter(train_loader)
+ forward_t = []
+ backward_t = []
+ comm_t = []
+ num_fwd = num_bwd = 0
+ end_data = False
+ data_store = {} # {uuid: input_ids_tensor}
with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar:
while True:
- # Training model
model.train()
- optimizer.zero_grad()
- # Process gradient
- method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True)
+ # ── Backward pass nếu có gradient ──────────
+ method_frame, _, body = self.channel.basic_get(
+ queue=bwd_q_name, auto_ack=True
+ )
if method_frame and body:
- num_backward += 1
- received_data = pickle.loads(body)
- gradient_numpy = received_data["data"]
- gradient = torch.tensor(gradient_numpy).to(self.device)
- data_id = received_data["data_id"]
-
- data_input = data_store.pop(data_id)
- start_backward = time.time()
- output = model(input_ids=data_input)
+ num_bwd += 1
+ recv = pickle.loads(body)
+ gradient = torch.tensor(recv["data"]).to(self.device)
+ data_id = recv["data_id"]
+ data_input = data_store.pop(data_id) # Fix Bug 7: xóa ngay sau dùng
+
+ t0 = time.time()
+ optimizer.zero_grad()
+ with opt.precision_ctx(self.device):
+ output = model(input_ids=data_input)
+
+ # Fix Bug 1: chỉ gọi backward 1 lần với gradient từ layer sau
+ # KHÔNG gọi opt.step() ở đây vì đây là split backward
output.backward(gradient=gradient)
- optimizer.step()
- time_backward = time.time() - start_backward
- backward.append(time_backward)
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+
+ # Fix Bug 1: scaler.step nếu FP16, không thì optimizer.step thường
+ if opt.scaler is not None:
+ opt.scaler.step(optimizer)
+ opt.scaler.update()
+ else:
+ optimizer.step()
+
+ backward_t.append(time.time() - t0)
+
else:
- # speed control
+ # Fix Bug 7: giới hạn data_store để tránh OOM
if len(data_store) > control_count:
continue
-
try:
- batch = next(data_iter)
- input_ids = batch['input_ids'].to(self.device)
- labels = batch['labels'].to(self.device)
+ batch = next(data_iter)
+ inp_ids = batch["input_ids"].to(self.device)
+ labels = batch["labels"].to(self.device)
data_id = uuid.uuid4()
- data_store[data_id] = input_ids
+ data_store[data_id] = inp_ids
- start_forward = time.time()
- intermediate_output = model(input_ids=input_ids)
- time_forward = time.time() - start_forward
- forward.append(time_forward)
- intermediate_output = intermediate_output.detach().requires_grad_(True)
+ t0 = time.time()
+ with opt.precision_ctx(self.device):
+ inter_out = model(input_ids=inp_ids)
+ forward_t.append(time.time() - t0)
- num_forward += 1
+ inter_out = inter_out.detach().requires_grad_(True)
+ num_fwd += 1
self.data_count += 1
-
pbar.update(1)
- start_comm = time.time()
- self.send_intermediate_output(data_id, intermediate_output, labels, trace=None)
- time_comm = time.time() - start_comm
- comm.append(time_comm)
+
+ t0 = time.time()
+ q_numpy, scale = opt.quant(inter_out)
+ self.send_intermediate_output(data_id, q_numpy, scale, labels, trace=None)
+ comm_t.append(time.time() - t0)
except StopIteration:
end_data = True
- if end_data and (num_forward == num_backward):
+ if end_data and num_fwd == num_bwd:
break
-
- notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id,
- "message": "Finish training!"}
-
+ notify = {
+ "action": "NOTIFY", "client_id": self.client_id,
+ "layer_id": self.layer_id, "message": "Finish training!",
+ }
src.Log.print_with_color("[>>>] Finish training!", "red")
- self.send_to_server(notify_data)
+ self.send_to_server(notify)
- broadcast_queue_name = f'reply_{self.client_id}'
+ bcast_q = f"reply_{self.client_id}"
while True:
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
- print(f'Forward time: {forward}s.')
- print(f'Backward time: {backward}s.')
- print(f'Comm time: {comm}s.')
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] {recv}", "blue")
+ if recv["action"] == "PAUSE":
+ print(f"Forward: {forward_t}")
+ print(f"Backward: {backward_t}")
+ print(f"Comm: {comm_t}")
return True, self.data_count
time.sleep(0.5)
- def last_layer(self, model, lr, weight_decay, clip_grad_norm):
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- criterion = nn.CrossEntropyLoss()
- result = True
+ # ── Layer 2 ───────────────────────────────────────────────
+
+ def last_layer(self, model, lr, weight_decay, clip_grad_norm,
+ opt: OptimizationBundle = None):
+
+ if opt is None:
+ opt = OptimizationBundle({})
- forward_queue_name = f'intermediate_queue_{self.layer_id - 1}'
- self.channel.queue_declare(queue=forward_queue_name, durable=False)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ criterion = nn.CrossEntropyLoss()
+ result = True
+
+ fwd_q_name = f"intermediate_queue_{self.layer_id - 1}"
+ self.channel.queue_declare(queue=fwd_q_name, durable=False)
self.channel.basic_qos(prefetch_count=1)
- print('Waiting for intermediate output. To exit press CTRL+C')
+ print("Waiting for intermediate output. To exit press CTRL+C")
model.to(self.device)
model.train()
- execute = []
- comm = []
+
+ exec_t = []
+ comm_t = []
while True:
- method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True)
+ method_frame, _, body = self.channel.basic_get(
+ queue=fwd_q_name, auto_ack=True
+ )
if method_frame and body:
-
optimizer.zero_grad()
- received_data = pickle.loads(body)
- intermediate_output_numpy = received_data["data"]
- trace = received_data["trace"]
- data_id = received_data["data_id"]
- labels = received_data["label"].to(self.device)
- start_exec = time.time()
- intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device)
+ recv = pickle.loads(body)
+ trace = recv["trace"]
+ data_id = recv["data_id"]
+ labels = recv["label"].to(self.device)
- output = model(input_ids=intermediate_output)
+ inter_out = opt.dequant(
+ recv["data"], recv.get("scale"), self.device, requires_grad=True
+ )
- loss = criterion(output, labels)
+ t0 = time.time()
+ with opt.precision_ctx(self.device):
+ output = model(input_ids=inter_out)
+ loss = criterion(output, labels)
- if torch.isnan(loss).any():
+ if torch.isnan(loss):
src.Log.print_with_color("NaN detected in loss", "yellow")
result = False
- print(f"Loss: {loss.item()}")
- intermediate_output.retain_grad()
- loss.backward()
+ print(f"Loss: {loss.item():.4f}")
+
+ # Fix Bug 3: retain_grad() TRƯỚC backward
+ inter_out.retain_grad()
+
+ # Fix Bug 2: backward + clip + step chỉ 1 lần, không lồng opt.step
+ if opt.scaler is not None:
+ opt.scaler.scale(loss).backward()
+ opt.scaler.unscale_(optimizer)
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ opt.scaler.step(optimizer)
+ opt.scaler.update()
+ else:
+ loss.backward()
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ optimizer.step()
- optimizer.step()
- time_exec = time.time() - start_exec
- execute.append(time_exec)
+ exec_t.append(time.time() - t0)
self.data_count += 1
- gradient = intermediate_output.grad
- start_comm = time.time()
- self.send_gradient(data_id, gradient, trace)
- time_comm = time.time() - start_comm
- comm.append(time_comm)
+ # Fix Bug 3: inter_out.grad giờ đã có sau backward
+ t0 = time.time()
+ self.send_gradient(data_id, inter_out.grad, trace)
+ comm_t.append(time.time() - t0)
else:
- broadcast_queue_name = f'reply_{self.client_id}'
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
+ bcast_q = f"reply_{self.client_id}"
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
- print(f'Forward + Backward: {execute}s')
- print(f'Comm time: {comm}s')
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] {recv}", "blue")
+ if recv["action"] == "PAUSE":
+ print(f"Exec: {exec_t}")
+ print(f"Comm: {comm_t}")
return result, self.data_count
- def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None):
- pass
-
- def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None):
- pass
-
-
-
-
-
-
-
-
-
-
-
-
+ def train_on_middle_layer(self, *args, **kwargs): pass
+ def alone_training(self, *args, **kwargs): pass
diff --git a/src/fine_tune/GPT2.py b/src/fine_tune/GPT2.py
index ce6cbf1..e31a988 100644
--- a/src/fine_tune/GPT2.py
+++ b/src/fine_tune/GPT2.py
@@ -7,244 +7,337 @@
import torch.nn as nn
import src.Log
+from src.Optimizer import OptimizationBundle
from transformers import GPT2Tokenizer
+from torch.optim.lr_scheduler import CosineAnnealingLR
+
+
class Ft_GPT2:
def __init__(self, client_id, layer_id, channel, device):
- self.client_id = client_id
- self.layer_id = layer_id
- self.channel = channel
- self.device = device
+ self.client_id = client_id
+ self.layer_id = layer_id
+ self.channel = channel
+ self.device = device
self.data_count = 0
- self.gpu_cpu = []
- self.encode = []
- self.tokenizer = None
-
- def send_intermediate_output(self, data_id, output, attention_mask, labels, trace):
-
- forward_queue_name = f'intermediate_queue_{self.layer_id}'
- self.channel.queue_declare(forward_queue_name, durable=False)
-
- if trace:
- trace.append(self.client_id)
- start_cpu = time.time()
- output = output.detach().cpu().numpy()
- labels = labels.cpu()
-
- self.gpu_cpu.append(time.time() - start_cpu)
- message = pickle.dumps(
- {"data_id": data_id, "data": output, "label": labels, "trace": trace,
- "attention_mask": attention_mask.cpu()}
- )
- else:
- message = pickle.dumps(
- {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id],
- "attention_mask" :attention_mask.cpu()}
- )
- print(f'len message : {len(message)} bytes')
+ def send_intermediate_output(self, data_id, q_numpy, scale, mask_out, labels, trace):
+ fwd_q = f"intermediate_queue_{self.layer_id}"
+ trace_out = list(trace) + [self.client_id] if trace else [self.client_id]
+ self.channel.queue_declare(fwd_q, durable=False)
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": q_numpy,
+ "scale": scale,
+ "label": labels.cpu(),
+ "trace": trace_out,
+ "mask": mask_out.cpu(),
+ })
+ print(f"len message: {len(msg)} bytes")
+ self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg)
+
+ def send_end_signal(self):
+ fwd_q = f"intermediate_queue_{self.layer_id}"
+ self.channel.queue_declare(fwd_q, durable=False)
self.channel.basic_publish(
- exchange='',
- routing_key=forward_queue_name,
- body=message
+ exchange="", routing_key=fwd_q,
+ body=pickle.dumps({"action": "END"})
)
+ src.Log.print_with_color("[>>>] Client 1 gửi END signal cho client 2", "yellow")
def send_gradient(self, data_id, gradient, trace):
- to_client_id = trace[-1]
- trace.pop(-1)
- backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
-
- message = pickle.dumps(
- {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace})
+ to_id = trace[-1]
+ trace = trace[:-1]
+ bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}"
+ self.channel.queue_declare(queue=bwd_q, durable=False)
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": gradient.detach().cpu().numpy(),
+ "trace": trace,
+ })
+ self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg)
+ def send_to_server(self, message):
+ self.channel.queue_declare("rpc_queue", durable=False)
self.channel.basic_publish(
- exchange='',
- routing_key=backward_queue_name,
- body=message
+ exchange="", routing_key="rpc_queue", body=pickle.dumps(message)
)
- def send_to_server(self, message):
- self.channel.queue_declare('rpc_queue', durable=False)
- self.channel.basic_publish(exchange='',
- routing_key='rpc_queue',
- body=pickle.dumps(message))
- def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1,
- train_loader=None):
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
- self.channel.basic_qos(prefetch_count=1)
+ def first_layer(self, model, lr, weight_decay, clip_grad_norm,
+ control_count=1, train_loader=None,
+ opt: OptimizationBundle = None):
+
+ if opt is None:
+ opt = OptimizationBundle({})
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}"
+ self.channel.queue_declare(queue=bwd_q_name, durable=False)
+ self.channel.basic_qos(prefetch_count=1)
model = model.to(self.device)
- forward_time = []
- backward_time = []
- comm_time = []
-
- for i in range(1):
- data_iter = iter(train_loader)
- num_forward = 0
- num_backward = 0
- end_data = False
- data_store = {}
-
- with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar:
- while True:
- # Training model
- model.train()
- optimizer.zero_grad()
- # Process gradient
- method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True)
- if method_frame and body:
- num_backward += 1
- received_data = pickle.loads(body)
- gradient_numpy = received_data["data"]
- gradient = torch.tensor(gradient_numpy).to(self.device)
- data_id = received_data["data_id"]
-
- data_input = data_store.pop(data_id)
- start_backward = time.time()
- output, mask = model(input_ids=data_input[0], attention_mask=data_input[1])
- output.backward(gradient=gradient)
- optimizer.step()
- stop_backward = time.time()
- backward_time.append(stop_backward - start_backward)
- else:
- # speed control
- if len(data_store) >= control_count:
- continue
-
- try:
- batch = next(data_iter)
- input_ids = batch['input_ids'].to(self.device)
- attention_mask = batch['attention_mask'].to(self.device)
- labels = batch['labels'].to(self.device)
- data_id = uuid.uuid4()
- data_store[data_id] = (input_ids, attention_mask)
- start_forward = time.time()
- intermediate_output, mask = model(input_ids=input_ids, attention_mask=attention_mask)
- stop_forward = time.time()
- forward_time.append(stop_forward - start_forward)
- intermediate_output = intermediate_output.detach().requires_grad_(True)
-
- num_forward += 1
- self.data_count += 1
-
- pbar.update(1)
- start_comm = time.time()
- self.send_intermediate_output(data_id, intermediate_output, mask,
- labels, trace=None)
- stop_comm = time.time()
- comm_time.append(start_comm - stop_comm)
-
- except StopIteration:
- end_data = True
-
- if end_data and (num_forward == num_backward):
- break
-
- notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id,
- "message": "Finish training!"}
-
- src.Log.print_with_color("[>>>] Finish training!", "red")
- self.send_to_server(notify_data)
-
- broadcast_queue_name = f'reply_{self.client_id}'
- while True: # Wait for broadcast
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
+ forward_t = []
+ backward_t = []
+ comm_t = []
+ data_iter = iter(train_loader)
+ num_fwd = num_bwd = 0
+ end_data = False
+ data_store = {}
+ scheduler = CosineAnnealingLR(
+ optimizer, T_max=max(len(train_loader), 1), eta_min=lr / 10
+ )
+
+ with tqdm(total=len(train_loader), desc="Layer 1", unit="step") as pbar:
+ while True:
+ model.train()
+
+ method_frame, _, body = self.channel.basic_get(
+ queue=bwd_q_name, auto_ack=True
+ )
+ if method_frame and body:
+ num_bwd += 1
+ recv = pickle.loads(body)
+ gradient = torch.tensor(recv["data"]).to(self.device)
+ # FIX: cast gradient về đúng dtype của inter (fp16/fp32)
+ gradient = gradient.to(dtype=inter.dtype)
+ data_id = recv["data_id"]
+ inter = data_store.pop(data_id)
+
+ if torch.isnan(gradient).any():
+ print("[ERROR] Skip backward layer 1 (NaN gradient)")
+ continue
+
+ optimizer.zero_grad()
+ inter.backward(gradient)
+
+ has_nan = any(
+ p.grad is not None and torch.isnan(p.grad).any()
+ for p in model.parameters()
+ )
+ if has_nan:
+ print("[ERROR] NaN grad → skip step")
+ optimizer.zero_grad()
+ continue
+
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+
+ optimizer.step()
+ scheduler.step()
+ pbar.update(1)
+
+ else:
+
+ if len(data_store) >= control_count:
+ continue
+ try:
+ batch = next(data_iter)
+ inp_ids = batch["input_ids"].to(self.device)
+ attn = batch["attention_mask"].to(self.device)
+ labels = batch["labels"].to(self.device)
+ data_id = uuid.uuid4()
+
+ t0 = time.time()
+ with opt.precision_ctx(self.device):
+ out = model(input_ids=inp_ids, attention_mask=attn)
+ inter = out["hidden_states"]
+ mask_out = out["mask"]
+ inter = inter.detach().requires_grad_(True)
+ data_store[data_id] = inter
+ forward_t.append(time.time() - t0)
+
+ num_fwd += 1
+ self.data_count += 1
+
+ t0 = time.time()
+ q_numpy, scale = opt.quant(inter)
+ self.send_intermediate_output(
+ data_id, q_numpy, scale, mask_out, labels, trace=None
+ )
+ comm_t.append(time.time() - t0)
+
+ except StopIteration:
+ end_data = True
+
+ if end_data and num_fwd == num_bwd:
+ break
+
+ self.send_end_signal()
+
+ notify = {
+ "action": "NOTIFY",
+ "client_id": self.client_id,
+ "layer_id": self.layer_id,
+ "message": "Finish training!",
+ }
+ src.Log.print_with_color("[>>>] Client 1 gửi NOTIFY về server", "red")
+ self.send_to_server(notify)
+
+ # Chờ PAUSE từ server
+ bcast_q = f"reply_{self.client_id}"
+ while True:
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
- print(f'forward: {forward_time}')
- print(f'backward: {backward_time}')
- print(f'comm: {comm_time}')
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] Client 1: {recv['action']}", "blue")
+ if recv["action"] == "PAUSE":
+ src.Log.print_with_color(
+ f"[INFO] Forward: {len(forward_t)} steps, "
+ f"Backward: {num_bwd} steps", "green"
+ )
return True, self.data_count
time.sleep(0.5)
- def last_layer(self, model, lr, weight_decay, clip_grad_norm):
+
+
+ def last_layer(self, model, lr, weight_decay, clip_grad_norm,
+ opt: OptimizationBundle = None):
+
+ if opt is None:
+ opt = OptimizationBundle({})
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- pad_id = tokenizer.pad_token_id
- if pad_id is None:
+ if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
- pad_id = tokenizer.eos_token_id
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
- result = True
+ criterion = nn.CrossEntropyLoss(ignore_index=-100)
+ result = True
+
+ scheduler = CosineAnnealingLR(optimizer, T_max=1000, eta_min=lr / 10)
- forward_queue_name = f'intermediate_queue_{self.layer_id - 1}'
- self.channel.queue_declare(queue=forward_queue_name, durable=False)
+ fwd_q_name = f"intermediate_queue_{self.layer_id - 1}"
+ self.channel.queue_declare(queue=fwd_q_name, durable=False)
self.channel.basic_qos(prefetch_count=1)
- print('Waiting for intermediate output. To exit press CTRL+C')
+ src.Log.print_with_color("Layer 2: Waiting for hidden states...", "green")
model.to(self.device)
model.train()
- exec_time = []
- comm_time = []
+
+ exec_t = []
+ comm_t = []
+ num_received = 0
+ num_grad_sent = 0
+ end_received = False
+ nan_count = 0
+
while True:
- method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True)
+ method_frame, _, body = self.channel.basic_get(
+ queue=fwd_q_name, auto_ack=True
+ )
if method_frame and body:
- optimizer.zero_grad()
- received_data = pickle.loads(body)
- intermediate_output_numpy = received_data["data"]
- attention_mask = received_data["attention_mask"].to(self.device)
- trace = received_data["trace"]
- data_id = received_data["data_id"]
- labels = received_data["label"].to(self.device)
-
- intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device)
-
- start = time.time()
- output, _ = model(input_ids=intermediate_output, attention_mask=attention_mask)
- shift_logits = output[:, :-1, :].contiguous() # [B, L-1, V]
- shift_labels = labels[:, 1:].contiguous() # [B, L-1]
-
- loss = criterion(
- shift_logits.view(-1, shift_logits.size(-1)), # [(B*(L-1)), V]
- shift_labels.view(-1) # [(B*(L-1))]
+ recv = pickle.loads(body)
+
+ # END sentinel
+ if recv.get("action") == "END":
+ src.Log.print_with_color(
+ f"[<<<] END received. Received {num_received} batches, "
+ f"sent {num_grad_sent} gradients.", "yellow"
+ )
+ end_received = True
+ continue
+
+ mask = recv["mask"].to(self.device)
+ trace = recv["trace"]
+ data_id = recv["data_id"]
+ labels = recv["label"].to(self.device)
+ num_received += 1
+
+ inter = opt.dequant(
+ recv["data"], recv.get("scale"), self.device, requires_grad=True
)
+ # FIX: đồng bộ dtype của inter với model để tránh Half/Float mismatch
+ model_dtype = next(model.parameters()).dtype
+ inter = inter.to(dtype=model_dtype)
- # loss = criterion(output.view(-1, output.size(-1)), labels.view(-1))
- if torch.isnan(loss).any():
- src.Log.print_with_color("NaN detected in loss", "yellow")
- result = False
-
- print(f"Loss: {loss.item()}")
- intermediate_output.retain_grad()
- loss.backward()
+ optimizer.zero_grad()
- optimizer.step()
- exec_time.append(time.time() - start)
+ t0 = time.time()
+ with opt.precision_ctx(self.device):
+ out = model(input_ids=inter, attention_mask=mask)
+ output = out["logits"]
+ shift_logits = output[:, :-1, :].contiguous()
+ shift_labels = labels[:, 1:].contiguous()
+ loss = criterion(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1),
+ )
+
+ if torch.isnan(loss):
+ src.Log.print_with_color("NaN loss — sending zero gradient", "yellow")
+ nan_count += 1
+ num_grad_sent += 1
+ self.send_gradient(data_id, torch.zeros_like(inter), trace)
+ continue
+
+ print(f"Loss: {loss.item():.4f}")
+
+ inter.retain_grad()
+
+ if opt.scaler is not None:
+ opt.scaler.scale(loss).backward()
+ opt.scaler.unscale_(optimizer)
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ opt.scaler.step(optimizer)
+ opt.scaler.update()
+ else:
+ loss.backward()
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ optimizer.step()
+
+ scheduler.step()
+ exec_t.append(time.time() - t0)
self.data_count += 1
- gradient = intermediate_output.grad
- start_comm = time.time()
- self.send_gradient(data_id, gradient, trace) # 1F1B
- comm_time.append(time.time() - start_comm)
- # Check training process
- else:
- broadcast_queue_name = f'reply_{self.client_id}'
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
- if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
- print(f'exec_time: {exec_time}')
- print(f'comm_time: {comm_time}')
- return result, self.data_count
-
- def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None):
- pass
-
- def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None):
- pass
-
-
+ t0 = time.time()
+ grad = inter.grad if inter.grad is not None else torch.zeros_like(inter)
+ if inter.grad is None:
+ src.Log.print_with_color(
+ "[WARN] inter.grad is None — sending zero gradient", "yellow"
+ )
+ self.send_gradient(data_id, grad, trace)
+ num_grad_sent += 1
+ comm_t.append(time.time() - t0)
+ else:
+ if end_received and num_grad_sent == num_received:
+ if num_received > 0 and nan_count / num_received > 0.5:
+ result = False
+ src.Log.print_with_color(
+ f"[WARN] {nan_count}/{num_received} batches NaN → round failed",
+ "yellow"
+ )
+
+ src.Log.print_with_color(
+ f"[>>>] Tất cả {num_grad_sent} gradient đã gửi. Gửi NOTIFY.", "red"
+ )
+ notify = {
+ "action": "NOTIFY",
+ "client_id": self.client_id,
+ "layer_id": self.layer_id,
+ "message": "Finish training!",
+ }
+ self.send_to_server(notify)
+ end_received = False
+
+ # Chờ PAUSE từ server
+ bcast_q = f"reply_{self.client_id}"
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
+ if body:
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] Client 2: {recv['action']}", "blue")
+ if recv["action"] == "PAUSE":
+ src.Log.print_with_color(
+ f"[INFO] Exec: {len(exec_t)} steps", "green"
+ )
+ return result, self.data_count
+ time.sleep(0.1)
+ def train_on_middle_layer(self, *args, **kwargs): pass
+ def alone_training(self, *args, **kwargs): passs
\ No newline at end of file
diff --git a/src/fine_tune/Llama.py b/src/fine_tune/Llama.py
index 5435d21..50a0c46 100644
--- a/src/fine_tune/Llama.py
+++ b/src/fine_tune/Llama.py
@@ -7,208 +7,226 @@
import torch.nn as nn
import src.Log
+from src.Optimizer import OptimizationBundle
+
class Ft_Llama:
def __init__(self, client_id, layer_id, channel, device):
- self.client_id = client_id
- self.layer_id = layer_id
- self.channel = channel
- self.device = device
+ self.client_id = client_id
+ self.layer_id = layer_id
+ self.channel = channel
+ self.device = device
self.data_count = 0
- def send_intermediate_output(self, data_id, output, attention_mask, labels, trace):
-
- forward_queue_name = f'intermediate_queue_{self.layer_id}'
- self.channel.queue_declare(forward_queue_name, durable=False)
-
- if trace:
- trace.append(self.client_id)
- message = pickle.dumps(
- {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace,
- "attention_mask": attention_mask.cpu()}
- )
- else:
- message = pickle.dumps(
- {"data_id": data_id, "data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id],
- "attention_mask" :attention_mask.cpu()}
- )
-
- self.channel.basic_publish(
- exchange='',
- routing_key=forward_queue_name,
- body=message
- )
+ # ── Giao tiếp ────────────────────────────────────────────
+
+ def send_intermediate_output(self, data_id, q_numpy, scale, attention_mask, labels, trace):
+ fwd_q = f"intermediate_queue_{self.layer_id}"
+ self.channel.queue_declare(fwd_q, durable=False)
+ trace_out = list(trace) + [self.client_id] if trace else [self.client_id]
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": q_numpy,
+ "scale": scale,
+ "label": labels.cpu(),
+ "trace": trace_out,
+ "attention_mask": attention_mask.cpu(),
+ })
+ self.channel.basic_publish(exchange="", routing_key=fwd_q, body=msg)
def send_gradient(self, data_id, gradient, trace):
- to_client_id = trace[-1]
- trace.pop(-1)
- backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
-
- message = pickle.dumps(
- {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace})
+ to_id = trace[-1]
+ trace = trace[:-1]
+ bwd_q = f"gradient_queue_{self.layer_id - 1}_{to_id}"
+ self.channel.queue_declare(queue=bwd_q, durable=False)
+ msg = pickle.dumps({
+ "data_id": data_id,
+ "data": gradient.detach().cpu().numpy(),
+ "trace": trace,
+ })
+ self.channel.basic_publish(exchange="", routing_key=bwd_q, body=msg)
+ def send_to_server(self, message):
+ self.channel.queue_declare("rpc_queue", durable=False)
self.channel.basic_publish(
- exchange='',
- routing_key=backward_queue_name,
- body=message
+ exchange="", routing_key="rpc_queue", body=pickle.dumps(message)
)
- def send_to_server(self, message):
- self.channel.queue_declare('rpc_queue', durable=False)
- self.channel.basic_publish(exchange='',
- routing_key='rpc_queue',
- body=pickle.dumps(message))
+ # ── Layer 1 ───────────────────────────────────────────────
- def first_layer(self, model, lr, weight_decay, clip_grad_norm, control_count=1,
- train_loader=None):
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ def first_layer(self, model, lr, weight_decay, clip_grad_norm,
+ control_count=1, train_loader=None,
+ opt: OptimizationBundle = None):
- backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}'
- self.channel.queue_declare(queue=backward_queue_name, durable=False)
- self.channel.basic_qos(prefetch_count=1)
+ if opt is None:
+ opt = OptimizationBundle({})
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+ bwd_q_name = f"gradient_queue_{self.layer_id}_{self.client_id}"
+ self.channel.queue_declare(queue=bwd_q_name, durable=False)
+ self.channel.basic_qos(prefetch_count=1)
model = model.to(self.device)
- for i in range(1):
- data_iter = iter(train_loader)
- num_forward = 0
- num_backward = 0
- end_data = False
- data_store = {}
-
- with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar:
- while True:
- # Training model
- model.train()
- optimizer.zero_grad()
+ data_iter = iter(train_loader)
+ num_fwd = num_bwd = 0
+ end_data = False
+ data_store = {} # {uuid: (input_ids, attention_mask)}
+
+ with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar:
+ while True:
+ model.train()
+
+ method_frame, _, body = self.channel.basic_get(
+ queue=bwd_q_name, auto_ack=True
+ )
+ if method_frame and body:
+ num_bwd += 1
+ recv = pickle.loads(body)
+ gradient = torch.tensor(recv["data"]).to(self.device)
+ data_id = recv["data_id"]
+ inp, mask = data_store.pop(data_id) # Fix Bug 7
- # Process gradient
- method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True)
- if method_frame and body:
- num_backward += 1
- received_data = pickle.loads(body)
- gradient_numpy = received_data["data"]
- gradient = torch.tensor(gradient_numpy).to(self.device)
- data_id = received_data["data_id"]
-
- data_input = data_store.pop(data_id)
- output, mask = model(input_ids=data_input[0], attention_mask=data_input[1])
- output.backward(gradient=gradient)
- optimizer.step()
+ optimizer.zero_grad()
+ with opt.precision_ctx(self.device):
+ output, _ = model(input_ids=inp, attention_mask=mask)
+
+ # Fix Bug 1: chỉ 1 backward với split gradient
+ output.backward(gradient=gradient)
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ if opt.scaler is not None:
+ opt.scaler.step(optimizer)
+ opt.scaler.update()
else:
- # speed control
- if len(data_store) >= control_count:
- continue
-
- try:
- batch = next(data_iter)
- input_ids = batch['input_ids'].to(self.device)
- attention_mask = batch['attention_mask'].to(self.device)
- labels = batch['input_ids'].to(self.device)
- data_id = uuid.uuid4()
- data_store[data_id] = (input_ids, attention_mask)
-
- intermediate_output, mask = model(input_ids=input_ids, attention_mask=attention_mask)
- intermediate_output = intermediate_output.detach().requires_grad_(True)
-
- num_forward += 1
- self.data_count += 1
-
- pbar.update(1)
-
- self.send_intermediate_output(data_id, intermediate_output, mask,
- labels, trace=None)
-
- except StopIteration:
- end_data = True
-
- if end_data and (num_forward == num_backward):
- break
-
- notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id,
- "message": "Finish training!"}
+ optimizer.step()
+ else:
+ # Fix Bug 7: giới hạn pipeline depth
+ if len(data_store) >= control_count:
+ continue
+ try:
+ batch = next(data_iter)
+ inp_ids = batch["input_ids"].to(self.device)
+ attn = batch["attention_mask"].to(self.device)
+ labels = batch["labels"].to(self.device)
+ data_id = uuid.uuid4()
+ data_store[data_id] = (inp_ids, attn)
+
+ with opt.precision_ctx(self.device):
+ inter, mask_out = model(input_ids=inp_ids, attention_mask=attn)
+
+ inter = inter.detach().requires_grad_(True)
+ num_fwd += 1
+ self.data_count += 1
+ pbar.update(1)
+
+ q_numpy, scale = opt.quant(inter)
+ self.send_intermediate_output(
+ data_id, q_numpy, scale, mask_out, labels, trace=None
+ )
+
+ except StopIteration:
+ end_data = True
+
+ if end_data and num_fwd == num_bwd:
+ break
+
+ notify = {
+ "action": "NOTIFY", "client_id": self.client_id,
+ "layer_id": self.layer_id, "message": "Finish training!",
+ }
src.Log.print_with_color("[>>>] Finish training!", "red")
- self.send_to_server(notify_data)
+ self.send_to_server(notify)
- broadcast_queue_name = f'reply_{self.client_id}'
- while True: # Wait for broadcast
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
+ bcast_q = f"reply_{self.client_id}"
+ while True:
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] {recv}", "blue")
+ if recv["action"] == "PAUSE":
return True, self.data_count
time.sleep(0.5)
- def last_layer(self, model, lr, weight_decay, clip_grad_norm):
+ # ── Layer 2 ───────────────────────────────────────────────
+
+ def last_layer(self, model, lr, weight_decay, clip_grad_norm,
+ opt: OptimizationBundle = None):
+
+ if opt is None:
+ opt = OptimizationBundle({})
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
- result = True
+ result = True
- forward_queue_name = f'intermediate_queue_{self.layer_id - 1}'
- self.channel.queue_declare(queue=forward_queue_name, durable=False)
+ fwd_q_name = f"intermediate_queue_{self.layer_id - 1}"
+ self.channel.queue_declare(queue=fwd_q_name, durable=False)
self.channel.basic_qos(prefetch_count=1)
- print('Waiting for intermediate output. To exit press CTRL+C')
+ print("Waiting for intermediate output. To exit press CTRL+C")
model.to(self.device)
model.train()
+
while True:
- method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True)
+ method_frame, _, body = self.channel.basic_get(
+ queue=fwd_q_name, auto_ack=True
+ )
if method_frame and body:
-
optimizer.zero_grad()
- received_data = pickle.loads(body)
- intermediate_output_numpy = received_data["data"]
- attention_mask = received_data["attention_mask"].to(self.device)
- trace = received_data["trace"]
- data_id = received_data["data_id"]
- labels = received_data["label"].to(self.device)
-
- intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device)
-
- output, _ = model(input_ids=intermediate_output, attention_mask=attention_mask)
-
- loss = criterion(output.view(-1, output.size(-1)), labels.view(-1))
- if torch.isnan(loss).any():
+ recv = pickle.loads(body)
+ attn = recv["attention_mask"].to(self.device)
+ trace = recv["trace"]
+ data_id = recv["data_id"]
+ labels = recv["label"].to(self.device)
+
+ inter = opt.dequant(
+ recv["data"], recv.get("scale"), self.device, requires_grad=True
+ )
+
+ with opt.precision_ctx(self.device):
+ output, _ = model(input_ids=inter, attention_mask=attn)
+ # Causal LM loss: shift
+ shift_logits = output[:, :-1, :].contiguous()
+ shift_labels = labels[:, 1:].contiguous()
+ loss = criterion(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1),
+ )
+
+ if torch.isnan(loss):
src.Log.print_with_color("NaN detected in loss", "yellow")
result = False
- print(f"Loss: {loss.item()}")
- intermediate_output.retain_grad()
- loss.backward()
+ print(f"Loss: {loss.item():.4f}")
+
+ # Fix Bug 4: retain_grad() TRƯỚC backward
+ inter.retain_grad()
+
+ # Fix Bug 2: 1 backward, 1 optimizer.step
+ if opt.scaler is not None:
+ opt.scaler.scale(loss).backward()
+ opt.scaler.unscale_(optimizer)
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ opt.scaler.step(optimizer)
+ opt.scaler.update()
+ else:
+ loss.backward()
+ if clip_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
+ optimizer.step()
- optimizer.step()
self.data_count += 1
+ self.send_gradient(data_id, inter.grad, trace)
- gradient = intermediate_output.grad
-
- self.send_gradient(data_id, gradient, trace) # 1F1B
- # Check training process
else:
- broadcast_queue_name = f'reply_{self.client_id}'
- method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True)
+ bcast_q = f"reply_{self.client_id}"
+ _, _, body = self.channel.basic_get(queue=bcast_q, auto_ack=True)
if body:
- received_data = pickle.loads(body)
- src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue")
- if received_data["action"] == "PAUSE":
+ recv = pickle.loads(body)
+ src.Log.print_with_color(f"[<<<] {recv}", "blue")
+ if recv["action"] == "PAUSE":
return result, self.data_count
- def train_on_middle_layer(self, model, lr, momentum, clip_grad_norm, control_count=5, cluster=None):
- pass
-
- def alone_training(self, model, lr, momentum, clip_grad_norm, train_loader=None, cluster=None):
- pass
-
-
-
-
-
-
-
-
-
-
-
-
-
+ def train_on_middle_layer(self, *args, **kwargs): pass
+ def alone_training(self, *args, **kwargs): pass
diff --git a/src/fine_tune/__pycache__/Bert.cpython-313.pyc b/src/fine_tune/__pycache__/Bert.cpython-313.pyc
new file mode 100644
index 0000000..8c97500
Binary files /dev/null and b/src/fine_tune/__pycache__/Bert.cpython-313.pyc differ
diff --git a/src/fine_tune/__pycache__/GPT2.cpython-313.pyc b/src/fine_tune/__pycache__/GPT2.cpython-313.pyc
new file mode 100644
index 0000000..10b263b
Binary files /dev/null and b/src/fine_tune/__pycache__/GPT2.cpython-313.pyc differ
diff --git a/src/fine_tune/__pycache__/Llama.cpython-313.pyc b/src/fine_tune/__pycache__/Llama.cpython-313.pyc
new file mode 100644
index 0000000..299657e
Binary files /dev/null and b/src/fine_tune/__pycache__/Llama.cpython-313.pyc differ
diff --git a/src/fine_tune/__pycache__/__init__.cpython-313.pyc b/src/fine_tune/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..8e4d268
Binary files /dev/null and b/src/fine_tune/__pycache__/__init__.cpython-313.pyc differ
diff --git a/src/model/Bert.py b/src/model/Bert.py
index 6e44e5f..c4769bc 100644
--- a/src/model/Bert.py
+++ b/src/model/Bert.py
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+from src.Optimizer import flash_scaled_dot_product
class DotDict(dict):
def __getattr__(self, k):
@@ -37,47 +38,35 @@ def forward(self, input_ids, token_type_ids=None):
return embeddings
class BertSdpaSelfAttention(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, dropout_prob):
+ def __init__(self, hidden_size, num_attention_heads, dropout_prob, use_flash=False):
super(BertSdpaSelfAttention, self).__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.use_flash = use_flash
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(dropout_prob)
+ self._dropout_p = dropout_prob
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
- return x.permute(0, 2, 1, 3)
+ return x.permute(0, 2, 1, 3) # (B, H, T, D)
def forward(self, hidden_states):
+ q = self.transpose_for_scores(self.query(hidden_states))
+ k = self.transpose_for_scores(self.key(hidden_states))
+ v = self.transpose_for_scores(self.value(hidden_states))
- mixed_query_layer = self.query(hidden_states)
- mixed_key_layer = self.key(hidden_states)
- mixed_value_layer = self.value(hidden_states)
-
- query_layer = self.transpose_for_scores(mixed_query_layer)
- key_layer = self.transpose_for_scores(mixed_key_layer)
- value_layer = self.transpose_for_scores(mixed_value_layer)
-
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
-
- import math
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
-
- attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = self.dropout(attention_probs)
-
- context_layer = torch.matmul(attention_probs, value_layer)
+ dp = self._dropout_p if self.training else 0.0
+ context_layer = flash_scaled_dot_product(q, k, v, dropout_p=dp)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
-
- return context_layer
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ return context_layer.view(*new_shape)
class BertSelfOutput(nn.Module):
def __init__(self, hidden_size, dropout_prob):
@@ -93,9 +82,10 @@ def forward(self, hidden_states, input_tensor):
return hidden_states
class BertAttention(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, dropout_prob):
+ def __init__(self, hidden_size, num_attention_heads, dropout_prob, use_flash=False):
super(BertAttention, self).__init__()
- self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob)
+ self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads,
+ dropout_prob, use_flash=use_flash)
self.output = BertSelfOutput(hidden_size, dropout_prob)
def forward(self, hidden_states):
@@ -128,9 +118,11 @@ def forward(self, hidden_states, input_tensor):
return hidden_states
class BertLayer(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob):
+ def __init__(self, hidden_size, num_attention_heads, intermediate_size,
+ dropout_prob, use_flash=False):
super(BertLayer, self).__init__()
- self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob)
+ self.attention = BertAttention(hidden_size, num_attention_heads,
+ dropout_prob, use_flash=use_flash)
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
self.output = BertOutput(hidden_size, intermediate_size, dropout_prob)
@@ -165,10 +157,12 @@ def forward(self, pooled_output):
class Bert(nn.Module):
def __init__( self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072,
- max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, layer_id=0, n_block=12
+ max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, layer_id=0, n_block=12,
+ use_flash=False
):
super(Bert, self).__init__()
self.layer_id = layer_id
+ self.use_flash = use_flash
self.config = DotDict(
model_type="bert",
vocab_size=vocab_size,
@@ -180,29 +174,30 @@ def __init__( self, vocab_size=28996, hidden_size=768, num_attention_heads=12, i
use_return_dict=True, output_attentions=False, output_hidden_states=False
)
+ def _make_layers(n):
+ return nn.ModuleList([
+ BertLayer(hidden_size, num_attention_heads, intermediate_size,
+ dropout_prob, use_flash=use_flash)
+ for _ in range(n)
+ ])
+
if self.layer_id == 1:
- self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, max_position_embeddings=max_position_embeddings,
- type_vocab_size=type_vocab_size,dropout_prob=dropout_prob)
- self.layers = nn.ModuleList(
- [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)
- for _ in range(n_block)]
- )
+ self.embeddings = BertEmbeddings(
+ vocab_size=vocab_size, hidden_size=hidden_size,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size, dropout_prob=dropout_prob)
+ self.layers = _make_layers(n_block)
elif self.layer_id == 2:
- self.layers = nn.ModuleList(
- [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)
- for _ in range(n_block)]
- )
+ self.layers = _make_layers(n_block)
self.pooler = BertPooler(hidden_size)
self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(hidden_size, 4)
else:
- self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size,
- max_position_embeddings=max_position_embeddings,
- type_vocab_size=type_vocab_size, dropout_prob=dropout_prob)
- self.layers = nn.ModuleList(
- [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)
- for _ in range(n_block)]
- )
+ self.embeddings = BertEmbeddings(
+ vocab_size=vocab_size, hidden_size=hidden_size,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size, dropout_prob=dropout_prob)
+ self.layers = _make_layers(n_block)
self.pooler = BertPooler(hidden_size)
self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(hidden_size, 4)
diff --git a/src/model/GPT2.py b/src/model/GPT2.py
index ce4b9bd..1a4067c 100644
--- a/src/model/GPT2.py
+++ b/src/model/GPT2.py
@@ -4,6 +4,7 @@
from transformers.pytorch_utils import Conv1D
from transformers.activations import NewGELUActivation
+from src.Optimizer import flash_scaled_dot_product
class DotDict(dict):
def __getattr__(self, k):
@@ -14,38 +15,35 @@ def __delattr__(self, k): del self[k]
class Attention(nn.Module):
- def __init__(self, embed_size: int, num_heads: int, dropout=0.0):
+ def __init__(self, embed_size: int, num_heads: int, dropout=0.0, use_flash=False):
super().__init__()
assert embed_size % num_heads == 0
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
+ self.use_flash = use_flash
self.c_attn = Conv1D(3 * embed_size, embed_size)
self.c_proj = Conv1D(embed_size, embed_size)
self.attn_drop = nn.Dropout(dropout)
self.resid_drop = nn.Dropout(dropout)
+ self._dropout_p = dropout
def forward(self, x, mask=None):
B, T, E = x.shape
H, D = self.num_heads, self.head_dim
- qkv = self.c_attn(x) # (B, T, 3E)
- q, k, v = qkv.split(E, dim=-1) # (B, T, E) x3
- q = q.view(B, T, H, D)
- k = k.view(B, T, H, D)
- v = v.view(B, T, H, D)
-
- # (B, H, T, T)
- att = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(D)
- if mask is not None:
- att = att.masked_fill(mask == 0, float("-1e20"))
- att = torch.softmax(att, dim=-1)
- att = self.attn_drop(att)
-
- # (B, T, H, D) -> (B, T, E)
- y = torch.einsum("bhqk,bkhd->bqhd", att, v).contiguous().view(B, T, E)
+ qkv = self.c_attn(x)
+ q, k, v = qkv.split(E, dim=-1)
+ q = q.view(B, T, H, D).transpose(1, 2)
+ k = k.view(B, T, H, D).transpose(1, 2)
+ v = v.view(B, T, H, D).transpose(1, 2)
+
+ dp = self._dropout_p if self.training else 0.0
+ y = flash_scaled_dot_product(q, k, v, mask=mask, dropout_p=dp) # (B,H,T,D)
+
+ y = y.transpose(1, 2).contiguous().view(B, T, E)
y = self.c_proj(y)
y = self.resid_drop(y)
return y
@@ -66,10 +64,10 @@ def forward(self, x):
return x
class GPT2Block(nn.Module):
- def __init__(self, embed_size: int, num_heads: int, dropout: float):
+ def __init__(self, embed_size: int, num_heads: int, dropout: float, use_flash=False):
super().__init__()
self.ln_1 = nn.LayerNorm(embed_size)
- self.attn = Attention(embed_size, num_heads, dropout)
+ self.attn = Attention(embed_size, num_heads, dropout, use_flash=use_flash)
self.ln_2 = nn.LayerNorm(embed_size)
self.mlp = MLP(embed_size, dropout)
@@ -79,12 +77,12 @@ def forward(self, x, mask):
return x
def build_causal_mask(B, T, device):
- return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T)
+ return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous()
class GPT2(nn.Module):
def __init__(self, vocab_size=50257, max_length=1024, n_layer=12, n_head=12, n_embd=768, dropout=0.1,
- layer_id=0, n_block=12):
+ layer_id=0, n_block=12, use_flash=False):
super().__init__()
self.vocab_size = vocab_size
self.max_length = max_length
@@ -98,87 +96,99 @@ def __init__(self, vocab_size=50257, max_length=1024, n_layer=12, n_head=12, n_e
bos_token_id=50256, eos_token_id=50256, pad_token_id=0,
is_encoder_decoder=False, tie_word_embeddings=False,
)
+
+ def _make_blocks(n):
+ return nn.ModuleList([
+ GPT2Block(n_embd, n_head, dropout, use_flash=use_flash)
+ for _ in range(n)
+ ])
+
if self.layer_id == 1:
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(max_length, n_embd)
self.drop = nn.Dropout(dropout)
-
- self.h = nn.ModuleList([
- GPT2Block(n_embd, n_head, dropout)
- for _ in range(n_block)
- ])
+ self.h = _make_blocks(n_block)
elif self.layer_id == 2:
- self.h = nn.ModuleList([
- GPT2Block(n_embd, n_head, dropout)
- for _ in range(n_block)
- ])
+ self.h = _make_blocks(n_block)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
else:
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(max_length, n_embd)
self.drop = nn.Dropout(dropout)
-
- self.h = nn.ModuleList([
- GPT2Block(n_embd, n_head, dropout)
- for _ in range(n_block)
- ])
+ self.h = _make_blocks(n_block)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
+ self.lm_head.weight = self.wte.weight
+
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
- def forward(self, input_ids, attention_mask=None,
- **kwargs):
if self.layer_id == 1:
B, T = input_ids.shape
h = self.wte(input_ids)
- pos = torch.arange(0, T).unsqueeze(0).to(attention_mask.device)
+ pos = torch.arange(0, T).unsqueeze(0).to(input_ids.device)
h = self.drop(h + self.wpe(pos))
- mask = build_causal_mask(B, T, attention_mask.device)
+
+ mask = build_causal_mask(B, T, input_ids.device)
+
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(mask.dtype)
qry_mask = attention_mask[:, None, :, None].to(mask.dtype)
- mask = mask * key_mask * qry_mask
+ mask = (mask * key_mask * qry_mask) > 0
for blk in self.h:
h = blk(h, mask)
+ return {"hidden_states": h, "mask": mask}
+
elif self.layer_id == 2:
h = input_ids
mask = attention_mask
+
for blk in self.h:
- h = blk(h, attention_mask)
+ h = blk(h, mask)
+
h = self.ln_f(h)
- h = self.lm_head(h)
+ logits = self.lm_head(h)
+
+ return {"logits": logits}
else:
B, T = input_ids.shape
h = self.wte(input_ids)
- pos = torch.arange(0, T).unsqueeze(0).to(attention_mask.device)
+
+ pos = torch.arange(0, T).unsqueeze(0).to(input_ids.device)
h = self.drop(h + self.wpe(pos))
- mask = build_causal_mask(B, T, attention_mask.device)
+
+ mask = build_causal_mask(B, T, input_ids.device)
+
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(mask.dtype)
qry_mask = attention_mask[:, None, :, None].to(mask.dtype)
- mask = mask * key_mask * qry_mask
+ mask = (mask * key_mask * qry_mask) > 0
for blk in self.h:
h = blk(h, mask)
- h = self.ln_f(h)
- h = self.lm_head(h)
- return h, mask
+ h = self.ln_f(h)
+ logits = self.lm_head(h)
+ return {
+ "logits": logits,
+ "hidden_states": h,
+ "mask": mask
+ }
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
- B, T = input_ids.shape
- device = input_ids.device
- causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T)
- if attention_mask is not None:
- key_mask = attention_mask[:, None, None, :].to(causal.dtype)
- qry_mask = attention_mask[:, None, :, None].to(causal.dtype)
- mask = causal * key_mask * qry_mask
- else:
- mask = causal
- return {"input_ids": input_ids, "mask": mask}
+ B, T = input_ids.shape
+ device = input_ids.device
+ causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous()
+ if attention_mask is not None:
+ key_mask = attention_mask[:, None, None, :].to(causal.dtype)
+ qry_mask = attention_mask[:, None, :, None].to(causal.dtype)
+ mask = causal * key_mask * qry_mask
+ else:
+ mask = causal
+ return {"input_ids": input_ids, "mask": mask}
def shift_labels_for_lm(input_ids, ignore_index=-100):
labels = input_ids.clone()
diff --git a/src/model/Llama.py b/src/model/Llama.py
index bcc5b8e..6808026 100644
--- a/src/model/Llama.py
+++ b/src/model/Llama.py
@@ -3,6 +3,7 @@
import torch.nn.functional as F
import math
+from src.Optimizer import flash_scaled_dot_product
class DotDict(dict):
def __getattr__(self, k):
@@ -35,7 +36,7 @@ def apply_rotary_pos_emb(q, k, cos, sin):
return q_rot, k_rot
def build_causal_mask(B, T, device):
- return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T)
+ return torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous()
class LlamaRotaryEmbedding(nn.Module):
def __init__(
@@ -62,17 +63,19 @@ def __init__(
self._set_cos_sin_cache(max_position_embeddings, dtype=torch.float32, device=device)
def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype, device=None):
+ # Fix Bug 12: khi device=None (gọi lại từ forward), dùng inv_freq.device
device = device if device is not None else self.inv_freq.device
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [seq_len]
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freq = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freq, freq], dim=-1)
cos = emb.cos()[None, None, :, :] # [1, 1, seq_len, dim]
- sin = emb.sin()[None, None, :, :] # [1, 1, seq_len, dim]
+ sin = emb.sin()[None, None, :, :]
- self.cos_cached = cos.to(dtype=dtype, device=device)
- self.sin_cached = sin.to(dtype=dtype, device=device)
+ # Lưu FP32 rồi cast khi dùng để tránh precision mismatch
+ self.cos_cached = cos.to(dtype=torch.float32, device=device)
+ self.sin_cached = sin.to(dtype=torch.float32, device=device)
self.max_seq_len_cached = seq_len
def forward(self, x, seq_len: int = None):
@@ -80,19 +83,22 @@ def forward(self, x, seq_len: int = None):
seq_len = x.shape[-2]
if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len, dtype=x.dtype, device=x.device)
+ self._set_cos_sin_cache(seq_len, dtype=torch.float32, device=x.device)
+ # Cast về dtype và device của x tại thời điểm dùng (hỗ trợ FP16/BF16)
cos = self.cos_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device)
sin = self.sin_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device)
return cos, sin
class LlamaAttention(nn.Module):
- def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=64):
+ def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=64,
+ use_flash=False):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
+ self.use_flash = use_flash
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
@@ -101,6 +107,7 @@ def __init__(self, hidden_size, num_heads, num_kv_heads, dropout=0.0, head_dim=6
self.rope = LlamaRotaryEmbedding(head_dim)
self.dropout = nn.Dropout(dropout)
+ self._dropout_p = dropout
def forward(self, x, attention_mask=None):
B, T, C = x.size()
@@ -116,17 +123,9 @@ def forward(self, x, attention_mask=None):
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
- att = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
+ dp = self._dropout_p if self.training else 0.0
+ out = flash_scaled_dot_product(q, k, v, mask=attention_mask, dropout_p=dp)
- if attention_mask is not None:
- att = att.masked_fill(
- attention_mask == 0,
- torch.finfo(att.dtype).min
- )
-
- att = F.softmax(att, dim=-1)
- att = self.dropout(att)
- out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
@@ -143,10 +142,12 @@ def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class LlamaDecoderLayer(nn.Module):
- def __init__(self, hidden_size, num_heads, num_kv_heads, intermediate_size, rms_eps=1e-6):
+ def __init__(self, hidden_size, num_heads, num_kv_heads, intermediate_size,
+ rms_eps=1e-6, use_flash=False):
super().__init__()
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_eps)
- self.self_attn = LlamaAttention(hidden_size, num_heads, num_kv_heads)
+ self.self_attn = LlamaAttention(hidden_size, num_heads, num_kv_heads,
+ use_flash=use_flash)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_eps)
self.mlp = LlamaMLP(hidden_size, intermediate_size)
@@ -159,9 +160,10 @@ def forward(self, x, attention_mask=None):
class Llama(nn.Module):
def __init__(self, vocab_size=32000, hidden_size=768, intermediate_size=3072, num_attention_heads=12,
num_key_value_heads=12,
- layer_id=0, n_block=12):
+ layer_id=0, n_block=12, use_flash=False):
super().__init__()
self.layer_id = layer_id
+ self.use_flash = use_flash
self.config = DotDict(
model_type="llama",
vocab_size=vocab_size,
@@ -181,25 +183,25 @@ def __init__(self, vocab_size=32000, hidden_size=768, intermediate_size=3072, nu
use_cache=True,
torch_dtype="float32",
)
+
+ def _make_layers(n):
+ return nn.ModuleList([
+ LlamaDecoderLayer(hidden_size, num_attention_heads,
+ num_key_value_heads, intermediate_size,
+ use_flash=use_flash)
+ for _ in range(n)
+ ])
+
if self.layer_id == 1:
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
- self.layers = nn.ModuleList([
- LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in
- range(n_block)
- ])
+ self.layers = _make_layers(n_block)
elif self.layer_id == 2:
- self.layers = nn.ModuleList([
- LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in
- range(n_block)
- ])
+ self.layers = _make_layers(n_block)
self.norm = LlamaRMSNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
else:
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
- self.layers = nn.ModuleList([
- LlamaDecoderLayer(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size) for _ in
- range(n_block)
- ])
+ self.layers = _make_layers(n_block)
self.norm = LlamaRMSNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
@@ -208,7 +210,9 @@ def forward(self, input_ids, attention_mask=None, **kwargs):
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
- masks = build_causal_mask(B, T, attention_mask.device)
+ # Fix: dùng input_ids.device làm fallback khi attention_mask là None
+ ref_device = attention_mask.device if attention_mask is not None else input_ids.device
+ masks = build_causal_mask(B, T, ref_device)
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(masks.dtype)
qry_mask = attention_mask[:, None, :, None].to(masks.dtype)
@@ -229,7 +233,8 @@ def forward(self, input_ids, attention_mask=None, **kwargs):
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
- masks = build_causal_mask(B, T, attention_mask.device)
+ ref_device = attention_mask.device if attention_mask is not None else input_ids.device
+ masks = build_causal_mask(B, T, ref_device)
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(masks.dtype)
qry_mask = attention_mask[:, None, :, None].to(masks.dtype)
@@ -246,11 +251,4 @@ def forward(self, input_ids, attention_mask=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
B, T = input_ids.shape
device = input_ids.device
- causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T)
- if attention_mask is not None:
- key_mask = attention_mask[:, None, None, :].to(causal.dtype)
- qry_mask = attention_mask[:, None, :, None].to(causal.dtype)
- masks = causal * key_mask * qry_mask
- else:
- masks = causal
- return {"input_ids": input_ids, "attention_mask": masks, **kwargs}
+ causal = torch.ones(T, T, device=device).tril().unsqueeze(0).unsqueeze(1).expand(B, 1, T, T).contiguous()
diff --git a/src/model/__pycache__/Bert.cpython-313.pyc b/src/model/__pycache__/Bert.cpython-313.pyc
new file mode 100644
index 0000000..279f3da
Binary files /dev/null and b/src/model/__pycache__/Bert.cpython-313.pyc differ
diff --git a/src/model/__pycache__/GPT2.cpython-310.pyc b/src/model/__pycache__/GPT2.cpython-310.pyc
new file mode 100644
index 0000000..23c41c5
Binary files /dev/null and b/src/model/__pycache__/GPT2.cpython-310.pyc differ
diff --git a/src/model/__pycache__/GPT2.cpython-313.pyc b/src/model/__pycache__/GPT2.cpython-313.pyc
new file mode 100644
index 0000000..eb3dfec
Binary files /dev/null and b/src/model/__pycache__/GPT2.cpython-313.pyc differ
diff --git a/src/model/__pycache__/Llama.cpython-313.pyc b/src/model/__pycache__/Llama.cpython-313.pyc
new file mode 100644
index 0000000..d29f953
Binary files /dev/null and b/src/model/__pycache__/Llama.cpython-313.pyc differ
diff --git a/src/model/__pycache__/__init__.cpython-310.pyc b/src/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000..dcb7fd2
Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/model/__pycache__/__init__.cpython-313.pyc b/src/model/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..5412035
Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-313.pyc differ
diff --git a/src/val/Bert.py b/src/val/Bert.py
index 48b0d40..f048c52 100644
--- a/src/val/Bert.py
+++ b/src/val/Bert.py
@@ -9,7 +9,7 @@ def val_Bert(model_name, data_name, state_dict_full, logger):
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- test_loader = dataloader(model_name==model_name, data_name=data_name, train=False)
+ test_loader = dataloader(model_name=model_name, data_name=data_name, train=False)
model = Bert()
model = model.to(device)
model.load_state_dict(state_dict_full)
@@ -35,6 +35,8 @@ def val_Bert(model_name, data_name, state_dict_full, logger):
logger.log_info(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}")
+ return True
+
diff --git a/src/val/GPT2.py b/src/val/GPT2.py
index 5973ec7..4e3de3c 100644
--- a/src/val/GPT2.py
+++ b/src/val/GPT2.py
@@ -1,3 +1,4 @@
+import os
import torch
import torch.nn as nn
from tqdm import tqdm
@@ -7,96 +8,163 @@
import re
-def extract_final_number(s: str) -> str:
+# Metric cho E2E (cần cài: pip install sacrebleu rouge-score)
+try:
+ from sacrebleu import corpus_bleu
+ from rouge_score import rouge_scorer as rouge_scorer_lib
+ BLEU_AVAILABLE = True
+except ImportError:
+ BLEU_AVAILABLE = False
+
+def extract_final_number(s: str) -> str:
if s is None:
return ""
-
m = re.search(r"####\s*([\-+]?\d+(?:\.\d+)?)", s)
if m:
return m.group(1).strip()
-
nums = re.findall(r"[\-+]?\d+(?:\.\d+)?", s)
if nums:
return nums[-1].strip()
return ""
+
+def _greedy_generate(model, prompt_ids, prompt_mask, max_new_tokens, pad_id, device):
+ cur_ids = prompt_ids.clone()
+ cur_mask = prompt_mask.clone()
+ generated = []
+
+ for _ in range(max_new_tokens):
+ out = model(input_ids=cur_ids, attention_mask=cur_mask)
+ logits = out["hidden_states"]
+ next_token = logits[0, -1, :].argmax(-1).item()
+ generated.append(next_token)
+
+ if next_token == pad_id:
+ break
+
+ next_tensor = torch.tensor([[next_token]], device=device)
+ next_mask = torch.ones((1, 1), device=device, dtype=cur_mask.dtype)
+ cur_ids = torch.cat([cur_ids, next_tensor], dim=1)
+ cur_mask = torch.cat([cur_mask, next_mask], dim=1)
+
+ return generated
+
+
def val_GPT2(model_name, data_name, state_dict_full, logger):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Eval device:", device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
-
pad_id = tokenizer.pad_token_id
- loss_fct = nn.CrossEntropyLoss(
- ignore_index=pad_id,
- )
+ loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='sum')
test_loader = dataloader(model_name=model_name, data_name=data_name, train=False)
model = GPT2()
- model.load_state_dict(state_dict_full)
+
+ pretrained_path = f"{model_name}.pt"
+ if os.path.exists(pretrained_path):
+ base_state = torch.load(pretrained_path, map_location="cpu")
+ missing, unexpected = model.load_state_dict(base_state, strict=False)
+ if missing:
+ logger.log_info(f"[val_GPT2] Base load missing keys: {missing}")
+ if unexpected:
+ logger.log_info(f"[val_GPT2] Base load unexpected keys: {unexpected}")
+ else:
+ logger.log_warning(
+ f"[val_GPT2] Pretrained file '{pretrained_path}' not found. "
+ f"Validating with random base weights."
+ )
+
+ if state_dict_full:
+ missing, unexpected = model.load_state_dict(state_dict_full, strict=False)
+ if missing:
+ logger.log_info(f"[val_GPT2] LoRA overlay missing keys: {missing}")
+ if unexpected:
+ logger.log_info(f"[val_GPT2] LoRA overlay unexpected keys: {unexpected}")
+
model = model.to(device)
model.eval()
- total_loss = 0.0
- total_tokens = 0
-
- total_samples = 0
+ total_loss = 0.0
+ total_tokens = 0
+ total_samples = 0
correct_samples = 0
+ nl_id = tokenizer.encode("\n", add_special_tokens=False)[0]
+ is_e2e = (data_name == 'E2E')
+ hypotheses = []
+ references = []
+
with torch.no_grad():
for batch in tqdm(test_loader):
input_ids = batch['input_ids'].to(device)
- mask = batch['attention_mask'].to(device)
- labels = batch['labels'].to(device)
+ mask = batch['attention_mask'].to(device)
+ labels = batch['labels'].to(device)
- logits, _ = model(input_ids, mask)
+ out = model(input_ids, mask)
+ logits = out["hidden_states"]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
-
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1)
+ shift_labels.view(-1),
)
- total_loss += loss.item()
-
- predicted_tokens = torch.argmax(logits, dim=-1)
-
- generated_texts = [
- tokenizer.decode(p, skip_special_tokens=True).strip()
- for p in predicted_tokens
- ]
- reference_texts = [
- tokenizer.decode(l, skip_special_tokens=True).strip()
- for l in labels
- ]
-
- for gen, ref in zip(generated_texts, reference_texts):
- if not gen:
- gen = " "
- if not ref:
- ref = " "
- # print(f'gen : {gen}')
- # print(f'ref : {ref}')
- # break
- # break
-
- pred_num = extract_final_number(gen)
- gold_num = extract_final_number(ref)
+ valid = (shift_labels != pad_id).sum().item()
+ total_loss += loss.item()
+ total_tokens += valid
+
+ for i in range(input_ids.size(0)):
+ tokens = input_ids[i].tolist()
+ if nl_id in tokens:
+ cut = tokens.index(nl_id) + 1
+ else:
+ cut = len(tokens) // 2
+
+ prompt_ids = input_ids[i, :cut].unsqueeze(0)
+ prompt_mask = mask[i, :cut].unsqueeze(0)
+
+ generated = _greedy_generate(
+ model, prompt_ids, prompt_mask,
+ max_new_tokens=64, pad_id=pad_id, device=device
+ )
+
+ gen_text = tokenizer.decode(generated, skip_special_tokens=True).strip()
+ ref_text = tokenizer.decode(
+ labels[i][labels[i] != pad_id], skip_special_tokens=True
+ ).strip()
total_samples += 1
- if gold_num and pred_num == gold_num:
- correct_samples += 1
-
- avg_loss = total_loss / max(total_samples, 1)
- accuracy = correct_samples / max(total_samples, 1)
-
- print(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%")
- logger.log_info(
- f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%"
- )
+ if is_e2e:
+ # E2E: dùng BLEU/ROUGE
+ hypotheses.append(gen_text)
+ references.append(ref_text)
+ else:
+ # GSM8K: dùng exact match số
+ pred_num = extract_final_number(gen_text)
+ gold_num = extract_final_number(ref_text)
+ if gold_num and pred_num == gold_num:
+ correct_samples += 1
+
+ avg_loss = total_loss / max(total_tokens, 1)
+
+ if is_e2e and BLEU_AVAILABLE and hypotheses:
+ bleu = corpus_bleu(hypotheses, [references]).score
+ scorer = rouge_scorer_lib.RougeScorer(["rougeL"], use_stemmer=True)
+ rouge_l = sum(
+ scorer.score(r, h)["rougeL"].fmeasure
+ for h, r in zip(hypotheses, references)
+ ) / len(hypotheses)
+ print(f"Loss / token : {avg_loss:.4f}; BLEU: {bleu:.2f}; ROUGE-L: {rouge_l:.4f}")
+ logger.log_info(f"Loss / token : {avg_loss:.4f}; BLEU: {bleu:.2f}; ROUGE-L: {rouge_l:.4f}")
+ else:
+ accuracy = correct_samples / max(total_samples, 1)
+ print(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%")
+ logger.log_info(f"Loss / token : {avg_loss:.4f}; Accuracy (answer) : {accuracy * 100:.2f}%")
+
+ return model.state_dict()
\ No newline at end of file
diff --git a/src/val/Llama.py b/src/val/Llama.py
index c62caec..4c6da23 100644
--- a/src/val/Llama.py
+++ b/src/val/Llama.py
@@ -14,10 +14,11 @@ def val_Llama(model_name, data_name, state_dict_full, logger):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Eval device:", device)
- smooth = SmoothingFunction().method1
- scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
+ smooth = SmoothingFunction().method1
+ scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-160m")
tokenizer.pad_token = tokenizer.eos_token
+ pad_id = tokenizer.pad_token_id
test_loader = dataloader(model_name=model_name, data_name=data_name, train=False)
@@ -26,74 +27,70 @@ def val_Llama(model_name, data_name, state_dict_full, logger):
model = model.to(device)
model.eval()
- total_bleu = 0.0
- total_rouge = 0.0
- total_perplexity = 0.0
+ # Fix: tạo criterion 1 lần ngoài vòng lặp (tránh tạo object mới mỗi batch)
+ loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='sum')
+
+ total_bleu = 0.0
+ total_rouge = 0.0
+ total_log_loss = 0.0 # Fix: tích lũy log-loss thay vì perplexity trực tiếp
+ total_valid_tok = 0 # Fix: đếm token hợp lệ để avg chính xác
count = 0
with torch.no_grad():
for batch in tqdm(test_loader):
input_ids = batch['input_ids'].to(device)
- mask = batch['attention_mask'].to(device)
- labels = batch['input_ids'].to(device)
+ mask = batch['attention_mask'].to(device)
+ labels = batch['input_ids'].to(device) # causal LM: labels = input
logits, _ = model(input_ids, mask)
predicted_tokens = torch.argmax(logits, dim=-1)
-
- generated_texts = [
+ generated_texts = [
tokenizer.decode(p, skip_special_tokens=True).strip()
for p in predicted_tokens
]
- reference_texts = [
+ reference_texts = [
tokenizer.decode(l, skip_special_tokens=True).strip()
for l in labels
]
for gen, ref in zip(generated_texts, reference_texts):
- if not gen:
- gen = " "
- if not ref:
- ref = " "
-
- bleu_score = sentence_bleu(
- [ref.split()],
- gen.split(),
- smoothing_function=smooth
- )
+ gen = gen or " "
+ ref = ref or " "
+ bleu = sentence_bleu([ref.split()], gen.split(),
+ smoothing_function=smooth)
try:
rouge_l = scorer.score(ref, gen)["rougeL"].fmeasure
- except:
+ except Exception:
rouge_l = 0.0
- total_bleu += bleu_score
+ total_bleu += bleu
total_rouge += rouge_l
count += 1
+ # Fix: perplexity = exp(avg_cross_entropy_per_token)
+ # Tích lũy tổng cross-entropy và tổng token; tính exp 1 lần cuối
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
- loss_fct = nn.CrossEntropyLoss(
- ignore_index=tokenizer.pad_token_id,
- reduction='sum'
- )
-
- loss = loss_fct(
+ loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1)
+ shift_labels.view(-1),
)
+ valid_tokens = (shift_labels != pad_id).sum().item()
+ total_log_loss += loss.item()
+ total_valid_tok += max(valid_tokens, 1)
- valid_tokens = (shift_labels != tokenizer.pad_token_id).sum().item()
- perplexity = math.exp(loss.item() / max(1, valid_tokens))
- total_perplexity += perplexity
+ # Fix: tính perplexity từ avg log-loss (tránh exp overflow mỗi batch)
+ avg_log_loss = total_log_loss / max(total_valid_tok, 1)
+ # Clamp để tránh math overflow với exp()
+ avg_perplexity = math.exp(min(avg_log_loss, 20.0))
- avg_bleu = total_bleu / count
- avg_rouge = total_rouge / count
- avg_perplexity = total_perplexity / count
+ avg_bleu = total_bleu / max(count, 1)
+ avg_rouge = total_rouge / max(count, 1)
print(f"Evaluation Results: BLEU: {avg_bleu:.4f}, ROUGE-L: {avg_rouge:.4f}, Perplexity: {avg_perplexity:.4f}")
-
logger.log_info(
f"Evaluation Results: BLEU: {avg_bleu:.4f}, ROUGE-L: {avg_rouge:.4f}, Perplexity: {avg_perplexity:.4f}"
)
diff --git a/src/val/__pycache__/Bert.cpython-313.pyc b/src/val/__pycache__/Bert.cpython-313.pyc
new file mode 100644
index 0000000..f22348b
Binary files /dev/null and b/src/val/__pycache__/Bert.cpython-313.pyc differ
diff --git a/src/val/__pycache__/GPT2.cpython-313.pyc b/src/val/__pycache__/GPT2.cpython-313.pyc
new file mode 100644
index 0000000..2b4277b
Binary files /dev/null and b/src/val/__pycache__/GPT2.cpython-313.pyc differ
diff --git a/src/val/__pycache__/Llama.cpython-313.pyc b/src/val/__pycache__/Llama.cpython-313.pyc
new file mode 100644
index 0000000..9d068d7
Binary files /dev/null and b/src/val/__pycache__/Llama.cpython-313.pyc differ
diff --git a/src/val/__pycache__/__init__.cpython-313.pyc b/src/val/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..3cc69da
Binary files /dev/null and b/src/val/__pycache__/__init__.cpython-313.pyc differ
diff --git a/src/val/__pycache__/get_val.cpython-313.pyc b/src/val/__pycache__/get_val.cpython-313.pyc
new file mode 100644
index 0000000..54c88b2
Binary files /dev/null and b/src/val/__pycache__/get_val.cpython-313.pyc differ
diff --git a/src/val/get_val.py b/src/val/get_val.py
index 700c749..9bfc08c 100644
--- a/src/val/get_val.py
+++ b/src/val/get_val.py
@@ -2,11 +2,22 @@
from src.val.Llama import val_Llama
from src.val.Bert import val_Bert
+
def get_val(model_name, data_name, state_dict_full, logger):
- if model_name == 'GPT2':
- val_GPT2(model_name, data_name, state_dict_full, logger)
- elif model_name == 'Llama':
- val_Llama(model_name, data_name, state_dict_full, logger)
- elif model_name == 'Bert':
- val_Bert(model_name, data_name, state_dict_full, logger)
- return True
+ """
+ Chạy validation và trả về full state dict nếu thành công, None nếu lỗi.
+ Server dùng full state dict này để lưu vào GPT2.pt cho round tiếp theo.
+ """
+ try:
+ if model_name == 'GPT2':
+ return val_GPT2(model_name, data_name, state_dict_full, logger)
+ elif model_name == 'Llama':
+ return val_Llama(model_name, data_name, state_dict_full, logger)
+ elif model_name == 'Bert':
+ return val_Bert(model_name, data_name, state_dict_full, logger)
+ else:
+ logger.log_warning(f"Unknown model_name '{model_name}' for validation.")
+ return None
+ except Exception as e:
+ logger.log_error(f"Validation failed with exception: {e}")
+ return None
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..3c58111
--- /dev/null
+++ b/test.py
@@ -0,0 +1,177 @@
+"""
+test.py — Evaluate GPT2 checkpoints sau khi train SplitFedLLM
+
+Usage:
+ python test.py # test GPT2.pt
+ python test.py --model GPT2_round3.pt
+ python test.py --model GPT2.pt --device cpu
+"""
+import argparse
+import os
+import torch
+import torch.nn as nn
+from transformers import GPT2Tokenizer
+from src.model.GPT2 import GPT2
+
+# ── Args ──────────────────────────────────────────────────────────────────────
+parser = argparse.ArgumentParser()
+parser.add_argument("--model", default="GPT2.pt", help="Path to .pt checkpoint")
+parser.add_argument("--device", default=None, help="cpu / cuda (auto-detect nếu bỏ qua)")
+parser.add_argument("--tokens", default=80, type=int, help="Max new tokens khi generate")
+args = parser.parse_args()
+
+DEVICE = torch.device(
+ args.device if args.device
+ else ("cuda" if torch.cuda.is_available() else "cpu")
+)
+print(f"[INFO] Device: {DEVICE}")
+
+TEST_PROMPTS = [
+ " name[The Golden Curry], food[Fast food], customer rating[low], area[riverside], familyFriendly[yes], near[Café Rouge] ",
+ " name[Fitzbillies], eatType[coffee shop], food[French], priceRange[£20-25], customer rating[3 out of 5] ",
+ " name[The Twenty Two], eatType[restaurant], food[Italian], familyFriendly[no] ",
+ " name[Cotto], eatType[coffee shop], food[Indian], priceRange[moderate], area[riverside], near[The Portland Arms] ",
+ " name[Giraffe], eatType[pub], food[Fast food], area[city centre], familyFriendly[no] ",
+]
+
+EVAL_PAIRS = [
+ (
+ " name[The Golden Curry], food[Fast food] ",
+ "The Golden Curry is a fast food place with a low customer rating located near Café Rouge."
+ ),
+ (
+ " name[Fitzbillies], eatType[coffee shop], food[French] ",
+ "Fitzbillies is a French coffee shop with a customer rating of 3 out of 5."
+ ),
+]
+
+# ── Load model ────────────────────────────────────────────────────────────────
+def load_model(path: str):
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"Không tìm thấy: {path}")
+
+ print(f"\n[INFO] Loading {path} ({os.path.getsize(path)/1e6:.1f} MB)")
+ sd = torch.load(path, map_location="cpu")
+
+ model = GPT2()
+ missing, unexpected = model.load_state_dict(sd, strict=False)
+ print(f"[DEBUG] Missing keys : {len(missing)}")
+ print(f"[DEBUG] Unexpected keys: {len(unexpected)}")
+
+ if hasattr(model, "lm_head") and hasattr(model, "wte"):
+ model.lm_head.weight = model.wte.weight
+ print("[OK] Weight tying applied")
+
+ model.to(DEVICE)
+ model.eval()
+ return model
+
+# ── Generate ──────────────────────────────────────────────────────────────────
+def generate(model, tokenizer, prompt: str, max_new_tokens: int = 80) -> str:
+ """Top-k sampling + repetition penalty."""
+ ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
+ prompt_len = ids.shape[1]
+ generated = []
+
+ for _ in range(max_new_tokens):
+ with torch.no_grad():
+ out = model(input_ids=ids)
+ logits = out["logits"][:, -1, :].clone() # (1, vocab)
+
+ # Repetition penalty mạnh hơn (1.5 thay vì 1.3)
+ for tid in set(ids[0].tolist()):
+ logits[0, tid] /= 1.5
+
+ # Penalty riêng cho các từ lặp gần đây (context 20 tokens cuối)
+ recent = ids[0, -20:].tolist()
+ for tid in set(recent):
+ logits[0, tid] /= 1.3
+
+ # Temperature + top-k (top-k=40 để output tập trung hơn)
+ logits /= 0.7
+ topk_vals, topk_idx = torch.topk(logits, 40)
+ probs = torch.softmax(topk_vals, dim=-1)
+ next_token = topk_idx[0, torch.multinomial(probs, 1)]
+
+ if next_token.item() == tokenizer.eos_token_id:
+ break
+
+ generated.append(next_token.item())
+ ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
+
+ output = tokenizer.decode(generated, skip_special_tokens=True).strip()
+ return output
+
+# ── Loss / Perplexity ─────────────────────────────────────────────────────────
+def compute_loss(model, tokenizer, pairs: list) -> tuple:
+ """
+ pairs: list of (prompt, reference) tuples.
+ Tính loss chỉ trên phần reference (không tính prompt), giống notebook.
+ """
+ criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
+ total_loss = 0.0
+ total_tokens = 0
+
+ with torch.no_grad():
+ for prompt, ref in pairs:
+ full_text = prompt + " " + ref
+ full_ids = tokenizer.encode(full_text, return_tensors="pt").to(DEVICE)
+ prompt_len = len(tokenizer.encode(prompt))
+
+ out = model(input_ids=full_ids)
+ logits = out["logits"] # (1, seq, vocab)
+
+ shift_logits = logits[:, :-1, :].contiguous()
+ shift_labels = full_ids[:, 1:].contiguous().clone()
+
+ # Mask phần prompt: chỉ tính loss trên phần reference
+ shift_labels[:, :prompt_len - 1] = -100
+
+ loss = criterion(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1),
+ )
+ n_ref_tokens = (shift_labels != -100).sum().item()
+ total_loss += loss.item()
+ total_tokens += n_ref_tokens
+
+ avg_loss = total_loss / max(total_tokens, 1)
+ ppl = torch.exp(torch.tensor(avg_loss)).item()
+ return avg_loss, ppl
+
+# ── Sanity check ──────────────────────────────────────────────────────────────
+def sanity_check(model, tokenizer):
+ print("\n[Sanity check]")
+ ids = tokenizer.encode("The restaurant", return_tensors="pt").to(DEVICE)
+ with torch.no_grad():
+ out = model(input_ids=ids)
+ print("[DEBUG] Output keys:", list(out.keys()))
+ if "logits" not in out:
+ raise RuntimeError("Model không trả về 'logits' → kiểm tra GPT2.forward()")
+ topk = torch.topk(out["logits"][0, -1], 5)
+ print("[DEBUG] Top next tokens:", [tokenizer.decode([i]) for i in topk.indices.tolist()])
+
+# ── Main ──────────────────────────────────────────────────────────────────────
+def main():
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ tokenizer.pad_token = tokenizer.eos_token
+
+ model = load_model(args.model)
+ sanity_check(model, tokenizer)
+
+ # ── Generation ────────────────────────────────────────────────────────────
+ print("\n" + "=" * 60)
+ for prompt in TEST_PROMPTS:
+ output = generate(model, tokenizer, prompt, max_new_tokens=args.tokens)
+ print(f"\nPROMPT:\n{prompt}")
+ print(f"OUTPUT:\n{output}")
+ print("-" * 60)
+
+ # ── Loss / Perplexity ─────────────────────────────────────────────────────
+ loss, ppl = compute_loss(model, tokenizer, EVAL_PAIRS)
+ print("\n" + "=" * 60)
+ print(f"[METRIC] Loss: {loss:.4f} | Perplexity: {ppl:.2f}")
+ print("=" * 60)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file