MoreThanText proxy server.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

598 lines
14 KiB

// morethantext/proxy/sessionmanager_test.go
package proxy
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"testing"
"time"
)
type testServer struct {
sm *sessionManager
msgStore chan sessionMessage
}
type webRig struct {
sm *sessionManager
wc *http.Client
ws *httptest.Server
}
const smTestLoopCount = 1000
const bodyString = "Testing!"
type smTestHandler struct{}
func (th smTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, bodyString)
}
func newTestServer() *testServer {
ts := &testServer{
sm: newSessionManager(smTestHandler{}),
msgStore: make(chan sessionMessage, 10),
}
go func() {
for msg := range ts.sm.msgOut {
ts.sm.msgIn <- msg
ts.msgStore <- msg
}
close(ts.msgStore)
}()
return ts
}
func newBenchmarkServer() *sessionManager {
sm := newSessionManager(smTestHandler{})
go func() {
for msg := range sm.msgOut {
sm.msgIn <- msg
}
}()
return sm
}
func newWebRig(t *testing.T) *webRig {
sm := newBenchmarkServer()
ws := httptest.NewTLSServer(sm)
cert, err := x509.ParseCertificate(ws.TLS.Certificates[0].Certificate[0])
if err != nil {
t.Fatal(err)
}
jar, _ := cookiejar.New(nil)
certpool := x509.NewCertPool()
certpool.AddCert(cert)
return &webRig{
sm: sm,
ws: ws,
wc: &http.Client{
Jar: jar,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
},
},
}
}
func (wr *webRig) findCookie(name string, cookies []*http.Cookie) *http.Cookie {
var result *http.Cookie
for _, cookie := range cookies {
if cookie.Name == name {
result = cookie
}
}
return result
}
func (wr *webRig) get(t *testing.T, url string) *http.Response {
res, err := wr.wc.Get(wr.ws.URL)
if err != nil {
t.Fatal(err)
}
return res
}
func setupBenchmark(b *testing.B, sm *sessionManager) []string {
var holder []string
for n := 0; n < b.N; n++ {
sessid, _ := sm.newSession()
holder = append(holder, sessid)
}
b.ResetTimer()
return holder
}
func TestGetSessionMessage(t *testing.T) {
sessid := newID()
reqid := newID()
ts := newTestServer()
newSessMsg := sessionMessage{
action: sessionUpdate,
sessionID: sessid,
requestID: reqid,
}
ts.sm.msgIn <- newSessMsg
infoChan := make(chan sessionInfo)
reqMsg := sessionMessage{
action: sessionRequest,
sessionID: sessid,
infoChan: infoChan,
}
ts.sm.msgIn <- reqMsg
AnswerID := <-infoChan
if AnswerID.requestID != reqid {
t.Errorf("For request message: Expected '%s', but got '%s'.", reqid, AnswerID)
}
}
func TestGetBadSessionMessage(t *testing.T) {
ts := newTestServer()
infoChan := make(chan sessionInfo)
reqMsg := sessionMessage{
action: sessionRequest,
sessionID: "Fred",
infoChan: infoChan,
}
ts.sm.msgIn <- reqMsg
result := <-infoChan
if result.requestID != "" {
t.Errorf("Bad message: Expected '' got '%s'.", result)
}
}
func TestDeleteSessionMessage(t *testing.T) {
sessid := newID()
reqid := newID()
ts := newTestServer()
newSessMsg := sessionMessage{
action: sessionUpdate,
sessionID: sessid,
requestID: reqid,
}
ts.sm.msgIn <- newSessMsg
msg := sessionMessage{
action: sessionDelete,
sessionID: sessid,
}
ts.sm.msgIn <- msg
infoChan := make(chan sessionInfo)
reqMsg := sessionMessage{
action: sessionRequest,
sessionID: sessid,
infoChan: infoChan,
}
ts.sm.msgIn <- reqMsg
result := <-infoChan
if result.requestID != "" {
t.Errorf("Delete message: Expected '', but got '%s'.", result)
}
}
func TestNewSession(t *testing.T) {
ts := newTestServer()
sessid, reqid := ts.sm.newSession()
close(ts.sm.msgOut)
sessLength := len(sessid)
if sessLength != idLength {
t.Errorf("Session id was %d characters: Expected %d.", sessLength, idLength)
}
requestLength := len(reqid)
if requestLength != idLength {
t.Errorf("Request id was %d characters: Expected %d.", requestLength, idLength)
}
msg := sessionMessage{
action: sessionUpdate,
sessionID: sessid,
requestID: reqid,
}
found := false
for readMsg := range ts.msgStore {
if readMsg == msg {
found = true
}
}
if !found {
t.Errorf("No update mesage: action: %d; session id: %s; request id: %s", msg.action, msg.sessionID, msg.requestID)
}
}
func TestNewSessionRandomID(t *testing.T) {
sm := newBenchmarkServer()
holder := make(map[string]bool)
for i := 0; i < smTestLoopCount; i++ {
sessid, reqid := sm.newSession()
if holder[sessid] {
t.Errorf("Session ID '%s' was repeated.", sessid)
}
if holder[reqid] {
t.Errorf("Request ID '%s' was repeated.", reqid)
}
holder[sessid] = true
holder[reqid] = true
}
}
func TestGetSession(t *testing.T) {
ts := newTestServer()
sessid, reqid := ts.sm.newSession()
rreqid := ts.sm.getSession(sessid)
if reqid != rreqid {
t.Errorf("For request id: Expected '%s' but got '%s'.", reqid, rreqid)
}
}
func TestNextRequestID(t *testing.T) {
ts := newTestServer()
sessid, reqid := ts.sm.newSession()
nreqid := ts.sm.nextRequestID(sessid)
close(ts.sm.msgOut)
if nreqid == reqid {
t.Errorf("New request id equaled the old one: %s'", reqid)
}
msg := sessionMessage{
action: sessionUpdate,
sessionID: sessid,
requestID: nreqid,
}
found := false
for readMsg := range ts.msgStore {
if readMsg == msg {
found = true
}
}
if !found {
t.Errorf("No update id msg: action: %d; session id: %s; request id: %s", msg.action, msg.sessionID, msg.requestID)
}
}
func TestNextRequestIDisRandom(t *testing.T) {
sm := newBenchmarkServer()
holder := make(map[string]bool)
sessid, reqid := sm.newSession()
holder[reqid] = true
for i := 0; i < smTestLoopCount; i++ {
reqid = sm.nextRequestID(sessid)
if holder[reqid] {
t.Errorf("Request ID '%s' was repeated.", reqid)
}
holder[reqid] = true
}
}
func TestResetSession(t *testing.T) {
ts := newTestServer()
sessid, _ := ts.sm.newSession()
nsessid, nreqid := ts.sm.resetSession(sessid)
close(ts.sm.msgOut)
if nsessid == sessid {
t.Errorf("Both old and new session equaled %s", sessid)
}
umsg := sessionMessage{
action: sessionUpdate,
sessionID: nsessid,
requestID: nreqid,
}
dmsg := sessionMessage{
action: sessionDelete,
sessionID: sessid,
}
ufound := false
dfound := false
for readMsg := range ts.msgStore {
if readMsg == umsg {
ufound = true
}
if readMsg == dmsg {
dfound = true
}
}
if !ufound {
t.Errorf("No update msg: action: %d; session id: %s; request id: %s", umsg.action, umsg.sessionID, umsg.requestID)
}
if !dfound {
t.Errorf("No delete msg: action: %d; session id: %s", umsg.action, umsg.sessionID)
}
}
func TestSessionAge(t *testing.T) {
var age time.Duration
sm := newBenchmarkServer()
sessid, _ := sm.newSession()
age = sm.sessionAge(sessid)
if age <= 0 {
t.Errorf("Initial session age was %d", age)
}
age2 := sm.sessionAge(sessid)
if age2 <= age {
t.Errorf("Session '%s' is not aging: First age: %d, Later age: %d.", sessid, age, age2)
}
}
func TestSessionUpdateAge(t *testing.T) {
var age time.Duration
sm := newBenchmarkServer()
sessid, _ := sm.newSession()
time.Sleep(time.Second)
sm.nextRequestID(sessid)
age = sm.sessionAge(sessid)
if age > time.Second {
t.Errorf("Age after update was %d, which is greater than a second", age)
}
}
func TestSessionRequestAge(t *testing.T) {
var beforeAge, afterAge time.Duration
sm := newBenchmarkServer()
sessid, _ := sm.newSession()
beforeAge = sm.sessionAge(sessid)
sm.getSession(sessid)
afterAge = sm.sessionAge(sessid)
if afterAge < beforeAge {
t.Errorf("Age before request %d, after %d", beforeAge, afterAge)
}
}
func TestSessionArray(t *testing.T) {
sm := newBenchmarkServer()
sessMap := make(map[string]bool)
for i := 0; i < smTestLoopCount; i++ {
sessid, _ := sm.newSession()
sessMap[sessid] = true
}
result := sm.sessionList()
if len(result) != len(sessMap) {
t.Errorf("Created %d sessions, but got %d sessions.", len(sessMap), len(result))
}
for _, sessid := range result {
if !sessMap[sessid] {
t.Errorf("Session '%s' was not in the returned slice.", sessid)
}
}
}
func TestDefaultCleanUpSettings(t *testing.T) {
sm := newBenchmarkServer()
maxAge := 24 * time.Hour
check := time.Hour
if sm.maxAge != maxAge {
t.Errorf("The default maximum age of a session was %d. It should have been %d.", sm.maxAge, maxAge)
}
if sm.check != check {
t.Errorf("The check period was %d. It should have been %d.", sm.check, check)
}
}
func TestSetMaxAge(t *testing.T) {
sm := newBenchmarkServer()
maxAge := time.Hour
sm.setMaxAge(maxAge)
if sm.maxAge != maxAge {
t.Errorf("Expected maxAge: %d; received %d", maxAge, sm.maxAge)
}
}
func TestSetCheckTime(t *testing.T) {
sm := newBenchmarkServer()
check := 15 * time.Minute
sm.setCheckTime(check)
if sm.check != check {
t.Errorf("Expected check: %d; received %d.", check, sm.check)
}
}
func TestCleanup(t *testing.T) {
sm := newBenchmarkServer()
check := 500 * time.Millisecond
sm.setCheckTime(check)
sm.setMaxAge(4 * check)
sessid1, _ := sm.newSession()
time.Sleep(2 * check)
sessid2, _ := sm.newSession()
time.Sleep(3 * check)
result1 := sm.getSession(sessid1)
result2 := sm.getSession(sessid2)
if result1 != "" {
t.Error("The first session did not delete after four check periods.")
}
if result2 == "" {
t.Error("The second session deleted before four check periods.")
}
time.Sleep(2 * check)
result3 := sm.getSession(sessid2)
if result3 != "" {
t.Error("The second session did not delete after four check periods.")
}
}
func BenchmarkNewSession(b *testing.B) {
sm := newBenchmarkServer()
for n := 0; n < b.N; n++ {
sm.newSession()
}
}
func BenchmarkDeleteSession(b *testing.B) {
sm := newBenchmarkServer()
holder := setupBenchmark(b, sm)
for n := 0; n < b.N; n++ {
sm.sessionDelete(holder[n])
}
}
func BenchmarkGetSession(b *testing.B) {
sm := newBenchmarkServer()
holder := setupBenchmark(b, sm)
for n := 0; n < b.N; n++ {
sm.getSession(holder[n])
}
}
func BenchmarkNextRequestID(b *testing.B) {
sm := newBenchmarkServer()
holder := setupBenchmark(b, sm)
for n := 0; n < b.N; n++ {
sm.nextRequestID(holder[n])
}
}
func BenchmarkResetSession(b *testing.B) {
sm := newBenchmarkServer()
holder := setupBenchmark(b, sm)
for n := 0; n < b.N; n++ {
sm.resetSession(holder[n])
}
}
func BenchmarkSessionList(b *testing.B) {
sm := newBenchmarkServer()
setupBenchmark(b, sm)
for n := 0; n < b.N; n++ {
sm.sessionList()
}
}
func TestSetSessionCookie(t *testing.T) {
wr := newWebRig(t)
res := wr.get(t, wr.ws.URL)
session := wr.findCookie(sname, res.Cookies())
if session == nil {
t.Error("No session cookie was created.")
} else {
if len(session.Value) != idLength {
t.Errorf("Session id was '%s', and it should have been a genid", session.Value)
}
if session.Path != "/" {
t.Errorf("Session path was '%s', and it should have been '/'", session.Path)
}
if !session.HttpOnly {
t.Error("Session cookie should be http only")
}
if !session.Secure {
t.Error("Session cookie should be secure.")
}
}
}
func TestSetRequestCookie(t *testing.T) {
wr := newWebRig(t)
res := wr.get(t, wr.ws.URL)
request := wr.findCookie(rname, res.Cookies())
if request == nil {
t.Error("No request cookie was created")
} else {
if len(request.Value) != idLength {
t.Errorf("Request id was '%s', and if whould have been a genid", request.Value)
}
if request.Path != "/" {
t.Errorf("Session path was '%s', and it should have been '/'", request.Path)
}
if !request.HttpOnly {
t.Error("Session cookie should be http only")
}
if !request.Secure {
t.Error("Session cookie should be secure.")
}
}
}
func TestSessionStorage(t *testing.T) {
wr := newWebRig(t)
res := wr.get(t, wr.ws.URL)
sessid := wr.findCookie(sname, res.Cookies()).Value
reqid := wr.findCookie(rname, res.Cookies()).Value
result := wr.sm.getSession(sessid)
if result != reqid {
t.Errorf("For session '%s', got request '%s', expected '%s'", sessid, result, reqid)
}
}
func TestSessionCookieSentOnce(t *testing.T) {
wr := newWebRig(t)
wr.get(t, wr.ws.URL)
res := wr.get(t, wr.ws.URL)
result := wr.findCookie(sname, res.Cookies())
if result != nil {
t.Error("Session cookie should only be sent once.")
}
}
func TestRequestCookieAlwaysUpdated(t *testing.T) {
wr := newWebRig(t)
firstPass := wr.get(t, wr.ws.URL)
sessid := wr.findCookie(sname, firstPass.Cookies()).Value
reqid1 := wr.findCookie(rname, firstPass.Cookies()).Value
secondPass := wr.get(t, wr.ws.URL)
reqid2 := wr.findCookie(rname, secondPass.Cookies()).Value
result := wr.sm.getSession(sessid)
if reqid2 == reqid1 {
t.Error("Request id did not change between sessions.")
}
if result != reqid2 {
t.Errorf("For second pass session '%s', got request '%s', expected '%s'", sessid, result, reqid2)
}
}
func TestMissingRequestCookie(t *testing.T) {
wr := newWebRig(t)
res1 := wr.get(t, wr.ws.URL)
var cookies []*http.Cookie
for _, cookie := range res1.Cookies() {
if cookie.Name != rname {
cookies = append(cookies, cookie)
}
}
jar, _ := cookiejar.New(nil)
u, _ := url.Parse(wr.ws.URL)
jar.SetCookies(u, cookies)
wr.wc.Jar = jar
res2 := wr.get(t, wr.ws.URL)
if wr.findCookie(sname, res1.Cookies()).Value == wr.findCookie(sname, res2.Cookies()).Value {
t.Error("Session did not reset when the request cookie was missing.")
}
}
func TestBadRequestCookie(t *testing.T) {
wr := newWebRig(t)
res1 := wr.get(t, wr.ws.URL)
var cookies []*http.Cookie
for _, cookie := range res1.Cookies() {
if cookie.Name == rname {
cookie.Value = "bad id"
}
cookies = append(cookies, cookie)
}
jar, _ := cookiejar.New(nil)
u, _ := url.Parse(wr.ws.URL)
jar.SetCookies(u, cookies)
wr.wc.Jar = jar
res2 := wr.get(t, wr.ws.URL)
if wr.findCookie(sname, res1.Cookies()).Value == wr.findCookie(sname, res2.Cookies()).Value {
t.Error("Session did not reset when the request cookie was bad.")
}
if wr.sm.getSession(wr.findCookie(sname, res1.Cookies()).Value) != "" {
t.Error("The bad session did not get deleted.")
}
}
func TestPagePassThrough(t *testing.T) {
wr := newWebRig(t)
res := wr.get(t, wr.ws.URL)
data, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
holder := string(data)
if holder != bodyString {
t.Errorf("Expected the body to be '%s', but got '%s'.", bodyString, holder)
}
}