Warning, /network/kdeconnect-android/src/org/kde/kdeconnect/Backends/BluetoothBackend/ConnectionMultiplexer.kt is written in an unsupported language. File is not indexed.

0001 /*
0002  * SPDX-FileCopyrightText: 2019 Matthijs Tijink <matthijstijink@gmail.com>
0003  * SPDX-FileCopyrightText: 2024 Rob Emery <git@mintsoft.net>
0004  *
0005  * SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only OR LicenseRef-KDE-Accepted-GPL
0006 */
0007 package org.kde.kdeconnect.Backends.BluetoothBackend
0008 
0009 import android.bluetooth.BluetoothSocket
0010 import android.util.Log
0011 import org.kde.kdeconnect.Helpers.ThreadHelper.execute
0012 import java.io.Closeable
0013 import java.io.IOException
0014 import java.io.InputStream
0015 import java.io.OutputStream
0016 import java.nio.ByteBuffer
0017 import java.nio.ByteOrder
0018 import java.util.UUID
0019 import java.util.concurrent.locks.Condition
0020 import java.util.concurrent.locks.ReentrantLock
0021 import kotlin.concurrent.withLock
0022 
0023 class ConnectionMultiplexer(socket: BluetoothSocket) : Closeable {
0024     private class ChannelInputStream constructor(val channel: Channel) : InputStream(), Closeable {
0025         override fun available(): Int {
0026             return channel.available()
0027         }
0028 
0029         @Throws(IOException::class)
0030         override fun close() {
0031             channel.close()
0032         }
0033 
0034         override fun read(): Int {
0035             val b = ByteArray(1)
0036             return if (read(b, 0, 1) == -1) {
0037                 -1
0038             } else {
0039                 b[0].toInt()
0040             }
0041         }
0042 
0043         override fun read(b: ByteArray, off: Int, len: Int): Int {
0044             return channel.read(b, off, len)
0045         }
0046 
0047         override fun read(b: ByteArray): Int {
0048             return read(b, 0, b.size)
0049         }
0050     }
0051 
0052     private class ChannelOutputStream constructor(val channel: Channel) : OutputStream(), Closeable {
0053         @Throws(IOException::class)
0054         override fun close() {
0055             channel.close()
0056         }
0057 
0058         @Throws(IOException::class)
0059         override fun flush() {
0060             channel.flush()
0061         }
0062 
0063         @Throws(IOException::class)
0064         override fun write(b: ByteArray, off: Int, len: Int) {
0065             channel.write(b, off, len)
0066         }
0067 
0068         @Throws(IOException::class)
0069         override fun write(b: Int) {
0070             val data = ByteArray(1)
0071             data[0] = b.toByte()
0072             write(data, 0, 1)
0073         }
0074 
0075         @Throws(IOException::class)
0076         override fun write(b: ByteArray) {
0077             write(b, 0, b.size)
0078         }
0079     }
0080 
0081     private class Channel constructor(val multiplexer: ConnectionMultiplexer, val id: UUID) : Closeable {
0082         val readBuffer: ByteBuffer = ByteBuffer.allocate(BUFFER_SIZE)
0083         val lock = ReentrantLock()
0084         var lockCondition: Condition = lock.newCondition()
0085 
0086         var open = true
0087         var requestedReadAmount = 0 //Number of times we requested some bytes from the channel
0088         var freeWriteAmount = 0 //Number of times we can safely send bytes over the channel
0089         fun available(): Int {
0090             lock.withLock { return readBuffer.position() }
0091         }
0092 
0093         fun read(b: ByteArray, off: Int, len: Int): Int {
0094             if (len == 0) return 0
0095             while (true) {
0096                 var makeRequest: Boolean
0097                 lock.withLock {
0098                     if (readBuffer.position() >= len) {
0099                         readBuffer.flip()
0100                         readBuffer[b, off, len]
0101                         readBuffer.compact()
0102 
0103                         //TODO: non-blocking (opportunistic) read request
0104                         return len
0105                     } else if (readBuffer.position() > 0) {
0106                         val numberRead = readBuffer.position()
0107                         readBuffer.flip()
0108                         readBuffer[b, off, numberRead]
0109                         readBuffer.compact()
0110 
0111                         //TODO: non-blocking (opportunistic) read request
0112                         return numberRead
0113                     }
0114                     if (!open) return -1
0115                     makeRequest = requestedReadAmount < BUFFER_SIZE
0116                 }
0117                 if (makeRequest) {
0118                     multiplexer.readRequest(id)
0119                 }
0120                 lock.withLock {
0121                     if (!open) return -1
0122                     if (readBuffer.position() <= 0) {
0123                         try {
0124                             lockCondition.await()
0125                         } catch (ignored: Exception) {
0126                         }
0127                     }
0128                 }
0129             }
0130         }
0131 
0132         @Throws(IOException::class)
0133         override fun close() {
0134             flush()
0135             lock.withLock {
0136                 open = false
0137                 readBuffer.clear()
0138                 lockCondition.signalAll()
0139             }
0140             multiplexer.closeChannel(id)
0141         }
0142 
0143         fun doClose() {
0144             lock.withLock {
0145                 open = false
0146                 lockCondition.signalAll()
0147             }
0148         }
0149 
0150         @Throws(IOException::class)
0151         fun write(data: ByteArray, off: Int, len: Int) {
0152             var offset = off
0153             var length = len
0154             while (length > 0) {
0155                 lock.withLock {
0156                     while (true) {
0157                         if (!open) throw IOException("Connection closed!")
0158                         if (freeWriteAmount == 0) {
0159                             try {
0160                                 lockCondition.await()
0161                             } catch (ignored: Exception) {
0162                             }
0163                         } else {
0164                             break
0165                         }
0166                     }
0167                 }
0168                 val numWritten = multiplexer.writeRequest(id, data, offset, length)
0169                 offset += numWritten
0170                 length -= numWritten
0171             }
0172         }
0173 
0174         @Throws(IOException::class)
0175         fun flush() {
0176             multiplexer.flush()
0177         }
0178     }
0179 
0180     private var socket: BluetoothSocket?
0181     private val channels: MutableMap<UUID, Channel> = HashMap()
0182     private val lock = ReentrantLock()
0183     private var open = true
0184     private var receivedProtocolVersion = false
0185 
0186     init {
0187         this.socket = socket
0188         channels[DEFAULT_CHANNEL] = Channel(this, DEFAULT_CHANNEL)
0189         sendProtocolVersion()
0190         execute(ListenRunnable(socket))
0191     }
0192 
0193     @Throws(IOException::class)
0194     private fun sendProtocolVersion() {
0195         val data = ByteArray(23)
0196         val message = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN)
0197         message.put(MESSAGE_PROTOCOL_VERSION)
0198         message.putShort(4.toShort())
0199         message.position(19)
0200         message.putShort(1.toShort())
0201         message.putShort(1.toShort())
0202         socket!!.outputStream.write(data)
0203     }
0204 
0205     private fun handleException(@Suppress("UNUSED_PARAMETER") ignored: IOException) {
0206         lock.withLock {
0207             open = false
0208             for (channel in channels.values) {
0209                 channel.doClose()
0210             }
0211             channels.clear()
0212             if (socket != null && socket!!.isConnected) {
0213                 try {
0214                     socket!!.close()
0215                 } catch (ignored: IOException) {
0216                 }
0217             }
0218         }
0219     }
0220 
0221     private fun closeChannel(id: UUID) {
0222         lock.withLock {
0223             if (channels.containsKey(id)) {
0224                 channels.remove(id)
0225                 val data = ByteArray(19)
0226                 val message = ByteBuffer.wrap(data)
0227                 message.order(ByteOrder.BIG_ENDIAN)
0228                 message.put(MESSAGE_CLOSE_CHANNEL)
0229                 message.putShort(0.toShort())
0230                 message.putLong(id.mostSignificantBits)
0231                 message.putLong(id.leastSignificantBits)
0232                 try {
0233                     socket!!.outputStream.write(data)
0234                 } catch (e: IOException) {
0235                     handleException(e)
0236                 }
0237             }
0238         }
0239     }
0240 
0241     private fun readRequest(id: UUID) {
0242         lock.withLock {
0243             val channel = channels[id] ?: return
0244             val data = ByteArray(21)
0245             channel.lock.withLock {
0246                 if (!channel.open) return
0247                 if (channel.readBuffer.position() + channel.requestedReadAmount >= BUFFER_SIZE) return
0248                 val amount = BUFFER_SIZE - channel.readBuffer.position() - channel.requestedReadAmount
0249                 val message = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN)
0250                 message.put(MESSAGE_READ)
0251                 message.putShort(2.toShort())
0252                 message.putLong(id.mostSignificantBits)
0253                 message.putLong(id.leastSignificantBits)
0254                 message.putShort(amount.toShort())
0255                 channel.requestedReadAmount += amount
0256                 try {
0257                     socket!!.outputStream.write(data)
0258                 } catch (e: IOException) {
0259                     handleException(e)
0260                 }
0261                 channel.lockCondition.signalAll()
0262             }
0263         }
0264     }
0265 
0266     @Throws(IOException::class)
0267     private fun writeRequest(id: UUID, writeData: ByteArray, off: Int, writeLen: Int): Int {
0268         lock.withLock {
0269             val channel = channels[id] ?: return 0
0270             val data = ByteArray(19 + BUFFER_SIZE)
0271             var length: Int
0272             channel.lock.withLock {
0273                 if (!channel.open) return 0
0274                 if (channel.freeWriteAmount == 0) return 0
0275                 length = channel.freeWriteAmount
0276                 if (writeLen < length) {
0277                     length = writeLen
0278                 }
0279                 val message = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN)
0280                 message.put(MESSAGE_WRITE)
0281                 //Convert length to signed short
0282                 val lengthShort: Short = if (length >= 0x10000) {
0283                     throw IOException("Invalid buffer size, too large!")
0284                 } else if (length >= 0x8000) {
0285                     (-0x10000 + length).toShort()
0286                 } else {
0287                     length.toShort()
0288                 }
0289                 message.putShort(lengthShort)
0290                 message.putLong(id.mostSignificantBits)
0291                 message.putLong(id.leastSignificantBits)
0292                 message.put(writeData, off, length)
0293                 channel.freeWriteAmount -= length
0294                 channel.lockCondition.signalAll()
0295             }
0296             try {
0297                 socket!!.outputStream.write(data, 0, 19 + length)
0298             } catch (e: IOException) {
0299                 handleException(e)
0300             }
0301             return length
0302         }
0303     }
0304 
0305     @Throws(IOException::class)
0306     private fun flush() {
0307         lock.withLock {
0308             if (!open) return
0309             socket!!.outputStream.flush()
0310         }
0311     }
0312 
0313     @Throws(IOException::class)
0314     override fun close() {
0315         if (socket == null) {
0316             return
0317         }
0318         socket!!.close()
0319         socket = null
0320         for (channel in channels.values) {
0321             channel.doClose()
0322         }
0323         channels.clear()
0324     }
0325 
0326     @Throws(IOException::class)
0327     fun newChannel(): UUID {
0328         val id = UUID.randomUUID()
0329         lock.withLock {
0330             val data = ByteArray(19)
0331             val message = ByteBuffer.wrap(data)
0332             message.order(ByteOrder.BIG_ENDIAN)
0333             message.put(MESSAGE_OPEN_CHANNEL)
0334             message.putShort(0.toShort())
0335             message.putLong(id.mostSignificantBits)
0336             message.putLong(id.leastSignificantBits)
0337             try {
0338                 socket!!.outputStream.write(data)
0339             } catch (e: IOException) {
0340                 handleException(e)
0341                 throw e
0342             }
0343             channels.put(id, Channel(this, id))
0344         }
0345         return id
0346     }
0347 
0348     @get:Throws(IOException::class)
0349     val defaultInputStream: InputStream
0350         get() = getChannelInputStream(DEFAULT_CHANNEL)
0351 
0352     @get:Throws(IOException::class)
0353     val defaultOutputStream: OutputStream
0354         get() = getChannelOutputStream(DEFAULT_CHANNEL)
0355 
0356     @Throws(IOException::class)
0357     fun getChannelInputStream(id: UUID): InputStream {
0358         lock.withLock {
0359             val channel = channels[id] ?: throw IOException("Invalid channel!")
0360             return ChannelInputStream(channel)
0361         }
0362     }
0363 
0364     @Throws(IOException::class)
0365     fun getChannelOutputStream(id: UUID): OutputStream {
0366         lock.withLock {
0367             val channel = channels[id] ?: throw IOException("Invalid channel!")
0368             return ChannelOutputStream(channel)
0369         }
0370     }
0371 
0372     private inner class ListenRunnable constructor(socket: BluetoothSocket) : Runnable {
0373         var input: InputStream
0374         var output: OutputStream
0375 
0376         init {
0377             input = socket.inputStream
0378             output = socket.outputStream
0379         }
0380 
0381         @Throws(IOException::class)
0382         private fun readBuffer(buffer: ByteArray, len: Int) {
0383             var numRead = 0
0384             while (numRead < len) {
0385                 val count = input.read(buffer, numRead, len - numRead)
0386                 if (count == -1) {
0387                     throw IOException("Couldn't read enough bytes!")
0388                 }
0389                 numRead += count
0390             }
0391         }
0392 
0393         fun byteArrayToHexString(bytes: ByteArray): String {
0394             val sb = StringBuilder()
0395             for (b in bytes) {
0396                 sb.append(String.format("0x%02x ", b.toInt() and 0xff))
0397             }
0398             return sb.toString()
0399         }
0400 
0401         @Throws(IOException::class)
0402         private fun readMessage() {
0403             var data = ByteArray(BUFFER_SIZE)
0404             readBuffer(data, 19)
0405             val message = ByteBuffer.wrap(data, 0, 19).order(ByteOrder.BIG_ENDIAN)
0406             val type = message.get()
0407             var length = message.short.toInt()
0408             //signed short -> unsigned short (as int) conversion
0409             if (length < 0) length += 0x10000
0410             val channelIdMostSigBits = message.long
0411             val channelIdLeastSigBits = message.long
0412             val channelId = UUID(channelIdMostSigBits, channelIdLeastSigBits)
0413             if (!receivedProtocolVersion && type != MESSAGE_PROTOCOL_VERSION) {
0414                 Log.w("ConnectionMultiplexer", "Received invalid message '$message'")
0415                 Log.w("ConnectionMultiplexer", "'data_buffer:(" + byteArrayToHexString(data) + ") ")
0416                 Log.w("ConnectionMultiplexer", "as string: '$data' ")
0417 
0418                 throw IOException("Did not receive protocol version message!")
0419             }
0420             when (type) {
0421                 MESSAGE_OPEN_CHANNEL -> {
0422                     lock.withLock {
0423                         channels.put(channelId, Channel(this@ConnectionMultiplexer, channelId))
0424                     }
0425                 }
0426                 MESSAGE_CLOSE_CHANNEL -> {
0427                     lock.withLock {
0428                         val channel = channels[channelId] ?: return
0429                         channels.remove(channelId)
0430                         channel.doClose()
0431                     }
0432                 }
0433                 MESSAGE_READ -> {
0434                     if (length != 2) {
0435                         throw IOException("Message length is invalid for 'MESSAGE_READ'!")
0436                     }
0437                     readBuffer(data, 2)
0438                     var amount = ByteBuffer.wrap(data, 0, 2).order(ByteOrder.BIG_ENDIAN).short.toInt()
0439                     //signed short -> unsigned short (as int) conversion
0440                     if (amount < 0) amount += 0x10000
0441                     lock.withLock {
0442                         val channel = channels[channelId] ?: return
0443                         channel.lock.withLock {
0444                             channel.freeWriteAmount += amount
0445                             channel.lockCondition.signalAll()
0446                         }
0447                     }
0448                 }
0449                 MESSAGE_WRITE -> {
0450                     if (length > BUFFER_SIZE) {
0451                         throw IOException("Message length is bigger than read size!")
0452                     }
0453                     readBuffer(data, length)
0454                     lock.withLock {
0455                         val channel = channels[channelId] ?: return
0456                         channel.lock.withLock {
0457                             if (channel.requestedReadAmount < length) {
0458                                 throw IOException("No outstanding read requests of this length!")
0459                             }
0460                             channel.requestedReadAmount -= length
0461                             if (channel.readBuffer.position() + length > BUFFER_SIZE) {
0462                                 throw IOException("Shouldn't be getting more data when the buffer is too full!")
0463                             }
0464                             channel.readBuffer.put(data, 0, length)
0465                             channel.lockCondition.signalAll()
0466                         }
0467                     }
0468                 }
0469                 MESSAGE_PROTOCOL_VERSION -> {
0470                     //Allow more than 4 bytes data, for future extensibility
0471                     if (length < 4) {
0472                         throw IOException("Message length is invalid for 'MESSAGE_PROTOCOL_VERSION'!")
0473                     }
0474                     //We might need a larger buffer to read this
0475                     if (length > data.size) {
0476                         data = ByteArray(1 shl 16)
0477                     }
0478                     readBuffer(data, length)
0479 
0480                     //Check remote endpoint protocol version
0481                     var minimumVersion = ByteBuffer.wrap(data, 0, 2).order(ByteOrder.BIG_ENDIAN).short.toInt()
0482                     //signed short -> unsigned short (as int) conversion
0483                     if (minimumVersion < 0) minimumVersion += 0x10000
0484                     var maximumVersion = ByteBuffer.wrap(data, 2, 2).order(ByteOrder.BIG_ENDIAN).short.toInt()
0485                     //signed short -> unsigned short (as int) conversion
0486                     if (maximumVersion < 0) maximumVersion += 0x10000
0487                     if (minimumVersion > 1 || maximumVersion < 1) {
0488                         throw IOException("Unsupported protocol version $minimumVersion - $maximumVersion!")
0489                     }
0490                     //We now support receiving other messages
0491                     receivedProtocolVersion = true
0492                 }
0493                 else -> {
0494                     throw IOException("Invalid message type " + type.toInt())
0495                 }
0496             }
0497         }
0498 
0499         override fun run() {
0500             while (true) {
0501                 lock.withLock {
0502                     if (!open) {
0503                         Log.w("ConnectionMultiplexer", "connection not open, returning")
0504                         return
0505                     }
0506                 }
0507                 try {
0508                     readMessage()
0509                 } catch (e: IOException) {
0510                     Log.w("ConnectionMultiplexer", "run caught IOException", e)
0511                     handleException(e)
0512                     return
0513                 }
0514             }
0515         }
0516     }
0517 
0518     companion object {
0519         private val DEFAULT_CHANNEL = UUID.fromString("a0d0aaf4-1072-4d81-aa35-902a954b1266")
0520         private const val BUFFER_SIZE = 4096
0521         private const val MESSAGE_PROTOCOL_VERSION: Byte = 0 //Negotiate the protocol version
0522         private const val MESSAGE_OPEN_CHANNEL: Byte = 1 //Open a new channel
0523         private const val MESSAGE_CLOSE_CHANNEL: Byte = 2 //Close a channel
0524         private const val MESSAGE_READ: Byte = 3 //Request some bytes from a channel
0525         private const val MESSAGE_WRITE: Byte = 4 //Write some bytes to a channel
0526     }
0527 }