Conversation
| self, | ||
| model_path, | ||
| device=None, | ||
| dtype="float16", |
There was a problem hiding this comment.
为什么要指定数据类型,增加这个参数,默认的类型不能跑么
There was a problem hiding this comment.
为什么要指定数据类型,增加这个参数,默认的类型不能跑么
openai-community/gpt2 的 config.json 中没有 dtype 相关配置,无法通过读取 HF 配置自动获取数据类型,因此需要在测试命令中显式指定 dtype。如果不指定,当前代码无法确定模型应使用的数据类型进行加载。
There was a problem hiding this comment.
gpt2跑了,其他模型能跑么。 dtype=这个参数需要再斟酌。
接口的优先级更高,但下面的实现却是config.json的dtype优先级高
if self.hf_config.get("torch_dtype") is None and self.hf_config.get("dtype") is None:
self.hf_config["torch_dtype"] = dtype
There was a problem hiding this comment.
gpt2跑了,其他模型能跑么。 dtype=这个参数需要再斟酌。
接口的优先级更高,但下面的实现却是config.json的dtype优先级高 if self.hf_config.get("torch_dtype") is None and self.hf_config.get("dtype") is None: self.hf_config["torch_dtype"] = dtype
在 read_hf_config 通过判断模型是否为 gpt2,如果为 gpt2 增加 torch_dtype 为 fp32,进而不再引入 dtype 相关参数
|
|
||
| for k in f.keys(): | ||
| state_dict[k] = f.get_tensor(k).to(device=device) | ||
| state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype) |
There was a problem hiding this comment.
添加.to(dtype=dtype)的话,轶群的量化模型可能就不能跑了
There was a problem hiding this comment.
添加.to(dtype=dtype)的话,轶群的量化模型可能就不能跑了
已通过其他方式规避
|
给出tp=2的测试截图 |
| outputs = model.chat( | ||
| messages=conversations, | ||
| ) | ||
| if getattr(model.engine.tokenizer, "chat_template", None): |
| sampling_params = self._build_sampling_params(data) | ||
|
|
||
| req = self.engine.add_chat_request( | ||
| req = self._add_generation_request( |
There was a problem hiding this comment.
需要再斟酌一下,或许不应该这么修改
gpt2 不支持 chat_template,是需要和原有服务走的路径形成区分。如果不这样修改的话,怎么修改比较合适?


InfiniLM 支持 GPT 2,测试截图如下
InfiniLM 推理与Transformer 推理对比测试截图:

备注:
openai-community_gpt2 存储的数据类型为 fp32,paged_attention 相关算子只支持 bf16/fp16,后面的测试均基于静态 cache 进行;
服务端启动命令参考:
CUDA_VISIBLE_DEVICES=6,7 python python/infinilm/server/inference_server.py
--model /data-aisoft/mechdancer/models/openai-community_gpt2
--device nvidia
--tp 2
--num-blocks 1024
--block-size 256
--max-batch-size 32
--max-new-tokens 512
--host 0.0.0.0
--port 8000
客户端发送请求参考:
curl --noproxy '*' -v -X POST http://127.0.0.1:8000/chat/completions
-H 'Content-Type: application/json'
-d '{
"prompt": "tell me a story",
"stream": false,
"max_tokens": 50
}'