diff --git a/prepare/src/main/guessNNprepare/Main.scala b/prepare/src/main/guessNNprepare/Main.scala index 95ac12c..01a94ce 100644 --- a/prepare/src/main/guessNNprepare/Main.scala +++ b/prepare/src/main/guessNNprepare/Main.scala @@ -1,6 +1,6 @@ package guessNNprepare -import guessNNprepare.mains.ExtractEntities +import guessNNprepare.mains.{ComputeConcepts, ExtractEntities} import net.sourceforge.argparse4j.ArgumentParsers import net.sourceforge.argparse4j.inf.{ArgumentParser, ArgumentParserException, Namespace} @@ -9,12 +9,14 @@ object Main { val parser: ArgumentParser = ArgumentParsers.newFor("GuessNN-prepare").singleMetavar(true).build() val subParsers = parser.addSubparsers().dest("command").title("Commands").metavar("COMMAND").help("Command to execute") ExtractEntities.addCliArgs(subParsers.addParser("entities").help(ExtractEntities.description)) + ComputeConcepts.addCliArgs(subParsers.addParser("concepts").help(ComputeConcepts.description)) try { val ARGS: Namespace = parser.parseArgs(args) val command = ARGS.getString("command") command match { case "entities" => ExtractEntities.execute(ARGS) + case "concepts" => ComputeConcepts.execute(ARGS) case _ => System.err.printf(s"Unknown command ${command}"); System.exit(1) diff --git a/prepare/src/main/guessNNprepare/Utils.scala b/prepare/src/main/guessNNprepare/Utils.scala new file mode 100644 index 0000000..e4a1371 --- /dev/null +++ b/prepare/src/main/guessNNprepare/Utils.scala @@ -0,0 +1,25 @@ +package guessNNprepare + +import org.json.JSONArray + +import java.util.stream.StreamSupport +import scala.jdk.StreamConverters.StreamHasToScala +import scala.util.{Failure, Success, Try} + +object Utils { + + def JSONArrayElementsAs[T](array: JSONArray): Seq[T] = { + StreamSupport.stream(array.spliterator(), false).toScala(LazyList) + .map(obj => obj.asInstanceOf[T]) + } + + /** + * Convert a list of Try to a Try of a list. The Try is a failure if any element in the list of a failure. + */ + def flattenTryList[T](l: List[Try[T]]): Try[List[T]] = l match { + case Failure(e) :: _ => Failure(e) + case Success(value) :: rest => flattenTryList(rest).map(r => value :: r) + case Nil => Success(Nil) + } + +} diff --git a/prepare/src/main/guessNNprepare/mains/ComputeConcepts.scala b/prepare/src/main/guessNNprepare/mains/ComputeConcepts.scala new file mode 100644 index 0000000..818ec22 --- /dev/null +++ b/prepare/src/main/guessNNprepare/mains/ComputeConcepts.scala @@ -0,0 +1,46 @@ +package guessNNprepare.mains + +import guessNNprepare.{NamedEntity, Utils} +import net.sourceforge.argparse4j.impl.Arguments +import net.sourceforge.argparse4j.inf.{ArgumentParser, Namespace} +import org.json.{JSONArray, JSONObject, JSONTokener} + +import java.nio.file.{Files, Path, Paths} +import scala.util.Try + +object ComputeConcepts extends MainCommand { + override def description: String = "Compute the CNN for a given goal entity" + + override def addCliArgs(parser: ArgumentParser): Unit = { + parser.description(description) + + parser.addArgument("guessable_entities").`type`(Arguments.fileType().verifyCanRead()).help("JSON file containing all guessable entities") + } + + override def execute(ARGS: Namespace): Unit = { + // CLI args + val entitiesFilePath = Paths.get(ARGS.getString("guessable_entities")) + + println(s"Loading entities from ${entitiesFilePath}") + val entities: List[NamedEntity] = loadEntitiesOrFail(entitiesFilePath) + println(s"Loaded ${entities.length} entities") + + + } + + private def loadEntitiesOrFail(jsonFilePath: Path): List[NamedEntity] = { + Try(new JSONArray(new JSONTokener(Files.newInputStream(jsonFilePath)))) + .map(arr => + Utils.JSONArrayElementsAs[JSONObject](arr) + .map(NamedEntity.apply) + .toList + ) + .flatMap(Utils.flattenTryList) + .recover(err => { + System.err.println(s"Error when parsing entities: ${err}") + System.exit(1) + Nil + }) + .get + } +}