service_provider.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. package saml
  2. import (
  3. "bytes"
  4. "compress/flate"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "encoding/base64"
  8. "encoding/xml"
  9. "errors"
  10. "fmt"
  11. "html/template"
  12. "net/http"
  13. "net/url"
  14. "regexp"
  15. "time"
  16. "github.com/beevik/etree"
  17. "github.com/crewjam/saml/logger"
  18. "github.com/crewjam/saml/xmlenc"
  19. dsig "github.com/russellhaering/goxmldsig"
  20. "github.com/russellhaering/goxmldsig/etreeutils"
  21. )
  22. // NameIDFormat is the format of the id
  23. type NameIDFormat string
  24. // Element returns an XML element representation of n.
  25. func (n NameIDFormat) Element() *etree.Element {
  26. el := etree.NewElement("")
  27. el.SetText(string(n))
  28. return el
  29. }
  30. // Name ID formats
  31. const (
  32. UnspecifiedNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:unspecified"
  33. TransientNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
  34. EmailAddressNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress"
  35. PersistentNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
  36. )
  37. // ServiceProvider implements SAML Service provider.
  38. //
  39. // In SAML, service providers delegate responsibility for identifying
  40. // clients to an identity provider. If you are writing an application
  41. // that uses passwords (or whatever) stored somewhere else, then you
  42. // are service provider.
  43. //
  44. // See the example directory for an example of a web application using
  45. // the service provider interface.
  46. type ServiceProvider struct {
  47. // Key is the RSA private key we use to sign requests.
  48. Key *rsa.PrivateKey
  49. // Certificate is the RSA public part of Key.
  50. Certificate *x509.Certificate
  51. // MetadataURL is the full URL to the metadata endpoint on this host,
  52. // i.e. https://example.com/saml/metadata
  53. MetadataURL url.URL
  54. // AcsURL is the full URL to the SAML Assertion Customer Service endpoint
  55. // on this host, i.e. https://example.com/saml/acs
  56. AcsURL url.URL
  57. // IDPMetadata is the metadata from the identity provider.
  58. IDPMetadata *EntityDescriptor
  59. // AuthnNameIDFormat is the format used in the NameIDPolicy for
  60. // authentication requests
  61. AuthnNameIDFormat NameIDFormat
  62. // MetadataValidDuration is a duration used to calculate validUntil
  63. // attribute in the metadata endpoint
  64. MetadataValidDuration time.Duration
  65. // Logger is used to log messages for example in the event of errors
  66. Logger logger.Interface
  67. // ForceAuthn allows you to force re-authentication of users even if the user
  68. // has a SSO session at the IdP.
  69. ForceAuthn *bool
  70. }
  71. // MaxIssueDelay is the longest allowed time between when a SAML assertion is
  72. // issued by the IDP and the time it is received by ParseResponse. This is used
  73. // to prevent old responses from being replayed (while allowing for some clock
  74. // drift between the SP and IDP).
  75. var MaxIssueDelay = time.Second * 90
  76. // MaxClockSkew allows for leeway for clock skew between the IDP and SP when
  77. // validating assertions. It defaults to 180 seconds (matches shibboleth).
  78. var MaxClockSkew = time.Second * 180
  79. // DefaultValidDuration is how long we assert that the SP metadata is valid.
  80. const DefaultValidDuration = time.Hour * 24 * 2
  81. // DefaultCacheDuration is how long we ask the IDP to cache the SP metadata.
  82. const DefaultCacheDuration = time.Hour * 24 * 1
  83. // Metadata returns the service provider metadata
  84. func (sp *ServiceProvider) Metadata() *EntityDescriptor {
  85. validDuration := DefaultValidDuration
  86. if sp.MetadataValidDuration > 0 {
  87. validDuration = sp.MetadataValidDuration
  88. }
  89. authnRequestsSigned := false
  90. wantAssertionsSigned := true
  91. validUntil := TimeNow().Add(validDuration)
  92. return &EntityDescriptor{
  93. EntityID: sp.MetadataURL.String(),
  94. ValidUntil: validUntil,
  95. SPSSODescriptors: []SPSSODescriptor{
  96. {
  97. SSODescriptor: SSODescriptor{
  98. RoleDescriptor: RoleDescriptor{
  99. ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
  100. KeyDescriptors: []KeyDescriptor{
  101. {
  102. Use: "signing",
  103. KeyInfo: KeyInfo{
  104. Certificate: base64.StdEncoding.EncodeToString(sp.Certificate.Raw),
  105. },
  106. },
  107. {
  108. Use: "encryption",
  109. KeyInfo: KeyInfo{
  110. Certificate: base64.StdEncoding.EncodeToString(sp.Certificate.Raw),
  111. },
  112. EncryptionMethods: []EncryptionMethod{
  113. {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"},
  114. {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc"},
  115. {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc"},
  116. {Algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"},
  117. },
  118. },
  119. },
  120. ValidUntil: validUntil,
  121. },
  122. },
  123. AuthnRequestsSigned: &authnRequestsSigned,
  124. WantAssertionsSigned: &wantAssertionsSigned,
  125. AssertionConsumerServices: []IndexedEndpoint{
  126. {
  127. Binding: HTTPPostBinding,
  128. Location: sp.AcsURL.String(),
  129. Index: 1,
  130. },
  131. },
  132. },
  133. },
  134. }
  135. }
  136. // MakeRedirectAuthenticationRequest creates a SAML authentication request using
  137. // the HTTP-Redirect binding. It returns a URL that we will redirect the user to
  138. // in order to start the auth process.
  139. func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string) (*url.URL, error) {
  140. req, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(HTTPRedirectBinding))
  141. if err != nil {
  142. return nil, err
  143. }
  144. return req.Redirect(relayState), nil
  145. }
  146. // Redirect returns a URL suitable for using the redirect binding with the request
  147. func (req *AuthnRequest) Redirect(relayState string) *url.URL {
  148. w := &bytes.Buffer{}
  149. w1 := base64.NewEncoder(base64.StdEncoding, w)
  150. w2, _ := flate.NewWriter(w1, 9)
  151. doc := etree.NewDocument()
  152. doc.SetRoot(req.Element())
  153. if _, err := doc.WriteTo(w2); err != nil {
  154. panic(err)
  155. }
  156. w2.Close()
  157. w1.Close()
  158. rv, _ := url.Parse(req.Destination)
  159. query := rv.Query()
  160. query.Set("SAMLRequest", string(w.Bytes()))
  161. if relayState != "" {
  162. query.Set("RelayState", relayState)
  163. }
  164. rv.RawQuery = query.Encode()
  165. return rv
  166. }
  167. // GetSSOBindingLocation returns URL for the IDP's Single Sign On Service binding
  168. // of the specified type (HTTPRedirectBinding or HTTPPostBinding)
  169. func (sp *ServiceProvider) GetSSOBindingLocation(binding string) string {
  170. for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
  171. for _, singleSignOnService := range idpSSODescriptor.SingleSignOnServices {
  172. if singleSignOnService.Binding == binding {
  173. return singleSignOnService.Location
  174. }
  175. }
  176. }
  177. return ""
  178. }
  179. // getIDPSigningCerts returns the certificates which we can use to verify things
  180. // signed by the IDP in PEM format, or nil if no such certificate is found.
  181. func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) {
  182. var certStrs []string
  183. for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
  184. for _, keyDescriptor := range idpSSODescriptor.KeyDescriptors {
  185. if keyDescriptor.Use == "signing" {
  186. certStrs = append(certStrs, keyDescriptor.KeyInfo.Certificate)
  187. }
  188. }
  189. }
  190. // If there are no explicitly signing certs, just return the first
  191. // non-empty cert we find.
  192. if len(certStrs) == 0 {
  193. for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
  194. for _, keyDescriptor := range idpSSODescriptor.KeyDescriptors {
  195. if keyDescriptor.Use == "" && keyDescriptor.KeyInfo.Certificate != "" {
  196. certStrs = append(certStrs, keyDescriptor.KeyInfo.Certificate)
  197. break
  198. }
  199. }
  200. }
  201. }
  202. if len(certStrs) == 0 {
  203. return nil, errors.New("cannot find any signing certificate in the IDP SSO descriptor")
  204. }
  205. var certs []*x509.Certificate
  206. // cleanup whitespace
  207. regex := regexp.MustCompile(`\s+`)
  208. for _, certStr := range certStrs {
  209. certStr = regex.ReplaceAllString(certStr, "")
  210. certBytes, err := base64.StdEncoding.DecodeString(certStr)
  211. if err != nil {
  212. return nil, fmt.Errorf("cannot parse certificate: %s", err)
  213. }
  214. parsedCert, err := x509.ParseCertificate(certBytes)
  215. if err != nil {
  216. return nil, err
  217. }
  218. certs = append(certs, parsedCert)
  219. }
  220. return certs, nil
  221. }
  222. // MakeAuthenticationRequest produces a new AuthnRequest object for idpURL.
  223. func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnRequest, error) {
  224. var nameIDFormat string
  225. switch sp.AuthnNameIDFormat {
  226. case "":
  227. // To maintain library back-compat, use "transient" if unset.
  228. nameIDFormat = string(TransientNameIDFormat)
  229. case UnspecifiedNameIDFormat:
  230. // Spec defines an empty value as "unspecified" so don't set one.
  231. default:
  232. nameIDFormat = string(sp.AuthnNameIDFormat)
  233. }
  234. allowCreate := true
  235. req := AuthnRequest{
  236. AssertionConsumerServiceURL: sp.AcsURL.String(),
  237. Destination: idpURL,
  238. ProtocolBinding: HTTPPostBinding, // default binding for the response
  239. ID: fmt.Sprintf("id-%x", randomBytes(20)),
  240. IssueInstant: TimeNow(),
  241. Version: "2.0",
  242. Issuer: &Issuer{
  243. Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
  244. Value: sp.MetadataURL.String(),
  245. },
  246. NameIDPolicy: &NameIDPolicy{
  247. AllowCreate: &allowCreate,
  248. // TODO(ross): figure out exactly policy we need
  249. // urn:mace:shibboleth:1.0:nameIdentifier
  250. // urn:oasis:names:tc:SAML:2.0:nameid-format:transient
  251. Format: &nameIDFormat,
  252. },
  253. ForceAuthn: sp.ForceAuthn,
  254. }
  255. return &req, nil
  256. }
  257. // MakePostAuthenticationRequest creates a SAML authentication request using
  258. // the HTTP-POST binding. It returns HTML text representing an HTML form that
  259. // can be sent presented to a browser to initiate the login process.
  260. func (sp *ServiceProvider) MakePostAuthenticationRequest(relayState string) ([]byte, error) {
  261. req, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(HTTPPostBinding))
  262. if err != nil {
  263. return nil, err
  264. }
  265. return req.Post(relayState), nil
  266. }
  267. // Post returns an HTML form suitable for using the HTTP-POST binding with the request
  268. func (req *AuthnRequest) Post(relayState string) []byte {
  269. doc := etree.NewDocument()
  270. doc.SetRoot(req.Element())
  271. reqBuf, err := doc.WriteToBytes()
  272. if err != nil {
  273. panic(err)
  274. }
  275. encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
  276. tmpl := template.Must(template.New("saml-post-form").Parse(`` +
  277. `<form method="post" action="{{.URL}}" id="SAMLRequestForm">` +
  278. `<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
  279. `<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
  280. `<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
  281. `</form>` +
  282. `<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
  283. `document.getElementById('SAMLRequestForm').submit();</script>`))
  284. data := struct {
  285. URL string
  286. SAMLRequest string
  287. RelayState string
  288. }{
  289. URL: req.Destination,
  290. SAMLRequest: encodedReqBuf,
  291. RelayState: relayState,
  292. }
  293. rv := bytes.Buffer{}
  294. if err := tmpl.Execute(&rv, data); err != nil {
  295. panic(err)
  296. }
  297. return rv.Bytes()
  298. }
  299. // AssertionAttributes is a list of AssertionAttribute
  300. type AssertionAttributes []AssertionAttribute
  301. // Get returns the assertion attribute whose Name or FriendlyName
  302. // matches name, or nil if no matching attribute is found.
  303. func (aa AssertionAttributes) Get(name string) *AssertionAttribute {
  304. for _, attr := range aa {
  305. if attr.Name == name {
  306. return &attr
  307. }
  308. if attr.FriendlyName == name {
  309. return &attr
  310. }
  311. }
  312. return nil
  313. }
  314. // AssertionAttribute represents an attribute of the user extracted from
  315. // a SAML Assertion.
  316. type AssertionAttribute struct {
  317. FriendlyName string
  318. Name string
  319. Value string
  320. }
  321. // InvalidResponseError is the error produced by ParseResponse when it fails.
  322. // The underlying error is in PrivateErr. Response is the response as it was
  323. // known at the time validation failed. Now is the time that was used to validate
  324. // time-dependent parts of the assertion.
  325. type InvalidResponseError struct {
  326. PrivateErr error
  327. Response string
  328. Now time.Time
  329. }
  330. func (ivr *InvalidResponseError) Error() string {
  331. return fmt.Sprintf("Authentication failed")
  332. }
  333. // ParseResponse extracts the SAML IDP response received in req, validates
  334. // it, and returns the verified attributes of the request.
  335. //
  336. // This function handles decrypting the message, verifying the digital
  337. // signature on the assertion, and verifying that the specified conditions
  338. // and properties are met.
  339. //
  340. // If the function fails it will return an InvalidResponseError whose
  341. // properties are useful in describing which part of the parsing process
  342. // failed. However, to discourage inadvertent disclosure the diagnostic
  343. // information, the Error() method returns a static string.
  344. func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
  345. now := TimeNow()
  346. retErr := &InvalidResponseError{
  347. Now: now,
  348. Response: req.PostForm.Get("SAMLResponse"),
  349. }
  350. rawResponseBuf, err := base64.StdEncoding.DecodeString(req.PostForm.Get("SAMLResponse"))
  351. if err != nil {
  352. retErr.PrivateErr = fmt.Errorf("cannot parse base64: %s", err)
  353. return nil, retErr
  354. }
  355. retErr.Response = string(rawResponseBuf)
  356. // do some validation first before we decrypt
  357. resp := Response{}
  358. if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil {
  359. retErr.PrivateErr = fmt.Errorf("cannot unmarshal response: %s", err)
  360. return nil, retErr
  361. }
  362. if resp.Destination != sp.AcsURL.String() {
  363. retErr.PrivateErr = fmt.Errorf("`Destination` does not match AcsURL (expected %q)", sp.AcsURL.String())
  364. return nil, retErr
  365. }
  366. requestIDvalid := false
  367. for _, possibleRequestID := range possibleRequestIDs {
  368. if resp.InResponseTo == possibleRequestID {
  369. requestIDvalid = true
  370. }
  371. }
  372. if !requestIDvalid {
  373. retErr.PrivateErr = fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
  374. return nil, retErr
  375. }
  376. if resp.IssueInstant.Add(MaxIssueDelay).Before(now) {
  377. retErr.PrivateErr = fmt.Errorf("IssueInstant expired at %s", resp.IssueInstant.Add(MaxIssueDelay))
  378. return nil, retErr
  379. }
  380. if resp.Issuer.Value != sp.IDPMetadata.EntityID {
  381. retErr.PrivateErr = fmt.Errorf("Issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID)
  382. return nil, retErr
  383. }
  384. if resp.Status.StatusCode.Value != StatusSuccess {
  385. retErr.PrivateErr = fmt.Errorf("Status code was not %s", StatusSuccess)
  386. return nil, retErr
  387. }
  388. var assertion *Assertion
  389. if resp.EncryptedAssertion == nil {
  390. doc := etree.NewDocument()
  391. if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
  392. retErr.PrivateErr = err
  393. return nil, retErr
  394. }
  395. // TODO(ross): verify that the namespace is urn:oasis:names:tc:SAML:2.0:protocol
  396. responseEl := doc.Root()
  397. if responseEl.Tag != "Response" {
  398. retErr.PrivateErr = fmt.Errorf("expected to find a response object, not %s", doc.Root().Tag)
  399. return nil, retErr
  400. }
  401. if err = sp.validateSigned(responseEl); err != nil {
  402. retErr.PrivateErr = err
  403. return nil, retErr
  404. }
  405. assertion = resp.Assertion
  406. }
  407. // decrypt the response
  408. if resp.EncryptedAssertion != nil {
  409. doc := etree.NewDocument()
  410. if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
  411. retErr.PrivateErr = err
  412. return nil, retErr
  413. }
  414. var key interface{} = sp.Key
  415. keyEl := doc.FindElement("//EncryptedAssertion/EncryptedKey")
  416. if keyEl != nil {
  417. key, err = xmlenc.Decrypt(sp.Key, keyEl)
  418. if err != nil {
  419. retErr.PrivateErr = fmt.Errorf("failed to decrypt key from response: %s", err)
  420. return nil, retErr
  421. }
  422. }
  423. el := doc.FindElement("//EncryptedAssertion/EncryptedData")
  424. plaintextAssertion, err := xmlenc.Decrypt(key, el)
  425. if err != nil {
  426. retErr.PrivateErr = fmt.Errorf("failed to decrypt response: %s", err)
  427. return nil, retErr
  428. }
  429. retErr.Response = string(plaintextAssertion)
  430. doc = etree.NewDocument()
  431. if err := doc.ReadFromBytes(plaintextAssertion); err != nil {
  432. retErr.PrivateErr = fmt.Errorf("cannot parse plaintext response %v", err)
  433. return nil, retErr
  434. }
  435. if err := sp.validateSigned(doc.Root()); err != nil {
  436. retErr.PrivateErr = err
  437. return nil, retErr
  438. }
  439. assertion = &Assertion{}
  440. if err := xml.Unmarshal(plaintextAssertion, assertion); err != nil {
  441. retErr.PrivateErr = err
  442. return nil, retErr
  443. }
  444. }
  445. if err := sp.validateAssertion(assertion, possibleRequestIDs, now); err != nil {
  446. retErr.PrivateErr = fmt.Errorf("assertion invalid: %s", err)
  447. return nil, retErr
  448. }
  449. return assertion, nil
  450. }
  451. // validateAssertion checks that the conditions specified in assertion match
  452. // the requirements to accept. If validation fails, it returns an error describing
  453. // the failure. (The digital signature on the assertion is not checked -- this
  454. // should be done before calling this function).
  455. func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error {
  456. if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) {
  457. return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay))
  458. }
  459. if assertion.Issuer.Value != sp.IDPMetadata.EntityID {
  460. return fmt.Errorf("issuer is not %q", sp.IDPMetadata.EntityID)
  461. }
  462. for _, subjectConfirmation := range assertion.Subject.SubjectConfirmations {
  463. requestIDvalid := false
  464. for _, possibleRequestID := range possibleRequestIDs {
  465. if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID {
  466. requestIDvalid = true
  467. break
  468. }
  469. }
  470. if !requestIDvalid {
  471. return fmt.Errorf("SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs)
  472. }
  473. if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() {
  474. return fmt.Errorf("SubjectConfirmation Recipient is not %s", sp.AcsURL.String())
  475. }
  476. if subjectConfirmation.SubjectConfirmationData.NotOnOrAfter.Add(MaxClockSkew).Before(now) {
  477. return fmt.Errorf("SubjectConfirmationData is expired")
  478. }
  479. }
  480. if assertion.Conditions.NotBefore.Add(-MaxClockSkew).After(now) {
  481. return fmt.Errorf("Conditions is not yet valid")
  482. }
  483. if assertion.Conditions.NotOnOrAfter.Add(MaxClockSkew).Before(now) {
  484. return fmt.Errorf("Conditions is expired")
  485. }
  486. audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0
  487. for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions {
  488. if audienceRestriction.Audience.Value == sp.MetadataURL.String() {
  489. audienceRestrictionsValid = true
  490. }
  491. }
  492. if !audienceRestrictionsValid {
  493. return fmt.Errorf("Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String())
  494. }
  495. return nil
  496. }
  497. func findChild(parentEl *etree.Element, childNS string, childTag string) (*etree.Element, error) {
  498. for _, childEl := range parentEl.ChildElements() {
  499. if childEl.Tag != childTag {
  500. continue
  501. }
  502. ctx, err := etreeutils.NSBuildParentContext(childEl)
  503. if err != nil {
  504. return nil, err
  505. }
  506. ctx, err = ctx.SubContext(childEl)
  507. if err != nil {
  508. return nil, err
  509. }
  510. ns, err := ctx.LookupPrefix(childEl.Space)
  511. if err != nil {
  512. return nil, fmt.Errorf("[%s]:%s cannot find prefix %s: %v", childNS, childTag, childEl.Space, err)
  513. }
  514. if ns != childNS {
  515. continue
  516. }
  517. return childEl, nil
  518. }
  519. return nil, nil
  520. }
  521. // validateSigned returns a nil error iff each of the signatures on the Response and Assertion elements
  522. // are valid and there is at least one signature.
  523. func (sp *ServiceProvider) validateSigned(responseEl *etree.Element) error {
  524. haveSignature := false
  525. // Some SAML responses have the signature on the Response object, and some on the Assertion
  526. // object, and some on both. We will require that at least one signature be present and that
  527. // all signatures be valid
  528. sigEl, err := findChild(responseEl, "http://www.w3.org/2000/09/xmldsig#", "Signature")
  529. if err != nil {
  530. return err
  531. }
  532. if sigEl != nil {
  533. if err = sp.validateSignature(responseEl); err != nil {
  534. return fmt.Errorf("cannot validate signature on Response: %v", err)
  535. }
  536. haveSignature = true
  537. }
  538. assertionEl, err := findChild(responseEl, "urn:oasis:names:tc:SAML:2.0:assertion", "Assertion")
  539. if err != nil {
  540. return err
  541. }
  542. if assertionEl != nil {
  543. sigEl, err := findChild(assertionEl, "http://www.w3.org/2000/09/xmldsig#", "Signature")
  544. if err != nil {
  545. return err
  546. }
  547. if sigEl != nil {
  548. if err = sp.validateSignature(assertionEl); err != nil {
  549. return fmt.Errorf("cannot validate signature on Response: %v", err)
  550. }
  551. haveSignature = true
  552. }
  553. }
  554. if !haveSignature {
  555. return errors.New("either the Response or Assertion must be signed")
  556. }
  557. return nil
  558. }
  559. // validateSignature returns nill iff the Signature embedded in the element is valid
  560. func (sp *ServiceProvider) validateSignature(el *etree.Element) error {
  561. certs, err := sp.getIDPSigningCerts()
  562. if err != nil {
  563. return err
  564. }
  565. certificateStore := dsig.MemoryX509CertificateStore{
  566. Roots: certs,
  567. }
  568. validationContext := dsig.NewDefaultValidationContext(&certificateStore)
  569. validationContext.IdAttribute = "ID"
  570. if Clock != nil {
  571. validationContext.Clock = Clock
  572. }
  573. // Some SAML responses contain a RSAKeyValue element. One of two things is happening here:
  574. //
  575. // (1) We're getting something signed by a key we already know about -- the public key
  576. // of the signing cert provided in the metadata.
  577. // (2) We're getting something signed by a key we *don't* know about, and which we have
  578. // no ability to verify.
  579. //
  580. // The best course of action is to just remove the KeyInfo so that dsig falls back to
  581. // verifying against the public key provided in the metadata.
  582. if el.FindElement("./Signature/KeyInfo/X509Data/X509Certificate") == nil {
  583. if sigEl := el.FindElement("./Signature"); sigEl != nil {
  584. if keyInfo := sigEl.FindElement("KeyInfo"); keyInfo != nil {
  585. sigEl.RemoveChild(keyInfo)
  586. }
  587. }
  588. }
  589. ctx, err := etreeutils.NSBuildParentContext(el)
  590. if err != nil {
  591. return err
  592. }
  593. ctx, err = ctx.SubContext(el)
  594. if err != nil {
  595. return err
  596. }
  597. el, err = etreeutils.NSDetatch(ctx, el)
  598. if err != nil {
  599. return err
  600. }
  601. _, err = validationContext.Validate(el)
  602. return err
  603. }