Skip to content

Commit

Permalink
fix externel data (#1215)
Browse files Browse the repository at this point in the history
* no module constrain for fusion eval

* fix external data larger than 2GB

* power of 2 to square

* change buffer size from int to  long
  • Loading branch information
xhuohai authored Jun 4, 2024
1 parent 2c81f40 commit 1c9f388
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
p.Add<Passes.Rules.Neutral.PowOf2ToSquare>();
});

passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
Expand Down
10 changes: 5 additions & 5 deletions src/Nncase.Core/TIR/Script.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location

var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var memspan = new MemSpan(size, location);
buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
Expand Down Expand Up @@ -268,7 +268,7 @@ public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocat

var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var memspan = new MemSpan(start, size, location);
buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
Expand All @@ -286,7 +286,7 @@ public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [Caller

var dimensions = @const.CheckedShape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes;
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes;
var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata);
buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
Expand All @@ -304,7 +304,7 @@ public static Buffer AttachBuffer(Buffer originBuffer, Expr offset, TensorType t

var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
buffer = new Buffer(name, tensorType.DType, originBuffer.MemSpan.SubSpan(offset, size), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}
Expand All @@ -322,7 +322,7 @@ public static Buffer AttachBuffer(TensorType tensorType, MemoryLocation location
@var = new Var(TensorType.Pointer(tensorType.DType));
var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
buffer = new Buffer(name, tensorType.DType, new MemSpan(@var, size, location), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Core/TensorUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,9 @@ public static int[] ToInts(this ReadOnlySpan<long> longs)

public static int[] ToInts(this long[] longs) => ToInts((ReadOnlySpan<long>)longs);

public static int GetSize(Span<int> shapes, Span<int> strides, int elementSize)
public static long GetSize(Span<int> shapes, Span<int> strides, int elementSize)
{
int size = 0;
long size = 0;
for (int i = 0; i < shapes.Length; i++)
{
size += (shapes[i] - 1) * strides[i];
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Evaluator/EvaluateVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ protected override IValue VisitLeafCall(Call expr)
{
Op op => CompilerServices.EvaluateOp(op, _context, _evaluator_cache),
Function func => CompilerServices.Evaluate(func.Body, CreateFunctionEvaluateArguments(func.Parameters, expr.Arguments), _evaluator_cache),
Fusion { ModuleKind: "stackvm" } fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
Fusion fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
_ => throw new NotImplementedException(expr.Target.ToString()),
};
}
Expand Down
32 changes: 29 additions & 3 deletions src/Nncase.Importer/Onnx/DataGatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ private bool EmptyTensor(TensorProto tensor)
return tensor.Dims.Count == 1 && tensor.Dims[0] == 0;
}

private Tensor GetExternalTensor<T>(BinaryReader br, DataType dataType, long length, Shape shape)
where T : unmanaged, IEquatable<T>
{
var tensorArray = new T[length / dataType.SizeInBytes];
var totalRead = 0;
int chunk = 1024 * 1024 * 1024;
for (long l = length; l > 0; l -= chunk)
{
var tmpBuffer = br.ReadBytes((int)Math.Min(chunk, l));

Buffer.BlockCopy(tmpBuffer, 0, tensorArray, totalRead, tmpBuffer.Length);
totalRead += tmpBuffer.Length / dataType.SizeInBytes;
}

return Tensor.From(tensorArray, shape);
}

private Tensor GetTensor(TensorProto tensor)
{
var shape = GetShape(tensor).ToValueArray();
Expand Down Expand Up @@ -115,10 +132,19 @@ private Tensor GetTensor(TensorProto tensor)
var location = Path.Join(parent, externalData[0].Value);
var offset = externalDataCount > 1L ? long.Parse(externalData[1].Value) : 0;
using var br = new BinaryReader(new FileStream(location, FileMode.Open));
var length = externalDataCount > 1 ? int.Parse(externalData[2].Value) : (int)br.BaseStream.Length;
var length = externalDataCount > 1 ? long.Parse(externalData[2].Value) : br.BaseStream.Length;
br.BaseStream.Seek(offset, SeekOrigin.Begin);
var buffer = br.ReadBytes(length);
return Tensor.FromBytes(type, buffer, shape);

return type switch
{
var t when t == DataTypes.Float32 => GetExternalTensor<float>(br, type, length, shape),
var t when t == DataTypes.Float64 => GetExternalTensor<double>(br, type, length, shape),
var t when t == DataTypes.Int32 => GetExternalTensor<int>(br, type, length, shape),
var t when t == DataTypes.Int64 => GetExternalTensor<long>(br, type, length, shape),
var t when t == DataTypes.Int8 => GetExternalTensor<sbyte>(br, type, length, shape),
var t when t == DataTypes.UInt8 => GetExternalTensor<byte>(br, type, length, shape),
_ => throw new NotSupportedException($"Not supported onnx constant data type {type}"),
};
}

return dt switch
Expand Down
8 changes: 4 additions & 4 deletions src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ namespace Nncase.Passes.BufferSchedule;

public sealed class Interval
{
public Interval(int start, int end)
public Interval(long start, long end)
{
Start = start;
Stop = end;
}

public int Start { get; set; }
public long Start { get; set; }

public int Stop { get; set; }
public long Stop { get; set; }

public int Size => Stop - Start;
public long Size => Stop - Start;

public override string ToString()
{
Expand Down
12 changes: 6 additions & 6 deletions src/Nncase.Passes/BufferSchedule/BufferScheduler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public virtual void ExternalConstrains(CpModel model, IReadOnlyDictionary<Expr,
if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple)
{
// the concat inputs must contiguous
int offset = 0;
long offset = 0;
for (int i = 0; i < tuple.Fields.Length; i++)
{
model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr());
Expand All @@ -36,7 +36,7 @@ public virtual void ExternalConstrains(CpModel model, IReadOnlyDictionary<Expr,

// the split outputs must contiguous
var users = splitCall.GetUsers();
int offset = 0;
long offset = 0;
foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar<int>()))
{
model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr());
Expand All @@ -56,7 +56,7 @@ public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
var model = new CpModel();
var noOverlap = model.AddNoOverlap2D();
var boxs = new Dictionary<Expr, (IntervalVar X, IntervalVar Y)>(ReferenceEqualityComparer.Instance);
var timeMap = new Dictionary<int, List<Expr>>();
var timeMap = new Dictionary<long, List<Expr>>();
var yStarts = new List<IntVar>();
foreach (var (expr, item) in bufferMap)
{
Expand All @@ -74,7 +74,7 @@ public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
yStarts.Add(memStartVar);
boxs.Add(expr, (xInterval, yInterval));

for (int time = item.TimeInterval.Start; time < item.TimeInterval.Stop; time++)
for (long time = item.TimeInterval.Start; time < item.TimeInterval.Stop; time++)
{
if (!timeMap.TryGetValue(time, out var timelist))
{
Expand All @@ -100,8 +100,8 @@ public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)

foreach (var (k, _) in bufferMap)
{
bufferMap[k].MemInterval.Start = checked((int)solver.Value(boxs[k].Y.StartExpr()));
bufferMap[k].MemInterval.Stop = checked((int)solver.Value(boxs[k].Y.EndExpr()));
bufferMap[k].MemInterval.Start = checked(solver.Value(boxs[k].Y.StartExpr()));
bufferMap[k].MemInterval.Stop = checked(solver.Value(boxs[k].Y.EndExpr()));
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ protected override Unit VisitLeafCall(Call expr)
return Unit.Default;
}

protected virtual int ComputeBufferSize(IRType type, out int[] shape, out int[] stride)
protected virtual long ComputeBufferSize(IRType type, out int[] shape, out int[] stride)
{
shape = Array.Empty<int>();
stride = Array.Empty<int>();
var size = 0;
long size = 0;
if (type is TensorType tensorType)
{
shape = tensorType.Shape.ToValueArray();
Expand Down
41 changes: 41 additions & 0 deletions src/Nncase.Passes/Rules/Neutral/PowOf2ToSquare.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.F.Math;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.Neutral;

[RuleGenerator]
public sealed partial class PowOf2ToSquare : RewriteRule<CallPattern>
{
/// <inheritdoc/>
public override CallPattern Pattern { get; } =
IsBinary(
"pow",
"call",
p => p.BinaryOp is BinaryOp.Pow,
IsWildcard("input"),
IsTensorConst("power"));

private Expr? GetReplace(Expr input, TensorConst power)
{
if (power.Value.ToArray<float>().All(x => x == 2))
{
return Unary(UnaryOp.Square, input);
}

return null;
}
}

0 comments on commit 1c9f388

Please sign in to comment.