Skip to content

Commit 83f9d95

Browse files
authored
Merge pull request #128 from sallyom/add-api-key
add option to pass 'api_key' to gen_answers, judge_answers
2 parents 6b3495b + 3445ce0 commit 83f9d95

4 files changed

Lines changed: 27 additions & 8 deletions

File tree

src/instructlab/eval/mt_bench.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,27 +94,30 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator):
9494

9595
name = "mt_bench"
9696

97-
def gen_answers(self, server_url) -> None:
97+
def gen_answers(self, server_url, api_key: str | None = None) -> None:
9898
"""
9999
Asks questions to model
100100
101101
Attributes
102102
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
103+
api_key API token for authenticating with model server
103104
"""
104105
logger.debug(locals())
105106
mt_bench_answers.generate_answers(
106107
self.model_name,
107108
server_url,
109+
api_key=api_key,
108110
output_dir=self.output_dir,
109111
max_workers=self.max_workers,
110112
)
111113

112-
def judge_answers(self, server_url) -> tuple:
114+
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
113115
"""
114116
Runs MT-Bench judgment
115117
116118
Attributes
117119
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
120+
api_key API token for authenticating with model server
118121
119122
Returns:
120123
overall_score MT-Bench score for the overall model evaluation
@@ -126,6 +129,7 @@ def judge_answers(self, server_url) -> tuple:
126129
self.model_name,
127130
self.judge_model_name,
128131
server_url,
132+
api_key=api_key,
129133
max_workers=self.max_workers,
130134
output_dir=self.output_dir,
131135
merge_system_user_message=self.merge_system_user_message,
@@ -171,12 +175,13 @@ def __init__(
171175
self.taxonomy_git_repo_path = taxonomy_git_repo_path
172176
self.branch = branch
173177

174-
def gen_answers(self, server_url) -> None:
178+
def gen_answers(self, server_url, api_key: str | None = None) -> None:
175179
"""
176180
Asks questions to model
177181
178182
Attributes
179183
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
184+
api_key API token for authenticating with model server
180185
"""
181186
logger.debug(locals())
182187
mt_bench_branch_generator.generate(
@@ -188,19 +193,21 @@ def gen_answers(self, server_url) -> None:
188193
mt_bench_answers.generate_answers(
189194
self.model_name,
190195
server_url,
196+
api_key=api_key,
191197
branch=self.branch,
192198
output_dir=self.output_dir,
193199
data_dir=self.output_dir,
194200
max_workers=self.max_workers,
195201
bench_name="mt_bench_branch",
196202
)
197203

198-
def judge_answers(self, server_url) -> tuple:
204+
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
199205
"""
200206
Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
201207
202208
Attributes
203209
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
210+
api_key API token for authenticating with model server
204211
205212
Returns:
206213
qa_pairs Question and answer pairs (with scores) from the evaluation
@@ -210,6 +217,7 @@ def judge_answers(self, server_url) -> tuple:
210217
self.model_name,
211218
self.judge_model_name,
212219
server_url,
220+
api_key=api_key,
213221
branch=self.branch,
214222
max_workers=self.max_workers,
215223
output_dir=self.output_dir,

src/instructlab/eval/mt_bench_answers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# Third Party
99
# TODO need to look into this dependency
1010
from fastchat.model.model_adapter import get_conversation_template # type: ignore
11-
import openai
1211
import shortuuid
1312
import tqdm
1413

@@ -17,6 +16,7 @@
1716
from .mt_bench_common import (
1817
bench_dir,
1918
chat_completion_openai,
19+
get_openai_client,
2020
load_questions,
2121
temperature_config,
2222
)
@@ -98,6 +98,7 @@ def get_answer(
9898
def generate_answers(
9999
model_name,
100100
model_api_base,
101+
api_key=None,
101102
branch=None,
102103
output_dir="eval_output",
103104
data_dir=None,
@@ -111,7 +112,8 @@ def generate_answers(
111112
):
112113
"""Generate model answers to be judged"""
113114
logger.debug(locals())
114-
openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY")
115+
116+
openai_client = get_openai_client(model_api_base, api_key)
115117

116118
if data_dir is None:
117119
data_dir = os.path.join(os.path.dirname(__file__), "data")

src/instructlab/eval/mt_bench_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,10 @@ def check_data(questions, model_answers, ref_answers, models, judges):
365365
def get_model_list(answer_file):
366366
logger.debug(locals())
367367
return [os.path.splitext(os.path.basename(answer_file))[0]]
368+
369+
370+
def get_openai_client(model_api_base, api_key):
371+
if api_key is None:
372+
api_key = "NO_API_KEY"
373+
openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key)
374+
return openai_client

src/instructlab/eval/mt_bench_judgment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Third Party
77
from tqdm import tqdm
88
import numpy as np
9-
import openai
109
import pandas as pd
1110

1211
# Local
@@ -18,6 +17,7 @@
1817
bench_dir,
1918
check_data,
2019
get_model_list,
20+
get_openai_client,
2121
load_judge_prompts,
2222
load_model_answers,
2323
load_questions,
@@ -278,6 +278,7 @@ def generate_judgment(
278278
model_name,
279279
judge_model_name,
280280
model_api_base,
281+
api_key=None,
281282
bench_name="mt_bench",
282283
output_dir="eval_output",
283284
data_dir=None,
@@ -288,7 +289,8 @@ def generate_judgment(
288289
):
289290
"""Generate judgment with scores and qa_pairs for a model"""
290291
logger.debug(locals())
291-
openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY")
292+
293+
openai_client = get_openai_client(model_api_base, api_key)
292294

293295
first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS")
294296
if first_n_env is not None and first_n is None:

0 commit comments

Comments
 (0)