This commit is contained in:
Will Charczuk 2017-04-17 13:56:19 -07:00
parent 724d6e3c2a
commit 8c4ccc3bb6
3 changed files with 291 additions and 17 deletions

View file

@ -14,16 +14,26 @@ const (
var (
// ErrDimensionMismatch is a typical error.
ErrDimensionMismatch = errors.New("matrix is not square, cannot invert")
ErrDimensionMismatch = errors.New("dimension mismatch")
)
// New returns a new matrix.
func New(rows, cols int) *Matrix {
func New(rows, cols int, values ...float64) *Matrix {
if len(values) == 0 {
return &Matrix{
rows: rows,
cols: cols,
epsilon: DefaultEpsilon,
elements: make([]float64, rows*cols),
}
}
elems := make([]float64, rows*cols)
copy(elems, values)
return &Matrix{
rows: rows,
cols: cols,
epsilon: DefaultEpsilon,
elements: make([]float64, rows*cols),
elements: elems,
}
}
@ -92,9 +102,11 @@ func (m *Matrix) WithEpsilon(epsilon float64) *Matrix {
// Arrays returns the matrix as a two dimensional jagged array.
func (m *Matrix) Arrays() [][]float64 {
a := make([][]float64, m.rows, m.cols)
a := make([][]float64, m.rows)
for row := 0; row < m.rows; row++ {
a[row] = make([]float64, m.cols)
for col := 0; col < m.cols; col++ {
a[row][col] = m.Get(row, col)
}
@ -145,7 +157,7 @@ func (m *Matrix) Set(row, col int, val float64) {
func (m *Matrix) Col(col int) Vector {
values := make([]float64, m.rows)
for row := 0; row < m.rows; row++ {
values[col] = m.Get(row, col)
values[row] = m.Get(row, col)
}
return Vector(values)
}
@ -181,6 +193,17 @@ func (m *Matrix) DiagonalVector() Vector {
return Vector(values)
}
// Diagonal returns a matrix from the diagonal of a matrix.
func (m *Matrix) Diagonal() *Matrix {
rank := minInt(m.rows, m.cols)
m2 := New(rank, rank)
for index := 0; index < rank; index++ {
m2.Set(index, index, m.Get(index, index))
}
return m2
}
// Equals returns if a matrix equals another matrix.
func (m *Matrix) Equals(other *Matrix) bool {
if other == nil && m != nil {
@ -226,17 +249,6 @@ func (m *Matrix) U() *Matrix {
return m2
}
// Diagonal returns a matrix from the diagonal of a matrix.
func (m *Matrix) Diagonal() *Matrix {
rank := minInt(m.rows, m.cols)
m2 := New(rank, rank)
for index := 0; index < rank; index++ {
m2.Set(index, index, m.Get(index, index))
}
return m2
}
// String returns a string representation of the matrix.
func (m *Matrix) String() string {
buffer := bytes.NewBuffer(nil)