Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 201 additions & 9 deletions repl/src/dotty/tools/repl/JLineTerminal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package repl

import scala.language.unsafeNulls
import scala.io.AnsiColor

import java.io.{InputStream, InterruptedIOException}
import dotc.core.Contexts.*
import dotc.parsing.Scanners.Scanner
import dotc.parsing.Tokens.*
Expand All @@ -14,9 +16,17 @@ import org.jline.reader.Parser.ParseContext
import org.jline.reader.*
import org.jline.reader.impl.LineReaderImpl
import org.jline.reader.impl.history.DefaultHistory
import org.jline.terminal.Attributes
import org.jline.terminal.Attributes.ControlChar
import org.jline.terminal.TerminalBuilder
import org.jline.terminal.Terminal.Signal
import org.jline.utils.AttributedString
import org.jline.utils.NonBlockingReader

// `stdin` alternates between a background Ctrl-C monitor and the foreground
// wrapped `System.in` reader. These states track which side currently owns it.
private enum InputState:
case Monitoring, ForegroundRead, Closed

class JLineTerminal extends java.io.Closeable {
private val terminal =
Expand All @@ -30,7 +40,42 @@ class JLineTerminal extends java.io.Closeable {
builder.dumb(true)
builder.build()

private val originalAttributes = terminal.getAttributes
private val noIntrAttributes = new Attributes(originalAttributes)
noIntrAttributes.setControlChar(ControlChar.VINTR, 0)
terminal.setAttributes(noIntrAttributes)
terminal.enterRawMode()

private val history = new DefaultHistory
@volatile private var monitoringThread: Thread | Null = null

private val userLineReader =
LineReaderBuilder
.builder()
.terminal(terminal)
.parser(new reader.Parser {
private class ParsedLine(val inputLine: String, val inputCursor: Int) extends reader.ParsedLine {
def word(): String = inputLine
def wordCursor(): Int = inputCursor
def wordIndex(): Int = 0
def words(): java.util.List[String] = java.util.List.of(inputLine)
def line(): String = inputLine
def cursor(): Int = inputCursor
}

def parse(input: String, cursor: Int, context: ParseContext): reader.ParsedLine =
new ParsedLine(input, cursor)
})
.build()

bindCtrlCInterrupt(userLineReader)
private val userInput = new UserInputStream(userLineReader, terminal.encoding())

private def bindCtrlCInterrupt(lr: LineReader): Unit =
lr.getKeyMaps.get(LineReader.MAIN).bind(
new Widget { override def apply(): Boolean = throw new UserInterruptException("") },
"\u0003"
)

private def magenta(str: String)(using Context) =
// Deliberately do not use these properties on `Console` to avoid initializing it,
Expand Down Expand Up @@ -80,26 +125,60 @@ class JLineTerminal extends java.io.Closeable {
.option(DISABLE_EVENT_EXPANSION, true) // don't process escape sequences in input
.build()

lineReader.getKeyMaps.get(LineReader.MAIN).bind(
new Widget { override def apply(): Boolean = throw new UserInterruptException("") },
"\u0003"
)
bindCtrlCInterrupt(lineReader)
lineReader.readLine(prompt)
}

def close(): Unit =
terminal.close()
userInput.signalClosed()
// Defensive: normally withMonitoringCtrlC joins and nulls the thread,
// but if close() is called during an abnormal exit, clean up here.
monitoringThread match
case thread: Thread =>
Thread.interrupted() // clear interrupt flag in case user code interrupted this thread
thread.join()
case null =>
try terminal.setAttributes(originalAttributes)
finally terminal.close()

def userInputStream: InputStream =
userInput

/** Execute a block while monitoring for Ctrl-C keypresses.
* Calls the handler when Ctrl-C is detected during block execution.
*/
def withMonitoringCtrlC[T](handler: () => Unit)(block: => T): T = {
// If you change Ctrl+C handling in any way, such as by trying to read/peek from stdin for Ctrl+C,
// make sure you manually check that reading from, e.g., `Console.in` still works!
// Remember that the user can use stdin from code they enter into the REPL, we do not have exclusive access to it.
// If you change Ctrl+C handling in any way, make sure you manually check that
// reading from both `System.in` and `Console.in` still works in embedded hosts.
// In raw mode, SIGINT is not generated by the terminal (Ctrl-C is detected
// by reading byte 3 from the raw stream). This handler is a fallback for
// external signals, e.g. `kill -INT`.
val previousHandler = terminal.handle(Signal.INT, _ => handler())
val reader = terminal.reader()
userInput.startMonitoring()
val thread = new Thread(() =>
while userInput.waitUntilActive() == InputState.Monitoring do
val ch =
try reader.read(100L)
catch case _: Exception => -1

if ch == NonBlockingReader.READ_EXPIRED then ()
else if ch == NonBlockingReader.EOF then userInput.signalClosed()
else if ch == 3 then handler()
else userInput.enqueueChar(ch)
, "REPL-CtrlC-Monitor")
monitoringThread = thread
thread.setDaemon(true)
thread.start()

try block
finally terminal.handle(Signal.INT, previousHandler)
finally {
userInput.signalClosed()
Thread.interrupted() // clear interrupted flag so join below doesn't explode
thread.join()
monitoringThread = null
terminal.handle(Signal.INT, previousHandler)
}
}

/** Provide syntax highlighting */
Expand Down Expand Up @@ -204,3 +283,116 @@ class JLineTerminal extends java.io.Closeable {
}
}
}

/** A `System.in` wrapper that lets the REPL monitor raw terminal input for Ctrl-C
* without stealing bytes from user code reading from `System.in` / `Console.in`.
*
* The monitor thread peeks at terminal input while REPL code is running. Any
* non-Ctrl-C input it sees is buffered here so later `read()` calls from user
* code observe the same bytes instead of losing them to the monitor.
*/
private final class UserInputStream(
userLineReader: LineReader,
encoding: java.nio.charset.Charset
) extends InputStream {
private var bytes = new Array[Byte](16)
private var byteCount = 0
private var state = InputState.ForegroundRead

/** Blocks until the state is no longer ForegroundRead. Returns the active state. */
def waitUntilActive(): InputState = synchronized {
while state == InputState.ForegroundRead do wait()
state
}

def enqueueChar(ch: Int): Unit = synchronized {
val encoded = String.valueOf(ch.toChar).getBytes(encoding)
enqueueBytes(encoded)
}

def signalClosed(): Unit = synchronized {
state = InputState.Closed
notifyAll()
}

def startMonitoring(): Unit = synchronized {
byteCount = 0
state = InputState.Monitoring
notifyAll()
}

private def resumeMonitoring(): Unit = synchronized {
if state != InputState.Closed then
state = InputState.Monitoring
notifyAll()
}

private def enqueueBytes(data: Array[Byte]): Unit = synchronized {
ensureCapacity(byteCount + data.length)
Array.copy(data, 0, bytes, byteCount, data.length)
byteCount += data.length
}

private def pollByte(): Option[Int] = synchronized {
if byteCount > 0 then
val value = bytes(0) & 0xff
removePrefix(1)
Some(value)
else if state == InputState.Closed then Some(-1)
else
state = InputState.ForegroundRead
None
}

private def drainTo(buf: Array[Byte], offset: Int, maxLen: Int): Int = synchronized {
val n = math.min(maxLen, byteCount)
Array.copy(bytes, 0, buf, offset, n)
removePrefix(n)
n
}

private def ensureCapacity(required: Int): Unit =
if required > bytes.length then
var newSize = bytes.length
while newSize < required do newSize *= 2
val newBytes = new Array[Byte](newSize)
Array.copy(bytes, 0, newBytes, 0, byteCount)
bytes = newBytes

private def removePrefix(n: Int): Unit =
byteCount -= n
if byteCount > 0 then
Array.copy(bytes, n, bytes, 0, byteCount)

private def readUserInputByte(): Int = {
while true do
pollByte() match
case Some(value) => return value
case None =>
try
val line = userLineReader.readLine("")
val lineBytes = (line + System.lineSeparator()).getBytes(encoding)
enqueueBytes(lineBytes)
catch
case _: EndOfFileException =>
return -1
case _: UserInterruptException =>
throw new InterruptedIOException()
finally
resumeMonitoring()

-1
}

override def read(): Int =
readUserInputByte()

override def read(bytes: Array[Byte], offset: Int, length: Int): Int =
if length == 0 then 0
else
val first = read()
if first == -1 then -1
else
bytes(offset) = first.toByte
drainTo(bytes, offset + 1, length - 1) + 1
}
12 changes: 10 additions & 2 deletions repl/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,15 @@ class ReplDriver(settings: Array[String],
System.exit(130) // Standard exit code for SIGINT
}
) {
interpret(res)
val savedIn = System.in
val replIn = terminal.userInputStream
try
System.setIn(replIn)
scala.Console.withIn(replIn) {
interpret(res)
}
finally
System.setIn(savedIn)
}

loop(using newState)()
Expand Down Expand Up @@ -724,4 +732,4 @@ class ReplDriver(settings: Array[String],

end ReplDriver
object ReplDriver:
def pprintImport = "import pprint.pprintln\n"
def pprintImport = "import pprint.pprintln\n"
Loading