diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index a1ac6562d..d002036c0 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -81,7 +81,11 @@ for (i, token) in tokens.enumerated() { batch.token[i] = token batch.pos[i] = Int32(i) batch.n_seq_id[i] = 1 - batch.seq_id[i][0] = 0 + // batch.seq_id[i][0] = 0 + // TODO: is this the proper way to do this? + if let seq_id = batch.seq_id[i] { + seq_id[0] = 0 + } batch.logits[i] = 0 } @@ -171,7 +175,9 @@ while n_cur <= n_len { batch.token[Int(batch.n_tokens)] = new_token_id batch.pos[Int(batch.n_tokens)] = n_cur batch.n_seq_id[Int(batch.n_tokens)] = 1 - batch.seq_id[Int(batch.n_tokens)][0] = Int32(i) + if let seq_id = batch.seq_id[Int(batch.n_tokens)] { + seq_id[0] = Int32(i) + } batch.logits[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens