Skip to content

Commit 2793267

Browse files
committed
Merge branch 'feature/ollama-support' into develop
2 parents e62afda + d085169 commit 2793267

9 files changed

Lines changed: 413 additions & 10 deletions

File tree

Core/Sources/CodeCompletionService/CodeCompletionService.swift

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ import Fundamental
33
import Storage
44

55
protocol CodeCompletionServiceType {
6-
func getCompletion(
7-
_ request: PromptStrategy
8-
) async throws -> AsyncStream<String>
6+
associatedtype CompletionSequence: AsyncSequence where CompletionSequence.Element == String
7+
8+
func getCompletion(_ request: PromptStrategy) async throws -> CompletionSequence
99
}
1010

1111
extension CodeCompletionServiceType {
@@ -115,6 +115,18 @@ public struct CodeCompletionService {
115115
let result = try await service.getCompletions(prompt, count: count)
116116
try Task.checkCancellation()
117117
return result
118+
case .ollama:
119+
let service = OllamaService(
120+
url: model.endpoint,
121+
endpoint: .chatCompletion,
122+
modelName: model.info.modelName,
123+
stopWords: prompt.stopWords,
124+
keepAlive: model.info.ollamaKeepAlive,
125+
format: .none
126+
)
127+
let result = try await service.getCompletions(prompt, count: count)
128+
try Task.checkCancellation()
129+
return result
118130
case .unknown:
119131
throw Error.unknownFormat
120132
}
@@ -150,6 +162,18 @@ public struct CodeCompletionService {
150162
let result = try await service.getCompletions(prompt, count: count)
151163
try Task.checkCancellation()
152164
return result
165+
case .ollama:
166+
let service = OllamaService(
167+
url: model.endpoint,
168+
endpoint: .completion,
169+
modelName: model.info.modelName,
170+
stopWords: prompt.stopWords,
171+
keepAlive: model.info.ollamaKeepAlive,
172+
format: .none
173+
)
174+
let result = try await service.getCompletions(prompt, count: count)
175+
try Task.checkCancellation()
176+
return result
153177
case .unknown:
154178
throw Error.unknownFormat
155179
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import CopilotForXcodeKit
2+
import Foundation
3+
import Fundamental
4+
5+
public actor OllamaService {
6+
let url: URL
7+
let endpoint: Endpoint
8+
let modelName: String
9+
let maxToken: Int
10+
let temperature: Double
11+
let stopWords: [String]
12+
let keepAlive: String
13+
let format: ResponseFormat
14+
15+
public enum ResponseFormat: String {
16+
case none = ""
17+
case json = "json"
18+
}
19+
20+
public enum Endpoint {
21+
case completion
22+
case chatCompletion
23+
}
24+
25+
init(
26+
url: String? = nil,
27+
endpoint: Endpoint,
28+
modelName: String,
29+
maxToken: Int? = nil,
30+
temperature: Double = 0.2,
31+
stopWords: [String] = [],
32+
keepAlive: String = "",
33+
format: ResponseFormat = .none
34+
) {
35+
self.url = url.flatMap(URL.init(string:)) ?? {
36+
switch endpoint {
37+
case .chatCompletion:
38+
URL(string: "https://127.0.0.1:11434/api/chat")!
39+
case .completion:
40+
URL(string: "https://127.0.0.1:11434/api/generate")!
41+
}
42+
}()
43+
44+
self.endpoint = endpoint
45+
self.modelName = modelName
46+
self.maxToken = maxToken ?? 4096
47+
self.temperature = temperature
48+
self.stopWords = stopWords
49+
self.keepAlive = keepAlive
50+
self.format = format
51+
}
52+
}
53+
54+
extension OllamaService: CodeCompletionServiceType {
55+
typealias CompletionSequence = AsyncThrowingCompactMapSequence<
56+
ResponseStream<OllamaService.ChatCompletionResponseChunk>,
57+
String
58+
>
59+
60+
func getCompletion(
61+
_ request: PromptStrategy
62+
) async throws -> CompletionSequence {
63+
switch endpoint {
64+
case .chatCompletion:
65+
let messages = createMessages(from: request)
66+
CodeCompletionLogger.logger.logPrompt(messages.map {
67+
($0.content, $0.role.rawValue)
68+
})
69+
let stream = try await sendMessages(messages)
70+
return stream.compactMap { $0.message?.content }
71+
case .completion:
72+
let prompt = createPrompt(from: request)
73+
CodeCompletionLogger.logger.logPrompt([(prompt, "user")])
74+
let stream = try await sendPrompt(prompt)
75+
return stream.compactMap { $0.response }
76+
}
77+
}
78+
}
79+
80+
extension OllamaService {
81+
struct Message: Codable, Equatable {
82+
public enum Role: String, Codable {
83+
case user
84+
case assistant
85+
case system
86+
}
87+
88+
/// The role of the message.
89+
public var role: Role
90+
/// The content of the message.
91+
public var content: String
92+
}
93+
94+
enum Error: Swift.Error, LocalizedError {
95+
case decodeError(Swift.Error)
96+
case otherError(String)
97+
98+
public var errorDescription: String? {
99+
switch self {
100+
case let .decodeError(error):
101+
return error.localizedDescription
102+
case let .otherError(message):
103+
return message
104+
}
105+
}
106+
}
107+
}
108+
109+
// MARK: - Chat Completion API
110+
111+
/// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-streaming
112+
extension OllamaService {
113+
struct ChatCompletionRequestBody: Codable {
114+
struct Options: Codable {
115+
var temperature: Double
116+
var stop: [String]
117+
var num_predict: Int
118+
var top_k: Int?
119+
var top_p: Double?
120+
}
121+
122+
var model: String
123+
var messages: [Message]
124+
var stream: Bool
125+
var options: Options
126+
var keep_alive: String?
127+
var format: String?
128+
}
129+
130+
struct ChatCompletionResponseChunk: Decodable {
131+
var model: String
132+
var message: Message?
133+
var response: String?
134+
var done: Bool
135+
var total_duration: Int64?
136+
var load_duration: Int64?
137+
var prompt_eval_count: Int?
138+
var prompt_eval_duration: Int64?
139+
var eval_count: Int?
140+
var eval_duration: Int64?
141+
}
142+
143+
func createMessages(from request: PromptStrategy) -> [Message] {
144+
let strategy = DefaultTruncateStrategy(maxTokenLimit: max(
145+
maxToken / 3 * 2,
146+
maxToken - 300 - 20
147+
))
148+
let prompts = strategy.createTruncatedPrompt(promptStrategy: request)
149+
return [
150+
.init(role: .system, content: request.systemPrompt),
151+
] + prompts.map { prompt in
152+
switch prompt.role {
153+
case .user:
154+
return .init(role: .user, content: prompt.content)
155+
case .assistant:
156+
return .init(role: .assistant, content: prompt.content)
157+
}
158+
}
159+
}
160+
161+
func sendMessages(_ messages: [Message]) async throws
162+
-> ResponseStream<ChatCompletionResponseChunk>
163+
{
164+
let requestBody = ChatCompletionRequestBody(
165+
model: modelName,
166+
messages: messages,
167+
stream: true,
168+
options: .init(
169+
temperature: temperature,
170+
stop: stopWords,
171+
num_predict: 300
172+
),
173+
keep_alive: keepAlive.isEmpty ? nil : keepAlive,
174+
format: format == .none ? nil : format.rawValue
175+
)
176+
177+
var request = URLRequest(url: url)
178+
request.httpMethod = "POST"
179+
let encoder = JSONEncoder()
180+
request.httpBody = try encoder.encode(requestBody)
181+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
182+
let (result, response) = try await URLSession.shared.bytes(for: request)
183+
184+
guard let response = response as? HTTPURLResponse else {
185+
throw CancellationError()
186+
}
187+
188+
guard response.statusCode == 200 else {
189+
let text = try await result.lines.reduce(into: "") { partialResult, current in
190+
partialResult += current
191+
}
192+
throw Error.otherError(text)
193+
}
194+
195+
return ResponseStream(result: result)
196+
}
197+
}
198+
199+
// MARK: - Completion API
200+
201+
extension OllamaService {
202+
struct CompletionRequestBody: Codable {
203+
var model: String
204+
var prompt: String
205+
var stream: Bool
206+
var options: ChatCompletionRequestBody.Options
207+
var keep_alive: String?
208+
var format: String?
209+
}
210+
211+
func createPrompt(from request: PromptStrategy) -> String {
212+
let strategy = DefaultTruncateStrategy(maxTokenLimit: max(
213+
maxToken / 3 * 2,
214+
maxToken - 300 - 20
215+
))
216+
let prompts = strategy.createTruncatedPrompt(promptStrategy: request)
217+
return ([request.systemPrompt] + prompts.map(\.content)).joined(separator: "\n\n")
218+
}
219+
220+
func sendPrompt(_ prompt: String) async throws -> ResponseStream<ChatCompletionResponseChunk> {
221+
let requestBody = CompletionRequestBody(
222+
model: modelName,
223+
prompt: prompt,
224+
stream: true,
225+
options: .init(
226+
temperature: temperature,
227+
stop: stopWords,
228+
num_predict: 300
229+
),
230+
keep_alive: keepAlive.isEmpty ? nil : keepAlive,
231+
format: format == .none ? nil : format.rawValue
232+
)
233+
234+
var request = URLRequest(url: url)
235+
request.httpMethod = "POST"
236+
let encoder = JSONEncoder()
237+
request.httpBody = try encoder.encode(requestBody)
238+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
239+
let (result, response) = try await URLSession.shared.bytes(for: request)
240+
241+
guard let response = response as? HTTPURLResponse else {
242+
throw CancellationError()
243+
}
244+
245+
guard response.statusCode == 200 else {
246+
let text = try await result.lines.reduce(into: "") { partialResult, current in
247+
partialResult += current
248+
}
249+
throw Error.otherError(text)
250+
}
251+
252+
return ResponseStream(result: result)
253+
}
254+
255+
func countToken(_ message: Message) -> Int {
256+
message.content.count
257+
}
258+
}
259+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import Foundation
2+
3+
struct ResponseStream<Chunk: Decodable>: AsyncSequence {
4+
func makeAsyncIterator() -> Stream.AsyncIterator {
5+
stream.makeAsyncIterator()
6+
}
7+
8+
typealias Stream = AsyncThrowingStream<Chunk, Error>
9+
typealias AsyncIterator = Stream.AsyncIterator
10+
typealias Element = Chunk
11+
12+
let stream: Stream
13+
14+
init(result: URLSession.AsyncBytes, lineExtractor: @escaping (String) -> String? = { $0 }) {
15+
stream = AsyncThrowingStream<Chunk, Error> { continuation in
16+
let task = Task {
17+
do {
18+
for try await line in result.lines {
19+
if Task.isCancelled { break }
20+
guard let content = lineExtractor(line)?.data(using: .utf8)
21+
else { continue }
22+
let chunk = try JSONDecoder().decode(Chunk.self, from: content)
23+
continuation.yield(chunk)
24+
}
25+
continuation.finish()
26+
} catch {
27+
continuation.finish(throwing: error)
28+
result.task.cancel()
29+
}
30+
}
31+
continuation.onTermination = { _ in
32+
task.cancel()
33+
result.task.cancel()
34+
}
35+
}
36+
}
37+
}
38+

Core/Sources/Fundamental/Models/ChatModel.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
2222
case azureOpenAI
2323
case openAICompatible
2424
case googleAI
25+
case ollama
2526

2627
case unknown
2728
}
@@ -45,6 +46,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
4546
get { modelName }
4647
set { modelName = newValue }
4748
}
49+
@FallbackDecoding<EmptyString>
50+
public var ollamaKeepAlive: String
4851

4952
public init(
5053
apiKeyName: String = "",
@@ -53,7 +56,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
5356
maxTokens: Int = 4000,
5457
supportsFunctionCalling: Bool = true,
5558
supportsOpenAIAPI2023_11: Bool = false,
56-
modelName: String = ""
59+
modelName: String = "",
60+
ollamaKeepAlive: String = ""
5761
) {
5862
self.apiKeyName = apiKeyName
5963
self.baseURL = baseURL
@@ -62,6 +66,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
6266
self.supportsFunctionCalling = supportsFunctionCalling
6367
self.supportsOpenAIAPI2023_11 = supportsOpenAIAPI2023_11
6468
self.modelName = modelName
69+
self.ollamaKeepAlive = ollamaKeepAlive
6570
}
6671
}
6772

@@ -86,6 +91,10 @@ public struct ChatModel: Codable, Equatable, Identifiable {
8691
let baseURL = info.baseURL
8792
if baseURL.isEmpty { return "https://generativelanguage.googleapis.com/v1" }
8893
return "\(baseURL)/v1/chat/completions"
94+
case .ollama:
95+
let baseURL = info.baseURL
96+
if baseURL.isEmpty { return "http://localhost:11434/api/chat" }
97+
return "\(baseURL)/api/chat"
8998
case .unknown:
9099
return ""
91100
}

0 commit comments

Comments
 (0)