diff --git a/repl/src/dotty/tools/repl/JLineTerminal.scala b/repl/src/dotty/tools/repl/JLineTerminal.scala index 64605847767f..c1ce2e06ff9e 100644 --- a/repl/src/dotty/tools/repl/JLineTerminal.scala +++ b/repl/src/dotty/tools/repl/JLineTerminal.scala @@ -3,6 +3,8 @@ package repl import scala.language.unsafeNulls import scala.io.AnsiColor + +import java.io.{InputStream, InterruptedIOException} import dotc.core.Contexts.* import dotc.parsing.Scanners.Scanner import dotc.parsing.Tokens.* @@ -14,9 +16,17 @@ import org.jline.reader.Parser.ParseContext import org.jline.reader.* import org.jline.reader.impl.LineReaderImpl import org.jline.reader.impl.history.DefaultHistory +import org.jline.terminal.Attributes +import org.jline.terminal.Attributes.ControlChar import org.jline.terminal.TerminalBuilder import org.jline.terminal.Terminal.Signal import org.jline.utils.AttributedString +import org.jline.utils.NonBlockingReader + +// `stdin` alternates between a background Ctrl-C monitor and the foreground +// wrapped `System.in` reader. These states track which side currently owns it. +private enum InputState: + case Monitoring, ForegroundRead, Closed class JLineTerminal extends java.io.Closeable { private val terminal = @@ -30,7 +40,42 @@ class JLineTerminal extends java.io.Closeable { builder.dumb(true) builder.build() + private val originalAttributes = terminal.getAttributes + private val noIntrAttributes = new Attributes(originalAttributes) + noIntrAttributes.setControlChar(ControlChar.VINTR, 0) + terminal.setAttributes(noIntrAttributes) + terminal.enterRawMode() + private val history = new DefaultHistory + @volatile private var monitoringThread: Thread | Null = null + + private val userLineReader = + LineReaderBuilder + .builder() + .terminal(terminal) + .parser(new reader.Parser { + private class ParsedLine(val inputLine: String, val inputCursor: Int) extends reader.ParsedLine { + def word(): String = inputLine + def wordCursor(): Int = inputCursor + def wordIndex(): Int = 0 + def words(): java.util.List[String] = java.util.List.of(inputLine) + def line(): String = inputLine + def cursor(): Int = inputCursor + } + + def parse(input: String, cursor: Int, context: ParseContext): reader.ParsedLine = + new ParsedLine(input, cursor) + }) + .build() + + bindCtrlCInterrupt(userLineReader) + private val userInput = new UserInputStream(userLineReader, terminal.encoding()) + + private def bindCtrlCInterrupt(lr: LineReader): Unit = + lr.getKeyMaps.get(LineReader.MAIN).bind( + new Widget { override def apply(): Boolean = throw new UserInterruptException("") }, + "\u0003" + ) private def magenta(str: String)(using Context) = // Deliberately do not use these properties on `Console` to avoid initializing it, @@ -80,26 +125,60 @@ class JLineTerminal extends java.io.Closeable { .option(DISABLE_EVENT_EXPANSION, true) // don't process escape sequences in input .build() - lineReader.getKeyMaps.get(LineReader.MAIN).bind( - new Widget { override def apply(): Boolean = throw new UserInterruptException("") }, - "\u0003" - ) + bindCtrlCInterrupt(lineReader) lineReader.readLine(prompt) } def close(): Unit = - terminal.close() + userInput.signalClosed() + // Defensive: normally withMonitoringCtrlC joins and nulls the thread, + // but if close() is called during an abnormal exit, clean up here. + monitoringThread match + case thread: Thread => + Thread.interrupted() // clear interrupt flag in case user code interrupted this thread + thread.join() + case null => + try terminal.setAttributes(originalAttributes) + finally terminal.close() + + def userInputStream: InputStream = + userInput /** Execute a block while monitoring for Ctrl-C keypresses. * Calls the handler when Ctrl-C is detected during block execution. */ def withMonitoringCtrlC[T](handler: () => Unit)(block: => T): T = { - // If you change Ctrl+C handling in any way, such as by trying to read/peek from stdin for Ctrl+C, - // make sure you manually check that reading from, e.g., `Console.in` still works! - // Remember that the user can use stdin from code they enter into the REPL, we do not have exclusive access to it. + // If you change Ctrl+C handling in any way, make sure you manually check that + // reading from both `System.in` and `Console.in` still works in embedded hosts. + // In raw mode, SIGINT is not generated by the terminal (Ctrl-C is detected + // by reading byte 3 from the raw stream). This handler is a fallback for + // external signals, e.g. `kill -INT`. val previousHandler = terminal.handle(Signal.INT, _ => handler()) + val reader = terminal.reader() + userInput.startMonitoring() + val thread = new Thread(() => + while userInput.waitUntilActive() == InputState.Monitoring do + val ch = + try reader.read(100L) + catch case _: Exception => -1 + + if ch == NonBlockingReader.READ_EXPIRED then () + else if ch == NonBlockingReader.EOF then userInput.signalClosed() + else if ch == 3 then handler() + else userInput.enqueueChar(ch) + , "REPL-CtrlC-Monitor") + monitoringThread = thread + thread.setDaemon(true) + thread.start() + try block - finally terminal.handle(Signal.INT, previousHandler) + finally { + userInput.signalClosed() + Thread.interrupted() // clear interrupted flag so join below doesn't explode + thread.join() + monitoringThread = null + terminal.handle(Signal.INT, previousHandler) + } } /** Provide syntax highlighting */ @@ -204,3 +283,116 @@ class JLineTerminal extends java.io.Closeable { } } } + +/** A `System.in` wrapper that lets the REPL monitor raw terminal input for Ctrl-C + * without stealing bytes from user code reading from `System.in` / `Console.in`. + * + * The monitor thread peeks at terminal input while REPL code is running. Any + * non-Ctrl-C input it sees is buffered here so later `read()` calls from user + * code observe the same bytes instead of losing them to the monitor. + */ +private final class UserInputStream( + userLineReader: LineReader, + encoding: java.nio.charset.Charset +) extends InputStream { + private var bytes = new Array[Byte](16) + private var byteCount = 0 + private var state = InputState.ForegroundRead + + /** Blocks until the state is no longer ForegroundRead. Returns the active state. */ + def waitUntilActive(): InputState = synchronized { + while state == InputState.ForegroundRead do wait() + state + } + + def enqueueChar(ch: Int): Unit = synchronized { + val encoded = String.valueOf(ch.toChar).getBytes(encoding) + enqueueBytes(encoded) + } + + def signalClosed(): Unit = synchronized { + state = InputState.Closed + notifyAll() + } + + def startMonitoring(): Unit = synchronized { + byteCount = 0 + state = InputState.Monitoring + notifyAll() + } + + private def resumeMonitoring(): Unit = synchronized { + if state != InputState.Closed then + state = InputState.Monitoring + notifyAll() + } + + private def enqueueBytes(data: Array[Byte]): Unit = synchronized { + ensureCapacity(byteCount + data.length) + Array.copy(data, 0, bytes, byteCount, data.length) + byteCount += data.length + } + + private def pollByte(): Option[Int] = synchronized { + if byteCount > 0 then + val value = bytes(0) & 0xff + removePrefix(1) + Some(value) + else if state == InputState.Closed then Some(-1) + else + state = InputState.ForegroundRead + None + } + + private def drainTo(buf: Array[Byte], offset: Int, maxLen: Int): Int = synchronized { + val n = math.min(maxLen, byteCount) + Array.copy(bytes, 0, buf, offset, n) + removePrefix(n) + n + } + + private def ensureCapacity(required: Int): Unit = + if required > bytes.length then + var newSize = bytes.length + while newSize < required do newSize *= 2 + val newBytes = new Array[Byte](newSize) + Array.copy(bytes, 0, newBytes, 0, byteCount) + bytes = newBytes + + private def removePrefix(n: Int): Unit = + byteCount -= n + if byteCount > 0 then + Array.copy(bytes, n, bytes, 0, byteCount) + + private def readUserInputByte(): Int = { + while true do + pollByte() match + case Some(value) => return value + case None => + try + val line = userLineReader.readLine("") + val lineBytes = (line + System.lineSeparator()).getBytes(encoding) + enqueueBytes(lineBytes) + catch + case _: EndOfFileException => + return -1 + case _: UserInterruptException => + throw new InterruptedIOException() + finally + resumeMonitoring() + + -1 + } + + override def read(): Int = + readUserInputByte() + + override def read(bytes: Array[Byte], offset: Int, length: Int): Int = + if length == 0 then 0 + else + val first = read() + if first == -1 then -1 + else + bytes(offset) = first.toByte + drainTo(bytes, offset + 1, length - 1) + 1 +} diff --git a/repl/src/dotty/tools/repl/ReplDriver.scala b/repl/src/dotty/tools/repl/ReplDriver.scala index 4f905cd31aa9..7fd7aca454ac 100644 --- a/repl/src/dotty/tools/repl/ReplDriver.scala +++ b/repl/src/dotty/tools/repl/ReplDriver.scala @@ -249,7 +249,15 @@ class ReplDriver(settings: Array[String], System.exit(130) // Standard exit code for SIGINT } ) { - interpret(res) + val savedIn = System.in + val replIn = terminal.userInputStream + try + System.setIn(replIn) + scala.Console.withIn(replIn) { + interpret(res) + } + finally + System.setIn(savedIn) } loop(using newState)() @@ -724,4 +732,4 @@ class ReplDriver(settings: Array[String], end ReplDriver object ReplDriver: - def pprintImport = "import pprint.pprintln\n" \ No newline at end of file + def pprintImport = "import pprint.pprintln\n"