poly regression!
This commit is contained in:
parent
5150caf515
commit
8445577ef4
3 changed files with 198 additions and 0 deletions
131
matrix/matrix.go
131
matrix/matrix.go
|
|
@ -3,6 +3,7 @@ package matrix
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
|
|
@ -14,6 +15,9 @@ const (
|
|||
var (
|
||||
// ErrDimensionMismatch is a typical error.
|
||||
ErrDimensionMismatch = errors.New("dimension mismatch")
|
||||
|
||||
// ErrSingularValue is a typical error.
|
||||
ErrSingularValue = errors.New("singular value")
|
||||
)
|
||||
|
||||
// New returns a new matrix.
|
||||
|
|
@ -220,6 +224,64 @@ func (m *Matrix) Row(row int) Vector {
|
|||
return Vector(values)
|
||||
}
|
||||
|
||||
// SubMatrix returns a sub matrix from a given outer matrix.
|
||||
func (m *Matrix) SubMatrix(i, j, rows, cols int) *Matrix {
|
||||
return &Matrix{
|
||||
elements: m.elements[i*m.stride+j : i*m.stride+j+(rows-1)*m.stride+cols],
|
||||
stride: m.stride,
|
||||
epsilon: m.epsilon,
|
||||
}
|
||||
}
|
||||
|
||||
// ScaleRow applies a scale to an entire row.
|
||||
func (m *Matrix) ScaleRow(row int, scale float64) {
|
||||
startIndex := row * m.stride
|
||||
for i := startIndex; i < m.stride; i++ {
|
||||
m.elements[i] = m.elements[i] * scale
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Matrix) scaleAddRow(rd int, rs int, f float64) {
|
||||
indexd := rd * m.stride
|
||||
indexs := rs * m.stride
|
||||
for col := 0; col < m.stride; col++ {
|
||||
m.elements[indexd] += f * m.elements[indexs]
|
||||
indexd++
|
||||
indexs++
|
||||
}
|
||||
}
|
||||
|
||||
// SwapRows swaps a row in the matrix in place.
|
||||
func (m *Matrix) SwapRows(i, j int) {
|
||||
var vi, vj float64
|
||||
for col := 0; col < m.stride; col++ {
|
||||
vi = m.Get(i, col)
|
||||
vj = m.Get(j, col)
|
||||
m.Set(i, col, vj)
|
||||
m.Set(j, col, vi)
|
||||
}
|
||||
}
|
||||
|
||||
// Augment concatenates two matrices about the horizontal.
|
||||
func (m *Matrix) Augment(m2 *Matrix) (*Matrix, error) {
|
||||
mr, mc := m.Size()
|
||||
m2r, m2c := m2.Size()
|
||||
if mr != m2r {
|
||||
return nil, ErrDimensionMismatch
|
||||
}
|
||||
|
||||
m3 := Zero(mr, mc+m2c)
|
||||
for row := 0; row < mr; row++ {
|
||||
for col := 0; col < mc; col++ {
|
||||
m3.Set(row, col, m.Get(row, col))
|
||||
}
|
||||
for col := 0; col < m2c; col++ {
|
||||
m3.Set(row, mc+col, m2.Get(row, col))
|
||||
}
|
||||
}
|
||||
return m3, nil
|
||||
}
|
||||
|
||||
// Copy returns a duplicate of a given matrix.
|
||||
func (m *Matrix) Copy() *Matrix {
|
||||
m2 := &Matrix{stride: m.stride, epsilon: m.epsilon, elements: make([]float64, len(m.elements))}
|
||||
|
|
@ -354,6 +416,30 @@ func (m *Matrix) Pivotize() *Matrix {
|
|||
return p
|
||||
}
|
||||
|
||||
// Times returns the product of a matrix and another.
|
||||
func (m *Matrix) Times(m2 *Matrix) (*Matrix, error) {
|
||||
mr, mc := m.Size()
|
||||
m2r, m2c := m2.Size()
|
||||
|
||||
if mc != m2r {
|
||||
return nil, fmt.Errorf("cannot multiply (%dx%d) and (%dx%d)", mr, mc, m2r, m2c)
|
||||
//return nil, ErrDimensionMismatch
|
||||
}
|
||||
|
||||
c := Zero(mr, m2c)
|
||||
|
||||
for i := 0; i < mr; i++ {
|
||||
sums := c.elements[i*c.stride : (i+1)*c.stride]
|
||||
for k, a := range m.elements[i*m.stride : i*m.stride+m.stride] {
|
||||
for j, b := range m2.elements[k*m2.stride : k*m2.stride+m2.stride] {
|
||||
sums[j] += a * b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Decompositions
|
||||
|
||||
// LU performs the LU decomposition.
|
||||
|
|
@ -459,3 +545,48 @@ func (m *Matrix) QR() (q, r *Matrix) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// Transpose flips a matrix about its diagonal, returning a new copy.
|
||||
func (m *Matrix) Transpose() *Matrix {
|
||||
rows, cols := m.Size()
|
||||
m2 := Zero(cols, rows)
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
m2.Set(j, i, m.Get(i, j))
|
||||
}
|
||||
}
|
||||
return m2
|
||||
}
|
||||
|
||||
// Inverse returns a matrix such that M*I==1.
|
||||
func (m *Matrix) Inverse() (*Matrix, error) {
|
||||
if !m.IsSymmetric() {
|
||||
return nil, ErrDimensionMismatch
|
||||
}
|
||||
|
||||
rows, cols := m.Size()
|
||||
|
||||
aug, _ := m.Augment(Eye(rows))
|
||||
for i := 0; i < rows; i++ {
|
||||
j := i
|
||||
for k := i; k < rows; k++ {
|
||||
if math.Abs(aug.Get(k, i)) > math.Abs(aug.Get(j, i)) {
|
||||
j = k
|
||||
}
|
||||
}
|
||||
if j != i {
|
||||
aug.SwapRows(i, j)
|
||||
}
|
||||
if aug.Get(i, i) == 0 {
|
||||
return nil, ErrSingularValue
|
||||
}
|
||||
aug.ScaleRow(i, 1.0/aug.Get(i, i))
|
||||
for k := 0; k < rows; k++ {
|
||||
if k == i {
|
||||
continue
|
||||
}
|
||||
aug.scaleAddRow(k, i, -aug.Get(k, i))
|
||||
}
|
||||
}
|
||||
return aug.SubMatrix(0, cols, rows, cols), nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue