-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathImageSegmentationUtils.kt
More file actions
136 lines (115 loc) · 4.48 KB
/
ImageSegmentationUtils.kt
File metadata and controls
136 lines (115 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package co.stonephone.stonecamera.utils
import android.content.Context
import android.graphics.Bitmap
import android.os.Build
import android.os.SystemClock
import android.util.Log
import androidx.annotation.RequiresApi
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.Rot90Op
import org.tensorflow.lite.task.core.BaseOptions
import org.tensorflow.lite.task.vision.segmenter.ImageSegmenter
import org.tensorflow.lite.task.vision.segmenter.OutputType
import org.tensorflow.lite.task.vision.segmenter.Segmentation
class ImageSegmentationUtils(
var numThreads: Int = 2,
var currentDelegate: Int = 0,
val context: Context,
val imageSegmentationListener: SegmentationListener?
) {
private var imageSegmenter: ImageSegmenter? = null
init {
setupImageSegmenter()
}
fun clearImageSegmenter() {
imageSegmenter = null
}
private fun setupImageSegmenter() {
// Create the base options for the segment
val optionsBuilder =
ImageSegmenter.ImageSegmenterOptions.builder()
// Set general segmentation options, including number of used threads
val baseOptionsBuilder = BaseOptions.builder().setNumThreads(numThreads)
// Use the specified hardware for running the model. Default to CPU
when (currentDelegate) {
DELEGATE_CPU -> {
// Default
}
DELEGATE_GPU -> {
if (CompatibilityList().isDelegateSupportedOnThisDevice) {
baseOptionsBuilder.useGpu()
} else {
imageSegmentationListener?.onError("GPU is not supported on this device")
}
}
DELEGATE_NNAPI -> {
baseOptionsBuilder.useNnapi()
}
}
optionsBuilder.setBaseOptions(baseOptionsBuilder.build())
/*
CATEGORY_MASK is being specifically used to predict the available objects
based on individual pixels in this sample. The other option available for
OutputType, CONFIDENCE_MAP, provides a gray scale mapping of the image
where each pixel has a confidence score applied to it from 0.0f to 1.0f
*/
optionsBuilder.setOutputType(OutputType.CATEGORY_MASK)
try {
imageSegmenter =
ImageSegmenter.createFromFileAndOptions(
context,
MODEL_DEEPLABV3,
optionsBuilder.build()
)
} catch (e: IllegalStateException) {
imageSegmentationListener?.onError(
"Image segmentation failed to initialize. See error logs for details"
)
Log.e(TAG, "TFLite failed to load model with error: " + e.message)
}
}
@RequiresApi(Build.VERSION_CODES.Q)
fun segment(image: Bitmap, imageRotation: Int) {
if (imageSegmenter == null) {
setupImageSegmenter()
}
// Inference time is the difference between the system time at the start and finish of the
// process
var inferenceTime = SystemClock.uptimeMillis()
// Create preprocessor for the image.
// See https://www.tensorflow.org/lite/inference_with_metadata/
// lite_support#imageprocessor_architecture
val imageProcessor =
ImageProcessor.Builder()
.add(Rot90Op(-imageRotation / 90))
.build()
// Preprocess the image and convert it into a TensorImage for segmentation.
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image))
val segmentResult = imageSegmenter?.segment(tensorImage)
inferenceTime = SystemClock.uptimeMillis() - inferenceTime
imageSegmentationListener?.onResults(
segmentResult,
inferenceTime,
tensorImage.height,
tensorImage.width
)
}
interface SegmentationListener {
fun onError(error: String)
fun onResults(
results: List<Segmentation>?,
inferenceTime: Long,
imageHeight: Int,
imageWidth: Int
)
}
companion object {
const val DELEGATE_CPU = 0
const val DELEGATE_GPU = 1
const val DELEGATE_NNAPI = 2
const val MODEL_DEEPLABV3 = "deeplabv3.tflite"
private const val TAG = "Image Segmentation Helper"
}
}