Skip to content

Commit

Permalink
bump jextract from version 21 to 22
Browse files Browse the repository at this point in the history
* Add MlirDialect and MlirTypeID to includeStructs.txt
* macOS support are available in jextract 22
  • Loading branch information
Emin017 committed Jan 14, 2025
1 parent ba4e01c commit ef56829
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 31 deletions.
2 changes: 2 additions & 0 deletions circtpanamabinding/includeStructs.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
MlirContext
MlirDialect
MlirDialectHandle
MlirStringRef
MlirType
MlirTypeID
MlirValue
MlirLocation
MlirAttribute
Expand Down
13 changes: 6 additions & 7 deletions panama.sc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object utils extends Module {
// 21, 1-2, {linux-x64, macos-x64, windows-x64}
// 22, 1-2, {linux-x64, macos-aarch64, macos-x64, windows-x64}
def jextract(jdkVersion: Int, jextractVersion: String, os: String, platform: String) =
s"https://download.java.net/java/early_access/jextract/21/1/openjdk-${jdkVersion}-jextract+${jextractVersion}_${os}-${platform}_bin.tar.gz"
s"https://download.java.net/java/early_access/jextract/22/6/openjdk-${jdkVersion}-jextract+${jextractVersion}_${os}-${platform}_bin.tar.gz"

// use T.persistent to avoid download repeatedly
def circtInstallDir: T[os.Path] = T.persistent {
Expand Down Expand Up @@ -64,11 +64,11 @@ object utils extends Module {
val tarPath = T.dest / "jextract.tar.gz"
if (!os.exists(tarPath)) {
val url = jextract(
21,
"1-2",
22,
"6-47",
if (linux) "linux" else if (mac) "macos" else throw new Exception("unsupported os"),
// There is no macos-aarch64 for jextract 21, use x64 for now
if (amd64 || mac) "x64" else if (aarch64) "aarch64" else throw new Exception("unsupported arch")
if (amd64) "x64" else if (aarch64) "aarch64" else throw new Exception("unsupported arch")
)
T.ctx().log.info(s"Downloading jextract from ${url}")
mill.util.Util.download(url, os.rel / "jextract.tar.gz")
Expand Down Expand Up @@ -127,7 +127,6 @@ trait HasJextractGeneratedSources extends JavaModule {
++ Seq(
"-t", target(),
"--header-class-name", headerClassName(),
"--source",
"--output", T.dest.toString
) ++ includeFunctions().flatMap(f => Seq("--include-function", f)) ++
includeConstants().flatMap(f => Seq("--include-constant", f)) ++
Expand All @@ -142,7 +141,7 @@ trait HasJextractGeneratedSources extends JavaModule {
}
}

override def javacOptions = T(super.javacOptions() ++ Seq("--enable-preview", "--release", "21"))
override def javacOptions = T(super.javacOptions() ++ Seq("--release", "22"))
}

// Java Codegen for all declared functions.
Expand Down Expand Up @@ -173,7 +172,7 @@ trait HasCIRCTPanamaBindingModule extends JavaModule {

override def moduleDeps = super.moduleDeps ++ Some(circtPanamaBindingModule)
//
override def javacOptions = T(super.javacOptions() ++ Seq("--enable-preview", "--release", "21"))
override def javacOptions = T(super.javacOptions() ++ Seq("--release", "22"))

override def forkArgs: T[Seq[String]] = T(
super.forkArgs() ++ Seq("--enable-native-access=ALL-UNNAMED", "--enable-preview")
Expand Down
43 changes: 19 additions & 24 deletions panamalib/src/PanamaCIRCT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,15 @@ class PanamaCIRCT {
buffer.copyFrom(MemorySegment.ofArray(bytes))

val stringRef = circt.MlirStringRef.allocate(arena)
circt.MlirStringRef.data$set(stringRef, buffer)
circt.MlirStringRef.length$set(stringRef, bytes.length)
circt.MlirStringRef.data(stringRef, buffer)
circt.MlirStringRef.length(stringRef, bytes.length)

MlirStringRef(stringRef)
}

private def newStringCallback(callback: String => Unit): MlirStringCallback = {
val cb = new circt.MlirStringCallback {
def apply(message: MemorySegment, userData: MemorySegment) = {
callback(MlirStringRef(message).toString)
}
val cb: circt.MlirStringCallback.Function = (message, userData) => {
callback(MlirStringRef(message).toString)
}
MlirStringCallback(circt.MlirStringCallback.allocate(cb, arena))
}
Expand Down Expand Up @@ -262,10 +260,8 @@ class PanamaCIRCT {
CAPI.mlirOperationPrint(op.get, newStringCallback(callback).get, NULL)

def mlirOperationWriteBytecode(op: MlirOperation, callback: Array[Byte] => Unit) = {
val cb = new circt.MlirStringCallback {
def apply(message: MemorySegment, userData: MemorySegment) = {
callback(MlirStringRef(message).toBytes)
}
val cb: circt.MlirStringCallback.Function = (message, userData) => {
callback(MlirStringRef(message).toBytes)
}
val mlirCallback = MlirStringCallback(circt.MlirStringCallback.allocate(cb, arena))
CAPI.mlirOperationWriteBytecode(op.get, mlirCallback.get, NULL)
Expand Down Expand Up @@ -424,9 +420,9 @@ class PanamaCIRCT {
def circtFirtoolPopulateFinalizeIR(pm: MlirPassManager, options: CirctFirtoolFirtoolOptions) =
MlirLogicalResult(CAPI.circtFirtoolPopulateFinalizeIR(arena, pm.get, options.get))

def mlirLogicalResultIsSuccess(res: MlirLogicalResult): Boolean = circt.MlirLogicalResult.value$get(res.get) != 0
def mlirLogicalResultIsSuccess(res: MlirLogicalResult): Boolean = circt.MlirLogicalResult.value(res.get) != 0

def mlirLogicalResultIsFailure(res: MlirLogicalResult): Boolean = circt.MlirLogicalResult.value$get(res.get) == 0
def mlirLogicalResultIsFailure(res: MlirLogicalResult): Boolean = circt.MlirLogicalResult.value(res.get) == 0

def firrtlTypeGetUInt(width: Int) = MlirType(CAPI.firrtlTypeGetUInt(arena, mlirCtx, width))

Expand All @@ -449,9 +445,9 @@ class PanamaCIRCT {
fields.zipWithIndex.foreach {
case (field, i) =>
val fieldBuffer = buffer.asSlice(circt.FIRRTLBundleField.sizeof() * i, circt.FIRRTLBundleField.sizeof())
circt.FIRRTLBundleField.name$slice(fieldBuffer).copyFrom(mlirIdentifierGet(field.name).get)
circt.FIRRTLBundleField.isFlip$set(fieldBuffer, field.isFlip)
circt.FIRRTLBundleField.type$slice(fieldBuffer).copyFrom(field.tpe.get)
circt.FIRRTLBundleField.name(fieldBuffer).copyFrom(mlirIdentifierGet(field.name).get)
circt.FIRRTLBundleField.isFlip(fieldBuffer, field.isFlip)
circt.FIRRTLBundleField.`type`(fieldBuffer).copyFrom(field.tpe.get)
}
MlirType(CAPI.firrtlTypeGetBundle(arena, mlirCtx, fields.length, buffer))
}
Expand Down Expand Up @@ -484,9 +480,9 @@ class PanamaCIRCT {
elements.zipWithIndex.foreach {
case (element, i) =>
val elementBuffer = buffer.asSlice(circt.FIRRTLClassElement.sizeof() * i, circt.FIRRTLClassElement.sizeof())
circt.FIRRTLClassElement.name$slice(elementBuffer).copyFrom(mlirIdentifierGet(element.name).get)
circt.FIRRTLClassElement.type$slice(elementBuffer).copyFrom(element.tpe.get)
circt.FIRRTLClassElement.direction$set(elementBuffer, element.direction.get)
circt.FIRRTLClassElement.name(elementBuffer).copyFrom(mlirIdentifierGet(element.name).get)
circt.FIRRTLClassElement.`type`(elementBuffer).copyFrom(element.tpe.get)
circt.FIRRTLClassElement.direction(elementBuffer, element.direction.get)
}
MlirType(CAPI.firrtlTypeGetClass(arena, mlirCtx, name.get, elements.length, buffer))
}
Expand Down Expand Up @@ -551,13 +547,12 @@ class PanamaCIRCT {
)

def hwInstanceGraphForEachNode(instaceGraph: HWInstanceGraph, callback: HWInstanceGraphNode => Unit) = {
val nodeProcessorFn: circt.HWInstanceGraphNodeCallback.Function = (node, userData) => {
callback(HWInstanceGraphNode(node))
}
val cb = HWInstanceGraphNodeCallback(
circt.HWInstanceGraphNodeCallback.allocate(
new circt.HWInstanceGraphNodeCallback {
def apply(node: MemorySegment, userData: MemorySegment) = {
callback(HWInstanceGraphNode(node))
}
},
nodeProcessorFn,
arena
)
)
Expand Down Expand Up @@ -776,7 +771,7 @@ final case class MlirStringRef(ptr: MemorySegment) extends ForeignType[MemorySeg
private[panamalib] val sizeof = circt.MlirStringRef.sizeof().toInt

def toBytes: Array[Byte] = {
var slice = circt.MlirStringRef.data$get(ptr).asSlice(0, circt.MlirStringRef.length$get(ptr))
var slice = circt.MlirStringRef.data(ptr).asSlice(0, circt.MlirStringRef.length(ptr))
slice.toArray(JAVA_BYTE)
}

Expand Down

0 comments on commit ef56829

Please sign in to comment.