Update DebertaV3.cs
Browse files- DebertaV3.cs +7 -1
DebertaV3.cs
CHANGED
|
@@ -34,6 +34,12 @@ public sealed class DebertaV3 : MonoBehaviour
|
|
| 34 |
Model loadedModel = ModelLoader.Load(model);
|
| 35 |
engine = WorkerFactory.CreateWorker(backend, loadedModel);
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
|
| 38 |
Batch batch = GetTokenizedBatch(text, hypotheses);
|
| 39 |
float[] scores = GetBatchScores(batch);
|
|
@@ -106,7 +112,7 @@ public sealed class DebertaV3 : MonoBehaviour
|
|
| 106 |
// To obtain a single value (score) per example, a softmax function is applied
|
| 107 |
|
| 108 |
TensorFloat tensorScores;
|
| 109 |
-
if (multipleTrueClasses || logits.shape.
|
| 110 |
{
|
| 111 |
// Softmax over the entailment vs. contradiction dimension for each label independently
|
| 112 |
tensorScores = ops.Softmax(logits, -1);
|
|
|
|
| 34 |
Model loadedModel = ModelLoader.Load(model);
|
| 35 |
engine = WorkerFactory.CreateWorker(backend, loadedModel);
|
| 36 |
|
| 37 |
+
if (classes.Length == 0)
|
| 38 |
+
{
|
| 39 |
+
Debug.LogError("There need to be more than 0 classes");
|
| 40 |
+
return;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
|
| 44 |
Batch batch = GetTokenizedBatch(text, hypotheses);
|
| 45 |
float[] scores = GetBatchScores(batch);
|
|
|
|
| 112 |
// To obtain a single value (score) per example, a softmax function is applied
|
| 113 |
|
| 114 |
TensorFloat tensorScores;
|
| 115 |
+
if (multipleTrueClasses || logits.shape.Length(0, 1) == 1)
|
| 116 |
{
|
| 117 |
// Softmax over the entailment vs. contradiction dimension for each label independently
|
| 118 |
tensorScores = ops.Softmax(logits, -1);
|